Spaces:
Running on Zero
Running on Zero
| """ | |
| HY-World-2.0-Demo - WorldMirror 2.0 World Reconstruction | |
| OrangeRed Gradio Theme Β· gr.Tabs Β· Rerun for 3D + Gaussian Splats Β· gr.Gallery Preview | |
| Requires Gradio version: 5.49.1 | |
| """ | |
| import os | |
| os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" | |
| import gc | |
| import json | |
| import shutil | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| from glob import glob | |
| from pathlib import Path | |
| from typing import Iterable, List, Optional | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import rerun as rr | |
| try: | |
| import rerun.blueprint as rrb | |
| except ImportError: | |
| rrb = None | |
| from gradio_rerun import Rerun | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| class _DummyGPU: | |
| def GPU(duration=120): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _DummyGPU() | |
| try: | |
| from pillow_heif import register_heif_opener | |
| register_heif_opener() | |
| except ImportError: | |
| pass | |
| import gradio as gr | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| colors.orange_red = colors.Color( | |
| name="orange_red", | |
| c50="#FFF0E5", | |
| c100="#FFE0CC", | |
| c200="#FFC299", | |
| c300="#FFA366", | |
| c400="#FF8533", | |
| c500="#FF4500", | |
| c600="#E63E00", | |
| c700="#CC3700", | |
| c800="#B33000", | |
| c900="#992900", | |
| c950="#802200", | |
| ) | |
| class OrangeRedTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.orange_red, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_secondary_text_color="black", | |
| button_secondary_text_color_hover="white", | |
| button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| orange_red_theme = OrangeRedTheme() | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1400px; | |
| } | |
| #main-title h1 { font-size: 2.4em !important; } | |
| """ | |
| BASE_DIR = Path(__file__).resolve().parent | |
| OUTPUT_DIR = BASE_DIR / "outputs" | |
| UPLOAD_DIR = BASE_DIR / "uploads" | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| UPLOAD_DIR.mkdir(exist_ok=True) | |
| _model = None | |
| _device = None | |
| def get_model_and_device(): | |
| global _model, _device | |
| import torch | |
| from safetensors.torch import load_file as load_safetensors | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import snapshot_download | |
| from hyworldmirror.models.models.worldmirror import WorldMirror | |
| if _model is None: | |
| print("[Lazy Load] Loading HY-World-2.0 model...") | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if torch.cuda.is_available(): | |
| print(f"[Lazy Load] CUDA: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| model_path = "tencent/HY-World-2.0" | |
| subfolder = "HY-WorldMirror-2.0" | |
| repo_root = snapshot_download( | |
| repo_id=model_path, | |
| allow_patterns=[f"{subfolder}/**"] | |
| ) | |
| model_dir = Path(repo_root) / subfolder | |
| yaml_path = model_dir / "config.yaml" | |
| json_path = model_dir / "config.json" | |
| if yaml_path.exists(): | |
| cfg = OmegaConf.load(yaml_path) | |
| if hasattr(cfg, "wrapper") and hasattr(cfg.wrapper, "model"): | |
| model_cfg = cfg.wrapper.model | |
| elif hasattr(cfg, "model"): | |
| model_cfg = cfg.model | |
| else: | |
| model_cfg = cfg | |
| model_cfg = OmegaConf.to_container(model_cfg, resolve=True) | |
| model_cfg.pop("_target_", None) | |
| elif json_path.exists(): | |
| with open(json_path) as f: | |
| model_cfg = json.load(f) | |
| else: | |
| raise FileNotFoundError(f"No config found in {model_dir}") | |
| _model = WorldMirror(**model_cfg).to(_device) | |
| safetensors_path = model_dir / "model.safetensors" | |
| if safetensors_path.exists(): | |
| state = load_safetensors(str(safetensors_path)) | |
| current = _model.state_dict() | |
| matched = 0 | |
| for key in current: | |
| if key in state and current[key].shape == state[key].shape: | |
| current[key] = state[key] | |
| matched += 1 | |
| _model.load_state_dict(current, strict=True) | |
| print(f"[Lazy Load] Loaded {matched}/{len(current)} keys") | |
| del state | |
| _model.eval() | |
| for param in _model.parameters(): | |
| param.requires_grad = False | |
| print("[Lazy Load] Model ready") | |
| return _model, _device | |
| def process_uploaded_files(files, time_interval: float = 1.0): | |
| """Accept a list of file paths from Gradio uploads. | |
| Returns (target_dir, sorted_image_paths). | |
| """ | |
| target_dir = UPLOAD_DIR / f"input_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}" | |
| images_dir = target_dir / "images" | |
| images_dir.mkdir(parents=True, exist_ok=True) | |
| image_paths = [] | |
| for f in files: | |
| src = Path(f) if isinstance(f, str) else Path(f.name) | |
| if not src.exists(): | |
| continue | |
| ext = src.suffix.lower() | |
| base_name = src.stem | |
| if ext in {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv"}: | |
| cap = cv2.VideoCapture(str(src)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS) > 0 else 30 | |
| interval = max(1, int(fps * time_interval)) | |
| idx, saved = 0, 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| idx += 1 | |
| if idx % interval == 0: | |
| dst = images_dir / f"{base_name}_{saved:06d}.png" | |
| cv2.imwrite(str(dst), frame) | |
| image_paths.append(str(dst)) | |
| saved += 1 | |
| cap.release() | |
| elif ext in {".heic", ".heif"}: | |
| img = Image.open(str(src)) | |
| dst = images_dir / f"{base_name}.jpg" | |
| img.convert("RGB").save(dst, "JPEG", quality=95) | |
| image_paths.append(str(dst)) | |
| else: | |
| dst = images_dir / src.name | |
| shutil.copy2(str(src), str(dst)) | |
| image_paths.append(str(dst)) | |
| return str(target_dir), sorted(image_paths) | |
| def render_depth_colormap(depth_map, mask=None): | |
| import matplotlib.pyplot as plt | |
| d = depth_map.copy() | |
| valid = (d > 0) & mask if mask is not None else (d > 0) | |
| if valid.sum() > 0: | |
| lo, hi = np.percentile(d[valid], 5), np.percentile(d[valid], 95) | |
| d[valid] = (d[valid] - lo) / (hi - lo + 1e-9) | |
| rgb = (plt.cm.turbo_r(d)[:, :, :3] * 255).astype(np.uint8) | |
| if mask is not None: | |
| rgb[~mask] = [255, 255, 255] | |
| return rgb | |
| def render_normal_colormap(normal_map, mask=None): | |
| n = normal_map.copy() | |
| if mask is not None: | |
| n[~mask] = 0 | |
| return ((n + 1.0) * 0.5 * 255).clip(0, 255).astype(np.uint8) | |
| def on_files_uploaded(files, time_interval): | |
| """ | |
| Triggered when the user uploads files. | |
| Extracts frames, returns gallery images + stored target_dir + frame choices. | |
| """ | |
| if not files: | |
| return [], None, gr.update(choices=["All"], value="All") | |
| target_dir, img_paths = process_uploaded_files(files, time_interval) | |
| # Build gallery list: list of (filepath, caption) tuples | |
| gallery_items = [(p, Path(p).stem) for p in img_paths] | |
| # Frame selector choices | |
| choices = ["All"] + [str(i) for i in range(len(img_paths))] | |
| return gallery_items, target_dir, gr.update(choices=choices, value="All") | |
| def _make_rec(app_id: str) -> "rr.RecordingStream": | |
| run_id = str(uuid.uuid4()) | |
| if hasattr(rr, "new_recording"): | |
| return rr.new_recording(application_id=app_id, recording_id=run_id) | |
| return rr.RecordingStream(application_id=app_id, recording_id=run_id) | |
| def build_reconstruction_rrd( | |
| output_subdir: Path, | |
| glb_path: Path, | |
| world_points: np.ndarray, | |
| images_np: np.ndarray, | |
| camera_poses: np.ndarray, | |
| camera_intrs: Optional[np.ndarray], | |
| filter_mask: list, | |
| normals: Optional[np.ndarray], | |
| ) -> str: | |
| """Build a Rerun recording for the full scene reconstruction.""" | |
| rrd_path = str(output_subdir / "reconstruction.rrd") | |
| rec = _make_rec("HY-World-2.0-Reconstruction") | |
| rec.log("world", rr.Clear(recursive=True), static=True) | |
| rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True) | |
| # Axis helpers | |
| try: | |
| rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[220, 50, 50]]), static=True) | |
| rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[50, 200, 50]]), static=True) | |
| rec.log("world/axes/z", rr.Arrows3D(vectors=[[0, 0, 0.5]], colors=[[50, 100, 220]]), static=True) | |
| except Exception: | |
| pass | |
| # GLB mesh asset | |
| if glb_path.exists(): | |
| try: | |
| rec.log("world/scene_mesh", rr.Asset3D(path=str(glb_path)), static=True) | |
| except Exception as e: | |
| print(f"[Rerun] GLB log failed: {e}") | |
| S = images_np.shape[0] if images_np.ndim == 4 else 1 | |
| for i in range(S): | |
| img_hwc = np.transpose(images_np[i], (1, 2, 0)) | |
| img_u8 = (np.clip(img_hwc, 0, 1) * 255).astype(np.uint8) | |
| H, W = img_u8.shape[:2] | |
| mask = filter_mask[i] if (filter_mask and i < len(filter_mask)) else None | |
| # Point cloud | |
| if world_points.ndim in (4, 5): | |
| pts = world_points[i] | |
| else: | |
| pts = world_points | |
| mask = None | |
| if pts.ndim == 3: | |
| pts_f = pts.reshape(-1, 3) | |
| col_f = img_u8.reshape(-1, 3) | |
| if mask is not None: | |
| mf = mask.reshape(-1) | |
| pts_f = pts_f[mf] | |
| col_f = col_f[mf] | |
| else: | |
| pts_f = pts | |
| col_f = None | |
| if pts_f.shape[0] > 0: | |
| try: | |
| rec.log( | |
| f"world/point_cloud/view_{i:03d}", | |
| rr.Points3D( | |
| positions=pts_f.astype(np.float32), | |
| colors=col_f, | |
| radii=0.003, | |
| ), | |
| static=True, | |
| ) | |
| except Exception as e: | |
| print(f"[Rerun] Points3D view {i} failed: {e}") | |
| # Camera | |
| if camera_poses is not None and i < len(camera_poses): | |
| c2w = camera_poses[i] | |
| intr = camera_intrs[i] if camera_intrs is not None else None | |
| try: | |
| rec.log( | |
| f"world/cameras/cam_{i:03d}", | |
| rr.Transform3D( | |
| translation=c2w[:3, 3].astype(np.float32), | |
| mat3x3=c2w[:3, :3].astype(np.float32), | |
| ), | |
| static=True, | |
| ) | |
| if intr is not None: | |
| rec.log( | |
| f"world/cameras/cam_{i:03d}/pinhole", | |
| rr.Pinhole( | |
| focal_length=[float(intr[0, 0]), float(intr[1, 1])], | |
| principal_point=[float(intr[0, 2]), float(intr[1, 2])], | |
| width=W, height=H, | |
| ), | |
| static=True, | |
| ) | |
| rec.log( | |
| f"world/cameras/cam_{i:03d}/pinhole/image", | |
| rr.Image(img_u8), | |
| static=True, | |
| ) | |
| except Exception as e: | |
| print(f"[Rerun] Camera {i} failed: {e}") | |
| # Normal map | |
| if normals is not None and i < len(normals): | |
| try: | |
| rec.log( | |
| f"world/normals/view_{i:03d}", | |
| rr.Image(render_normal_colormap(normals[i], mask)), | |
| static=True, | |
| ) | |
| except Exception as e: | |
| print(f"[Rerun] Normal {i} failed: {e}") | |
| # Blueprint | |
| if rrb is not None: | |
| try: | |
| bp = rrb.Blueprint( | |
| rrb.Horizontal( | |
| rrb.Spatial3DView(origin="/world", name="3D Scene"), | |
| rrb.Vertical( | |
| rrb.Spatial2DView( | |
| origin=f"/world/cameras/cam_000/pinhole/image", | |
| name="Camera View", | |
| ), | |
| rrb.Spatial2DView( | |
| origin="/world/normals/view_000", | |
| name="Normal Map", | |
| ), | |
| ), | |
| column_shares=[3, 1], | |
| ), | |
| collapse_panels=True, | |
| ) | |
| rec.send_blueprint(bp) | |
| except Exception as e: | |
| print(f"[Rerun] Reconstruction blueprint failed (non-fatal): {e}") | |
| rec.save(rrd_path) | |
| print(f"[Rerun] Reconstruction saved -> {rrd_path}") | |
| return rrd_path | |
| def build_gaussians_rrd( | |
| output_subdir: Path, | |
| means: np.ndarray, | |
| colors_np: np.ndarray, | |
| opacities: np.ndarray, | |
| scales: np.ndarray, | |
| ) -> str: | |
| """ | |
| Build a Rerun recording for Gaussian splats. | |
| Orientation fix: WorldMirror uses a right-hand Y-up convention but the | |
| splat means are stored in camera/world space where +Y points DOWN | |
| (OpenCV convention). We flip Y so the cloud is upright in the viewer. | |
| """ | |
| rrd_path = str(output_subdir / "gaussians.rrd") | |
| rec = _make_rec("HY-World-2.0-Gaussians") | |
| # Flip Y to convert from OpenCV (Y-down) to viewer (Y-up) | |
| means_corrected = means.copy() | |
| means_corrected[:, 1] = -means_corrected[:, 1] | |
| rec.log("world", rr.Clear(recursive=True), static=True) | |
| # Use RIGHT_HAND_Y_UP so +Y is treated as up in the viewer | |
| rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True) | |
| colors_u8 = (np.clip(colors_np, 0, 1) * 255).astype(np.uint8) if colors_np.dtype != np.uint8 else colors_np | |
| alpha = (np.clip(opacities, 0, 1) * 255).astype(np.uint8).reshape(-1, 1) | |
| colors_rgba = np.concatenate([colors_u8, alpha], axis=1) | |
| radii = np.clip(np.linalg.norm(scales, axis=1) * 0.5, 0.001, 0.1).astype(np.float32) | |
| MAX_PTS = 2_000_000 | |
| N = means_corrected.shape[0] | |
| if N > MAX_PTS: | |
| sel = np.random.default_rng(42).choice(N, size=MAX_PTS, replace=False) | |
| means_corrected = means_corrected[sel] | |
| colors_rgba = colors_rgba[sel] | |
| radii = radii[sel] | |
| try: | |
| rec.log( | |
| "world/gaussian_splats", | |
| rr.Points3D( | |
| positions=means_corrected.astype(np.float32), | |
| colors=colors_rgba, | |
| radii=radii, | |
| ), | |
| static=True, | |
| ) | |
| except Exception as e: | |
| print(f"[Rerun] Gaussians Points3D failed: {e}") | |
| try: | |
| rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 50, 50]]), static=True) | |
| rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[50, 200, 50]]), static=True) | |
| rec.log("world/axes/z", rr.Arrows3D(vectors=[[0, 0, 0.5]], colors=[[50, 100, 220]]), static=True) | |
| except Exception: | |
| pass | |
| if rrb is not None: | |
| try: | |
| bp = rrb.Blueprint( | |
| rrb.Spatial3DView(origin="/world", name="Gaussian Splats"), | |
| collapse_panels=True, | |
| ) | |
| rec.send_blueprint(bp) | |
| except Exception as e: | |
| print(f"[Rerun] Gaussians blueprint failed (non-fatal): {e}") | |
| rec.save(rrd_path) | |
| print(f"[Rerun] Gaussians saved -> {rrd_path}") | |
| return rrd_path | |
| def run_reconstruction_pipeline( | |
| target_dir: str, | |
| frame_selector: str, | |
| show_camera: bool, | |
| filter_sky_bg: bool, | |
| show_mesh: bool, | |
| filter_ambiguous: bool, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Runs the WorldMirror pipeline on the already-extracted frames in target_dir. | |
| Returns: | |
| recon_rrd_path str | |
| depth_imgs List[str] | |
| depth_first str | None | |
| normal_imgs List[str] | |
| normal_first str | None | |
| status str | |
| gs_rrd_path str | None | |
| gs_ply_path str | None | |
| """ | |
| import torch | |
| from hyworldmirror.utils.inference_utils import ( | |
| prepare_images_to_tensor, | |
| compute_adaptive_target_size, | |
| compute_sky_mask, | |
| compute_filter_mask, | |
| _voxel_prune_gaussians, | |
| depth_to_world_coords_points, | |
| ) | |
| from hyworldmirror.utils.visual_util import convert_predictions_to_glb_scene | |
| from hyworldmirror.utils.save_utils import save_camera_params, save_gs_ply | |
| if not target_dir: | |
| raise gr.Error("Please upload files first before running reconstruction.") | |
| image_folder = Path(target_dir) / "images" | |
| img_paths = sorted( | |
| glob(str(image_folder / "*.png")) + | |
| glob(str(image_folder / "*.jpg")) + | |
| glob(str(image_folder / "*.jpeg")) + | |
| glob(str(image_folder / "*.webp")) | |
| ) | |
| if not img_paths: | |
| raise gr.Error("No valid images found. Please re-upload your files.") | |
| progress(0.10, desc="Loading model...") | |
| model, device = get_model_and_device() | |
| effective = compute_adaptive_target_size(img_paths, max_target_size=952) | |
| progress(0.20, desc="Running inference...") | |
| with torch.no_grad(): | |
| imgs = prepare_images_to_tensor(img_paths, target_size=effective, resize_strategy="crop").to(device) | |
| views = {"img": imgs} | |
| B, S, C, H, W = imgs.shape | |
| use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| t0 = time.perf_counter() | |
| with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16): | |
| predictions = model(views=views, cond_flags=[0, 0, 0], is_inference=True) | |
| if device.type == "cuda": | |
| torch.cuda.synchronize() | |
| infer_time = time.perf_counter() - t0 | |
| progress(0.45, desc="Computing masks...") | |
| sky_mask = compute_sky_mask( | |
| img_paths, H, W, S, predictions=predictions, | |
| source="auto", model_threshold=0.45, | |
| processed_aspect_ratio=W / H, | |
| ) | |
| filter_mask, gs_filter_mask = compute_filter_mask( | |
| predictions, imgs, img_paths, H, W, S, | |
| apply_confidence_mask=False, | |
| apply_edge_mask=True, | |
| apply_sky_mask=True, | |
| confidence_percentile=10.0, | |
| edge_normal_threshold=1.0, | |
| edge_depth_threshold=0.03, | |
| sky_mask=sky_mask, | |
| use_gs_depth=("gs_depth" in predictions), | |
| ) | |
| output_id = uuid.uuid4().hex[:8] | |
| output_subdir = OUTPUT_DIR / f"recon_{output_id}" | |
| output_subdir.mkdir(exist_ok=True, parents=True) | |
| safe_name = frame_selector.replace(".", "_").replace(":", "").replace(" ", "_") | |
| glb_path = output_subdir / f"scene_{safe_name}.glb" | |
| # Assemble glb_preds dict | |
| glb_preds = {} | |
| if "world_points" not in predictions: | |
| glb_preds["world_points"] = depth_to_world_coords_points( | |
| predictions["depth"][0, ..., 0], | |
| predictions["camera_poses"][0], | |
| predictions["camera_intrs"][0], | |
| )[0].cpu().float().numpy() | |
| else: | |
| pts = predictions["world_points"] | |
| if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().numpy() | |
| if pts.ndim == 5: pts = pts[0] | |
| glb_preds["world_points"] = pts | |
| glb_preds["images"] = imgs[0].detach().cpu().numpy() | |
| c2w = predictions["camera_poses"] | |
| if isinstance(c2w, torch.Tensor): c2w = c2w.detach().cpu().numpy() | |
| if c2w.ndim == 4: c2w = c2w[0] | |
| glb_preds["camera_poses"] = c2w | |
| glb_preds["final_mask"] = np.array(filter_mask) | |
| glb_preds["sky_mask"] = np.array(sky_mask) | |
| normals_np = None | |
| for key in ("normal", "normals"): | |
| if key in predictions and predictions[key] is not None: | |
| n = predictions[key] | |
| if isinstance(n, torch.Tensor): n = n.detach().cpu().numpy() | |
| if n.ndim == 5: n = n[0] | |
| glb_preds["normal"] = n | |
| normals_np = n | |
| break | |
| progress(0.60, desc="Building GLB mesh...") | |
| glb_scene = convert_predictions_to_glb_scene( | |
| glb_preds, | |
| filter_by_frames=frame_selector, | |
| show_camera=show_camera, | |
| mask_sky_bg=filter_sky_bg, | |
| as_mesh=show_mesh, | |
| mask_ambiguous=filter_ambiguous, | |
| ) | |
| glb_scene.export(file_obj=str(glb_path)) | |
| progress(0.70, desc="Building Rerun reconstruction recording...") | |
| recon_rrd_path = build_reconstruction_rrd( | |
| output_subdir=output_subdir, | |
| glb_path=glb_path, | |
| world_points=glb_preds["world_points"], | |
| images_np=glb_preds["images"], | |
| camera_poses=glb_preds["camera_poses"], | |
| camera_intrs=predictions["camera_intrs"][0].cpu().float().numpy(), | |
| filter_mask=list(filter_mask), | |
| normals=normals_np, | |
| ) | |
| progress(0.78, desc="Rendering depth and normal maps...") | |
| depth_imgs, normal_imgs = [], [] | |
| for i in range(S): | |
| depth = predictions["depth"][0, i].cpu().float().numpy().squeeze() | |
| mask = filter_mask[i] if i < len(filter_mask) else None | |
| depth_p = output_subdir / f"depth_{i:03d}.png" | |
| Image.fromarray(render_depth_colormap(depth, mask)).save(depth_p) | |
| depth_imgs.append(str(depth_p)) | |
| normal = np.zeros((H, W, 3), dtype=np.float32) | |
| for k in ("normals", "normal"): | |
| if k in predictions and predictions[k] is not None: | |
| normal = predictions[k][0, i].cpu().float().numpy() | |
| break | |
| normal_p = output_subdir / f"normal_{i:03d}.png" | |
| Image.fromarray(render_normal_colormap(normal, mask)).save(normal_p) | |
| normal_imgs.append(str(normal_p)) | |
| save_camera_params( | |
| predictions["camera_poses"][0].cpu().float().numpy(), | |
| predictions["camera_intrs"][0].cpu().float().numpy(), | |
| str(output_subdir), | |
| ) | |
| progress(0.88, desc="Processing Gaussian splats...") | |
| gs_rrd_path = None | |
| gs_ply_path = None | |
| if "splats" in predictions: | |
| sp = predictions["splats"] | |
| means = sp["means"][0].reshape(-1, 3).cpu() | |
| scales = sp["scales"][0].reshape(-1, 3).cpu() | |
| quats = sp["quats"][0].reshape(-1, 4).cpu() | |
| colors_t = (sp.get("sh", sp.get("colors"))[0]).reshape(-1, 3).cpu() | |
| opacities = sp["opacities"][0].reshape(-1).cpu() | |
| weights = sp["weights"][0].reshape(-1).cpu() if "weights" in sp else torch.ones_like(opacities) | |
| if gs_filter_mask is not None: | |
| keep = torch.from_numpy(gs_filter_mask.reshape(-1)).bool() | |
| means, scales, quats = means[keep], scales[keep], quats[keep] | |
| colors_t, opacities, weights = colors_t[keep], opacities[keep], weights[keep] | |
| means, scales, quats, colors_t, opacities = _voxel_prune_gaussians( | |
| means, scales, quats, colors_t, opacities, weights | |
| ) | |
| if means.shape[0] > 5_000_000: | |
| idx = torch.from_numpy( | |
| np.random.default_rng(42).choice(means.shape[0], 5_000_000, replace=False) | |
| ).long() | |
| means, scales, quats, colors_t, opacities = ( | |
| means[idx], scales[idx], quats[idx], colors_t[idx], opacities[idx] | |
| ) | |
| ply_out = str(output_subdir / "gaussians.ply") | |
| save_gs_ply(ply_out, means, scales, quats, colors_t, opacities) | |
| gs_ply_path = ply_out | |
| progress(0.94, desc="Building Rerun recording for Gaussian splats...") | |
| gs_rrd_path = build_gaussians_rrd( | |
| output_subdir=output_subdir, | |
| means=means.float().numpy(), | |
| colors_np=colors_t.float().numpy(), | |
| opacities=opacities.float().numpy(), | |
| scales=scales.float().numpy(), | |
| ) | |
| del predictions, glb_preds, imgs | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| status = ( | |
| f"Reconstruction complete β {S} view(s), " | |
| f"inference time: {infer_time:.2f}s" | |
| ) | |
| return ( | |
| recon_rrd_path, | |
| depth_imgs, | |
| depth_imgs[0] if depth_imgs else None, | |
| normal_imgs, | |
| normal_imgs[0] if normal_imgs else None, | |
| status, | |
| gs_rrd_path, | |
| gs_ply_path, | |
| ) | |
| with gr.Blocks(theme=orange_red_theme, css=css, delete_cache=(600, 600)) as demo: | |
| gr.Markdown( | |
| "# **HY-World-2.0-Demo β WorldMirror 2.0 World Reconstruction**", | |
| elem_id="main-title", | |
| ) | |
| gr.Markdown( | |
| "Upload images or video to reconstruct a full 3D scene: " | |
| "point clouds, camera poses, depth maps, surface normals and Gaussian splats β " | |
| "powered by [HY-World-2.0 / WorldMirror](https://huggingface.co/papers/2604.14268). " | |
| "After Reconstruction hold tight, the rendering will happen in fraction of seconds." | |
| ) | |
| # Persistent state | |
| target_dir_state = gr.State(None) | |
| depth_list_state = gr.State([]) | |
| normal_list_state = gr.State([]) | |
| depth_idx_state = gr.State(0) | |
| normal_idx_state = gr.State(0) | |
| with gr.Row(elem_id="col-container"): | |
| with gr.Column(scale=1, min_width=380): | |
| file_upload = gr.File( | |
| label="Upload Images or Video", | |
| file_count="multiple", | |
| file_types=["image", "video", ".heic", ".heif"], | |
| ) | |
| # Input frame preview gallery | |
| input_gallery = gr.Gallery( | |
| label="Input Frames Preview", | |
| columns=2, | |
| rows=2, | |
| height=220, | |
| object_fit="cover", | |
| show_label=True, | |
| preview=False, | |
| ) | |
| time_interval = gr.Slider( | |
| 0.1, 10.0, value=1.0, step=0.1, | |
| label="Video Interval (secs)", | |
| ) | |
| frame_selector = gr.Dropdown( | |
| choices=["All"], | |
| value="All", | |
| label="Frame Selector", | |
| ) | |
| with gr.Accordion("Reconstruction Options", open=False): | |
| show_camera = gr.Checkbox(value=True, label="Show Camera Frustums") | |
| show_mesh = gr.Checkbox(value=True, label="Show as Mesh (vs Point Cloud)") | |
| filter_ambiguous = gr.Checkbox(value=True, label="Filter Low Confidence / Edges") | |
| filter_sky_bg = gr.Checkbox(value=False, label="Filter Sky Background") | |
| btn_reconstruct = gr.Button("Reconstruct 3D Scene", variant="primary") | |
| status_box = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=2, | |
| visible=False, | |
| placeholder="Awaiting reconstruction...", | |
| ) | |
| with gr.Accordion("Downloads", open=False): | |
| dl_glb = gr.DownloadButton(label="Download GLB Mesh", variant="secondary") | |
| dl_gs_ply = gr.DownloadButton(label="Download Gaussian PLY", variant="secondary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.Tab("3D Reconstruction"): | |
| gr.Markdown("### 3D Scene β Rerun Viewer") | |
| rerun_recon = Rerun( | |
| label="3D Reconstruction β Rerun Viewer", | |
| height=630, | |
| ) | |
| with gr.Tab("Gaussian Splats"): | |
| gr.Markdown("### Gaussian Splat Point Cloud β Rerun Viewer") | |
| rerun_gs = Rerun( | |
| label="Gaussian Splats β Rerun Viewer", | |
| height=630, | |
| ) | |
| gr.Markdown( | |
| "The PLY file (downloadable above) can also be opened in " | |
| "[SuperSplat](https://supersplat.xyz) or PlayCanvas for full splat rendering." | |
| ) | |
| with gr.Tab("Depth Maps"): | |
| gr.Markdown("### Per-View Depth Maps") | |
| depth_image = gr.Image( | |
| label="Depth Map", | |
| type="filepath", | |
| height=540, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| btn_depth_prev = gr.Button("Previous", variant="secondary", scale=1) | |
| depth_counter = gr.Textbox( | |
| value="β", label="Frame", interactive=False, scale=1 | |
| ) | |
| btn_depth_next = gr.Button("Next", variant="secondary", scale=1) | |
| with gr.Tab("Normal Maps"): | |
| gr.Markdown("### Per-View Surface Normal Maps") | |
| normal_image = gr.Image( | |
| label="Normal Map", | |
| type="filepath", | |
| height=540, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| btn_normal_prev = gr.Button("Previous", variant="secondary", scale=1) | |
| normal_counter = gr.Textbox( | |
| value="β", label="Frame", interactive=False, scale=1 | |
| ) | |
| btn_normal_next = gr.Button("Next", variant="secondary", scale=1) | |
| file_upload.upload( | |
| on_files_uploaded, | |
| inputs=[file_upload, time_interval], | |
| outputs=[input_gallery, target_dir_state, frame_selector], | |
| ) | |
| time_interval.change( | |
| on_files_uploaded, | |
| inputs=[file_upload, time_interval], | |
| outputs=[input_gallery, target_dir_state, frame_selector], | |
| ) | |
| def reconstruct( | |
| target_dir, | |
| frame_selector, | |
| show_camera, | |
| filter_sky_bg, | |
| show_mesh, | |
| filter_ambiguous, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| ( | |
| recon_rrd_path, | |
| depth_list, | |
| depth_first, | |
| normal_list, | |
| normal_first, | |
| status, | |
| gs_rrd_path, | |
| gs_ply_path, | |
| ) = run_reconstruction_pipeline( | |
| target_dir, | |
| frame_selector, | |
| show_camera, | |
| filter_sky_bg, | |
| show_mesh, | |
| filter_ambiguous, | |
| progress=progress, | |
| ) | |
| depth_ctr = f"1 / {len(depth_list)}" if depth_list else "β" | |
| normal_ctr = f"1 / {len(normal_list)}" if normal_list else "β" | |
| return ( | |
| recon_rrd_path, # rerun_recon | |
| gs_rrd_path, # rerun_gs | |
| depth_first, # depth_image | |
| depth_list, # depth_list_state | |
| 0, # depth_idx_state | |
| depth_ctr, # depth_counter | |
| normal_first, # normal_image | |
| normal_list, # normal_list_state | |
| 0, # normal_idx_state | |
| normal_ctr, # normal_counter | |
| status, # status_box | |
| gs_ply_path, # dl_glb (GLB embedded in rrd; offer ply download) | |
| gs_ply_path, # dl_gs_ply | |
| ) | |
| btn_reconstruct.click( | |
| reconstruct, | |
| inputs=[ | |
| target_dir_state, frame_selector, | |
| show_camera, filter_sky_bg, show_mesh, filter_ambiguous, | |
| ], | |
| outputs=[ | |
| rerun_recon, | |
| rerun_gs, | |
| depth_image, | |
| depth_list_state, | |
| depth_idx_state, | |
| depth_counter, | |
| normal_image, | |
| normal_list_state, | |
| normal_idx_state, | |
| normal_counter, | |
| status_box, | |
| dl_glb, | |
| dl_gs_ply, | |
| ], | |
| ) | |
| # Depth navigation | |
| def depth_prev(img_list, idx): | |
| new_idx = max(0, idx - 1) | |
| ctr = f"{new_idx + 1} / {len(img_list)}" if img_list else "β" | |
| return (img_list[new_idx] if img_list else None), new_idx, ctr | |
| def depth_next(img_list, idx): | |
| new_idx = min(len(img_list) - 1, idx + 1) if img_list else 0 | |
| ctr = f"{new_idx + 1} / {len(img_list)}" if img_list else "β" | |
| return (img_list[new_idx] if img_list else None), new_idx, ctr | |
| btn_depth_prev.click( | |
| depth_prev, | |
| inputs=[depth_list_state, depth_idx_state], | |
| outputs=[depth_image, depth_idx_state, depth_counter], | |
| ) | |
| btn_depth_next.click( | |
| depth_next, | |
| inputs=[depth_list_state, depth_idx_state], | |
| outputs=[depth_image, depth_idx_state, depth_counter], | |
| ) | |
| def normal_prev(img_list, idx): | |
| new_idx = max(0, idx - 1) | |
| ctr = f"{new_idx + 1} / {len(img_list)}" if img_list else "β" | |
| return (img_list[new_idx] if img_list else None), new_idx, ctr | |
| def normal_next(img_list, idx): | |
| new_idx = min(len(img_list) - 1, idx + 1) if img_list else 0 | |
| ctr = f"{new_idx + 1} / {len(img_list)}" if img_list else "β" | |
| return (img_list[new_idx] if img_list else None), new_idx, ctr | |
| btn_normal_prev.click( | |
| normal_prev, | |
| inputs=[normal_list_state, normal_idx_state], | |
| outputs=[normal_image, normal_idx_state, normal_counter], | |
| ) | |
| btn_normal_next.click( | |
| normal_next, | |
| inputs=[normal_list_state, normal_idx_state], | |
| outputs=[normal_image, normal_idx_state, normal_counter], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| mcp_server=True, | |
| ssr_mode=False, | |
| show_error=True, | |
| allowed_paths=[str(OUTPUT_DIR), str(UPLOAD_DIR)], | |
| ) |