prithivMLmods commited on
Commit
fbcc0a9
·
verified ·
1 Parent(s): f90be73

update app --files

Browse files
Files changed (42) hide show
  1. hyworldmirror/__init__.py +0 -0
  2. hyworldmirror/comm/__init__.py +0 -0
  3. hyworldmirror/comm/communication.py +61 -0
  4. hyworldmirror/comm/padding.py +134 -0
  5. hyworldmirror/models/__init__.py +0 -0
  6. hyworldmirror/models/heads/__init__.py +0 -0
  7. hyworldmirror/models/heads/camera_head.py +184 -0
  8. hyworldmirror/models/heads/dense_head.py +672 -0
  9. hyworldmirror/models/heads/gs_head.py +83 -0
  10. hyworldmirror/models/layers/__init__.py +5 -0
  11. hyworldmirror/models/layers/attention.py +131 -0
  12. hyworldmirror/models/layers/block.py +269 -0
  13. hyworldmirror/models/layers/drop_path.py +29 -0
  14. hyworldmirror/models/layers/layer_scale.py +17 -0
  15. hyworldmirror/models/layers/mlp.py +64 -0
  16. hyworldmirror/models/layers/norm_rope.py +140 -0
  17. hyworldmirror/models/layers/patch_embed.py +155 -0
  18. hyworldmirror/models/layers/rope.py +182 -0
  19. hyworldmirror/models/layers/swiglu_ffn.py +46 -0
  20. hyworldmirror/models/layers/vision_transformer.py +394 -0
  21. hyworldmirror/models/models/__init__.py +0 -0
  22. hyworldmirror/models/models/rasterization.py +525 -0
  23. hyworldmirror/models/models/visual_transformer.py +542 -0
  24. hyworldmirror/models/models/worldmirror.py +685 -0
  25. hyworldmirror/models/utils/__init__.py +0 -0
  26. hyworldmirror/models/utils/act_gs.py +22 -0
  27. hyworldmirror/models/utils/camera_utils.py +75 -0
  28. hyworldmirror/models/utils/frustum.py +196 -0
  29. hyworldmirror/models/utils/geometry.py +111 -0
  30. hyworldmirror/models/utils/grid.py +90 -0
  31. hyworldmirror/models/utils/priors.py +168 -0
  32. hyworldmirror/models/utils/rotation.py +126 -0
  33. hyworldmirror/models/utils/sh_utils.py +116 -0
  34. hyworldmirror/utils/__init__.py +0 -0
  35. hyworldmirror/utils/geometry.py +531 -0
  36. hyworldmirror/utils/inference_utils.py +824 -0
  37. hyworldmirror/utils/render_utils.py +294 -0
  38. hyworldmirror/utils/save_utils.py +261 -0
  39. hyworldmirror/utils/video_utils.py +557 -0
  40. hyworldmirror/utils/visual_util.py +617 -0
  41. hyworldmirror/utils/warnings.py +29 -0
  42. pipeline.py +847 -0
hyworldmirror/__init__.py ADDED
File without changes
hyworldmirror/comm/__init__.py ADDED
File without changes
hyworldmirror/comm/communication.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ def all2all(tensor, scatter_dim, gather_dim, cur_group, async_op):
5
+ group_size = dist.get_world_size(group=cur_group)
6
+ scatter_tensor_list = list(chunk.contiguous() for chunk in torch.chunk(tensor, chunks=group_size, dim=scatter_dim))
7
+ gather_tensor_list = [torch.zeros_like(x) for x in scatter_tensor_list]
8
+ comm = dist.all_to_all(gather_tensor_list, scatter_tensor_list, group=cur_group, async_op=async_op)
9
+ if async_op:
10
+ def wait():
11
+ comm.wait()
12
+ recieved_tensor = torch.cat(gather_tensor_list, dim=gather_dim).contiguous()
13
+ return recieved_tensor
14
+ return wait()
15
+ recieved_tensor = torch.cat(gather_tensor_list, dim=gather_dim).contiguous()
16
+ return recieved_tensor
17
+
18
+ def all_gather(tensor, gather_dim, cur_group, async_op):
19
+ tensor = tensor.contiguous()
20
+ group_size = dist.get_world_size(group=cur_group)
21
+ gather_list = [torch.zeros_like(tensor) for _ in range(group_size)]
22
+ comm = dist.all_gather(gather_list, tensor, group=cur_group, async_op=async_op)
23
+ gather_tensor = torch.cat(gather_list, dim=gather_dim)
24
+ if async_op:
25
+ def wait():
26
+ comm.wait()
27
+ gather_tensor = torch.cat(gather_list, dim=gather_dim)
28
+ return gather_tensor
29
+ return wait()
30
+ return gather_tensor
31
+
32
+
33
+ class _All2All(torch.autograd.Function):
34
+ @staticmethod
35
+ def forward(ctx, tensor, scatter_dim, gather_dim, cur_group, async_op):
36
+ ctx.cur_group = cur_group
37
+ ctx.scatter_dim = scatter_dim
38
+ ctx.gather_dim = gather_dim
39
+ ctx.async_op = async_op
40
+ return all2all(tensor=tensor, scatter_dim=scatter_dim, gather_dim=gather_dim, cur_group=cur_group, async_op=async_op)
41
+
42
+ @staticmethod
43
+ def backward(ctx, grad_outputs):
44
+ input_t = grad_outputs
45
+ return (all2all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.cur_group, False), None, None, None, None)
46
+
47
+ class _Allgather(torch.autograd.Function):
48
+ @staticmethod
49
+ def forward(ctx, tensor, gather_dim, cur_group, async_op):
50
+ ctx.gather_dim = gather_dim
51
+ ctx.cur_group = cur_group
52
+ ctx.async_op = async_op
53
+ return all_gather(tensor=tensor, gather_dim=gather_dim, cur_group=cur_group, async_op=async_op)
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_outputs):
57
+ sp_group = ctx.cur_group
58
+ sp_group_size = dist.get_world_size(group=sp_group)
59
+ rank = dist.get_rank()
60
+ rank_in_group = dist.get_group_rank(group=sp_group, global_rank=rank)
61
+ return (grad_outputs.split(grad_outputs.shape[ctx.gather_dim] // sp_group_size, dim=ctx.gather_dim)[rank_in_group], None, None, None)
hyworldmirror/comm/padding.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def minimal_pad_to_divisible(tensor: torch.Tensor, sp_size: int, dim: int = 1, pad_value: float = 0.0):
5
+ """
6
+ 对三维或更高维度的tensor在指定维度进行最小化padding,使其长度能被 sp_size 整除。
7
+
8
+ Args:
9
+ tensor: 输入的PyTorch tensor (例如:[B, L, C] 或 [B, H, W, C] 等)。
10
+ sp_size: 要求的最小分割尺寸。
11
+ dim: 需要进行padding的维度索引(默认为 1,即第二维)。
12
+ pad_value: 填充的值(默认为 0.0)。
13
+
14
+ Returns:
15
+ padded_tensor: 填充后的 tensor。
16
+ """
17
+
18
+ current_size = tensor.size(dim)
19
+
20
+ # 计算需要填充的长度
21
+ # (sp_size - current_size % sp_size) % sp_size
22
+ # 保证了如果 current_size 已经是 sp_size 的倍数,padding_len 为 0。
23
+ # 否则,计算出最小的填充长度。
24
+ padding_len = (sp_size - current_size % sp_size) % sp_size
25
+
26
+ if padding_len == 0:
27
+ # 如果长度已经可以整除,直接返回原 tensor
28
+ return tensor, 0
29
+
30
+ # 构建 pad 元组
31
+ # torch.nn.functional.pad 的 pad 参数是从**最末尾的维度**开始,**成对** (后填充, 前填充) 指定的。
32
+ # 假设你的 tensor 是 [D0, D1, D2]
33
+ # 如果 dim=1 (第二维, D1),pad 应该是 (0, 0, padding_len, 0, 0, 0, ...)
34
+ #
35
+ # 由于我们需要在第二维 (dim=1) 的末尾进行填充,我们需要确定 pad 元组中对应 dim=1 的位置。
36
+ # 维度数量 D = tensor.dim()
37
+ # dim=0 对应 pad 元组的最后两位
38
+ # dim=1 对应 pad 元组的倒数第 4, 3 位
39
+ # dim=2 对应 pad 元组的倒数第 6, 5 位 (对于三维 tensor,即前两位)
40
+
41
+ # 在 dim 维度进行 '后填充' (在末尾添加)
42
+ # padding_dims 是一个长度为 2 * D 的元组,所有维度默认不填充
43
+ padding_dims = [0] * (2 * tensor.dim())
44
+
45
+ # 对应 dim 维度的 '后填充' (即 pad 元组中的偶数索引位置,从后往前数)
46
+ # 填充的位置是 (2 * tensor.dim() - 2 * dim - 2)
47
+ # 例如:D=3, dim=1 -> 2*3 - 2*1 - 2 = 2
48
+ # pad 元组为 (d2_start, d2_end, d1_start, d1_end, d0_start, d0_end)
49
+ # 我们要填充 d1_end,它在索引 2 的位置
50
+
51
+ # F.pad 要求的是 (最后维度 start, 最后维度 end, 倒数第二维度 start, 倒数第二维度 end, ...)
52
+ # 我们的 dim=1 是倒数第 (D - 1 - dim) + 1 = D - dim 个维度
53
+ # 它在 pad 元组中是倒数第 2 * (D - dim) 位和倒数第 2 * (D - dim) - 1 位
54
+ #
55
+ # 填充位置的索引 (从 0 开始, 从左往右):
56
+ # (2 * (tensor.dim() - dim - 1)) 是 '前填充' 的位置
57
+ # (2 * (tensor.dim() - dim - 1) + 1) 是 '后填充' 的位置
58
+ pad_index = 2 * (tensor.dim() - dim - 1) + 1
59
+
60
+ if pad_index < len(padding_dims):
61
+ padding_dims[pad_index] = padding_len
62
+ else:
63
+ raise ValueError("Invalid dimension index.")
64
+
65
+ # 转换回 tuple
66
+ pad = tuple(padding_dims)
67
+
68
+ # 使用 F.pad 进行填充,模式为 'constant'
69
+ padded_tensor = F.pad(tensor, pad=pad, mode='constant', value=pad_value)
70
+
71
+ return padded_tensor, padding_len
72
+
73
+
74
+ def depad_by_length(padded_tensor: torch.Tensor, depadding_len: int, dim: int = 1) -> torch.Tensor:
75
+ """
76
+ 在指定维度上去除末尾的 padding 部分。
77
+
78
+ Args:
79
+ padded_tensor: 已经经过 padding 的 PyTorch tensor。
80
+ depadding_len: 需要从末尾去除的长度。
81
+ dim: 需要去除 padding 的维度索引(默认为 1,即第二维)。
82
+
83
+ Returns:
84
+ depadded_tensor: 去除 padding 后的 tensor。
85
+ """
86
+
87
+ # 检查去除长度是否合理
88
+ current_size = padded_tensor.size(dim)
89
+ if depadding_len < 0:
90
+ raise ValueError("depadding_len 必须是非负数。")
91
+ if depadding_len > current_size:
92
+ raise ValueError(f"要去除的长度 {depadding_len} 大于当前维度长度 {current_size}。")
93
+
94
+ # 计算去除 padding 后的目标长度
95
+ target_size = current_size - depadding_len
96
+
97
+ # 构造切片操作所需的索引元组
98
+ # 对于所有维度,我们默认使用完整的切片 `:`
99
+ slices = [slice(None)] * padded_tensor.dim()
100
+
101
+ # 在指定维度 dim 上,我们只取从 0 到 target_size 的部分
102
+ # Python 切片 [0:target_size] 会保留 target_size 个元素,即去除了末尾的 depadding_len
103
+ slices[dim] = slice(0, target_size)
104
+
105
+ # 使用元组解包进行切片操作
106
+ depadded_tensor = padded_tensor[tuple(slices)]
107
+
108
+ return depadded_tensor
109
+
110
+
111
+
112
+
113
+ def pad_by_length(padded_tensor: torch.Tensor, padding_len: int, dim: int = 1,pad_value: float = 0.0) -> torch.Tensor:
114
+
115
+ if padding_len < 0:
116
+ raise ValueError("padding_len 必须是非负数。")
117
+
118
+ if dim < 0 or dim >= padded_tensor.dim():
119
+ raise ValueError(f"维度索引 {dim} 超出有效范围 [0, {padded_tensor.dim() - 1}]。")
120
+
121
+ # 构建padding参数
122
+ # F.pad需要为每个维度指定左右两边的padding长度
123
+ # 格式为: (最后一个维度的左边, 最后一个维度的右边, 倒数第二个维度的左边, 倒数第二个维度的右边, ...)
124
+ pad_tuple = [0] * (2 * padded_tensor.dim())
125
+
126
+ # 将指定维度右边的padding长度设置为padding_len
127
+ # F.pad的维度顺序是从最后一个维度开始的,所以需要进行转换
128
+ pad_idx = 2 * (padded_tensor.dim() - 1 - dim) + 1
129
+ pad_tuple[pad_idx] = padding_len
130
+
131
+ # 调用F.pad进行padding
132
+ padded_tensor = F.pad(padded_tensor, pad=tuple(pad_tuple), mode='constant', value=pad_value)
133
+
134
+ return padded_tensor
hyworldmirror/models/__init__.py ADDED
File without changes
hyworldmirror/models/heads/__init__.py ADDED
File without changes
hyworldmirror/models/heads/camera_head.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inspired by https://github.com/facebookresearch/vggt/blob/main/src/models/heads/camera_head.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from ..layers import Mlp, MlpFP32
7
+ from ..layers.block import Block, DistBlock
8
+
9
+
10
+ class CameraHead(nn.Module):
11
+ """
12
+ Camera head module: predicts camera parameters from token representations using iterative refinement
13
+
14
+ Processes dedicated camera tokens through a series of transformer blocks
15
+ """
16
+ def __init__(
17
+ self,
18
+ dim_in: int = 2048,
19
+ trunk_depth: int = 4,
20
+ num_heads: int = 16,
21
+ mlp_ratio: int = 4,
22
+ init_values: float = 0.01,
23
+ trans_act: str = "linear",
24
+ quat_act: str = "linear",
25
+ fl_act: str = "relu",
26
+ block_fn: nn.Module = Block,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.out_dim = 9
31
+ self.trans_act = trans_act
32
+ self.quat_act = quat_act
33
+ self.fl_act = fl_act
34
+ self.depth = trunk_depth
35
+
36
+ # Build refinement network using transformer block sequence
37
+ self.refine_net = nn.Sequential(
38
+ *[
39
+ block_fn(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
40
+ for _ in range(trunk_depth)
41
+ ]
42
+ )
43
+
44
+ # Normalization for camera tokens and network output
45
+ self.token_norm = nn.LayerNorm(dim_in)
46
+ self.out_norm = nn.LayerNorm(dim_in)
47
+
48
+ # Learnable initial camera parameter token
49
+ self.init_token = nn.Parameter(torch.zeros(1, 1, self.out_dim))
50
+ self.param_embed = nn.Linear(self.out_dim, dim_in)
51
+
52
+ # Generate adaptive normalization parameters: shift, scale, and gate
53
+ self.adapt_norm_gen = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
54
+
55
+ # Adaptive layer normalization (no learnable parameters)
56
+ self.adapt_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
57
+ # self.param_predictor = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)
58
+ self.param_predictor = MlpFP32(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)
59
+
60
+ def to(self, *args, **kwargs):
61
+ self.refine_net = self.refine_net.to(*args, **kwargs)
62
+ self.token_norm = self.token_norm.to(*args, **kwargs)
63
+ self.out_norm = self.out_norm.to(*args, **kwargs)
64
+ self.adapt_norm_gen = self.adapt_norm_gen.to(*args, **kwargs)
65
+ self.adapt_norm = self.adapt_norm.to(*args, **kwargs)
66
+ self.param_predictor = self.param_predictor.to(*args, **kwargs)
67
+
68
+ # keep these parameters in FP32
69
+ args, kwargs = MlpFP32.map_to_args_to_float(args, kwargs)
70
+ self.init_token = nn.Parameter(self.init_token.to(*args, **kwargs))
71
+ self.param_embed = self.param_embed.to(*args, **kwargs)
72
+
73
+ return self
74
+
75
+ def forward(self, feat_seq: list, steps: int = 4) -> list:
76
+ """
77
+ Forward pass to predict camera parameters
78
+
79
+ Args:
80
+ feat_seq: List of token tensors from network, last one used for prediction
81
+ steps: Number of iterative refinement steps, default 4
82
+
83
+ Returns:
84
+ List of predicted camera encodings (post-activation) from each iteration
85
+ """
86
+ # Use tokens from last block for camera prediction
87
+ latest_feat = feat_seq[-1]
88
+
89
+ # Extract camera tokens
90
+ cam_tokens = latest_feat[:, :, 0]
91
+ cam_tokens = self.token_norm(cam_tokens)
92
+
93
+ # Iteratively refine camera pose predictions
94
+ b, seq_len, feat_dim = cam_tokens.shape # seq_len expected to be 1
95
+ curr_pred = None
96
+ pred_seq = []
97
+
98
+ for step in range(steps):
99
+ # Use learned initial token for first iteration
100
+ if curr_pred is None:
101
+ net_input = self.param_embed(self.init_token.expand(b, seq_len, -1))
102
+ else:
103
+ curr_pred = curr_pred.detach()
104
+ net_input = self.param_embed(curr_pred)
105
+ net_input = net_input.to(cam_tokens.dtype)
106
+ norm_shift, norm_scale, norm_gate = self.adapt_norm_gen(net_input).chunk(3, dim=-1)
107
+ mod_cam_feat = norm_gate * self.apply_adaptive_modulation(self.adapt_norm(cam_tokens), norm_shift, norm_scale)
108
+ mod_cam_feat = mod_cam_feat + cam_tokens
109
+
110
+ proc_feat = self.refine_net(mod_cam_feat)
111
+ param_delta = self.param_predictor(self.out_norm(proc_feat))
112
+
113
+ if curr_pred is None:
114
+ curr_pred = param_delta
115
+ else:
116
+ curr_pred = curr_pred + param_delta
117
+
118
+ # Apply final activation functions for translation, quaternion, and field-of-view
119
+ activated_params = self.apply_camera_parameter_activation(curr_pred)
120
+ pred_seq.append(activated_params)
121
+
122
+ return pred_seq
123
+
124
+ def apply_camera_parameter_activation(self, params: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Apply activation functions to camera parameter components
127
+
128
+ Args:
129
+ params: Tensor containing camera parameters [translation, quaternion, focal_length]
130
+
131
+ Returns:
132
+ Activated camera parameters tensor
133
+ """
134
+ trans_vec = params[..., :3]
135
+ quat_vec = params[..., 3:7]
136
+ fl_vec = params[..., 7:] # or field of view
137
+
138
+ trans_vec = self.apply_parameter_activation(trans_vec, self.trans_act)
139
+ quat_vec = self.apply_parameter_activation(quat_vec, self.quat_act)
140
+ fl_vec = self.apply_parameter_activation(fl_vec, self.fl_act)
141
+
142
+ activated_params = torch.cat([trans_vec, quat_vec, fl_vec], dim=-1)
143
+ return activated_params
144
+
145
+ def apply_parameter_activation(self, tensor: torch.Tensor, act_type: str) -> torch.Tensor:
146
+ """
147
+ Apply specified activation function to parameter tensor
148
+
149
+ Args:
150
+ tensor: Tensor containing parameter values
151
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
152
+
153
+ Returns:
154
+ Activated parameter tensor
155
+ """
156
+ if act_type == "linear":
157
+ return tensor
158
+ elif act_type == "inv_log":
159
+ return self.apply_inverse_logarithm_transform(tensor)
160
+ elif act_type == "exp":
161
+ return torch.exp(tensor)
162
+ elif act_type == "relu":
163
+ return F.relu(tensor)
164
+ else:
165
+ raise ValueError(f"Unknown activation_type: {act_type}")
166
+
167
+ def apply_inverse_logarithm_transform(self, x: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
170
+
171
+ Args:
172
+ x: Input tensor
173
+
174
+ Returns:
175
+ Transformed tensor
176
+ """
177
+ return torch.sign(x) * (torch.expm1(torch.abs(x)))
178
+
179
+ def apply_adaptive_modulation(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
180
+ """
181
+ Apply adaptive modulation to input tensor using scaling and shifting parameters
182
+ """
183
+ # Modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
184
+ return x * (1 + scale) + shift
hyworldmirror/models/heads/dense_head.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inspired by https://github.com/DepthAnything/Depth-Anything-V2
2
+ from typing import List, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from ..layers.mlp import MlpFP32
10
+ from ..utils.grid import create_uv_grid, position_grid_to_embed
11
+
12
+
13
+ class _BaseDPTHead(nn.Module):
14
+ """Base class with shared DPT feature extraction: projects, resize, scratch, and fusion."""
15
+
16
+ def __init__(
17
+ self,
18
+ dim_in: int,
19
+ patch_size: int = 14,
20
+ features: int = 256,
21
+ out_channels: List[int] = [256, 512, 1024, 1024],
22
+ pos_embed: bool = True,
23
+ down_ratio: int = 1,
24
+ gradient_checkpoint: bool = False,
25
+ _cast_pos_embed_dtype: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ self.patch_size = patch_size
29
+ self.pos_embed = pos_embed
30
+ self.down_ratio = down_ratio
31
+ self.gradient_checkpoint = gradient_checkpoint
32
+ self._cast_pos_embed_dtype = _cast_pos_embed_dtype
33
+
34
+ self.norm = nn.LayerNorm(dim_in)
35
+ self.projects = nn.ModuleList([
36
+ nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0)
37
+ for oc in out_channels
38
+ ])
39
+ self.resize_layers = nn.ModuleList([
40
+ nn.ConvTranspose2d(
41
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
42
+ ),
43
+ nn.ConvTranspose2d(
44
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
45
+ ),
46
+ nn.Identity(),
47
+ nn.Conv2d(
48
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
49
+ ),
50
+ ])
51
+ self.scratch = _make_scratch(out_channels, features, expand=False)
52
+ self.scratch.stem_transpose = None
53
+ self.scratch.refinenet1 = _make_fusion_block(features)
54
+ self.scratch.refinenet2 = _make_fusion_block(features)
55
+ self.scratch.refinenet3 = _make_fusion_block(features)
56
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
57
+
58
+ head_features_1 = features
59
+ self.scratch.output_conv1 = nn.Conv2d(
60
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
61
+ )
62
+
63
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
64
+ patch_w = x.shape[-1]
65
+ patch_h = x.shape[-2]
66
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
67
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
68
+ pos_embed = pos_embed * ratio
69
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
70
+ if self._cast_pos_embed_dtype:
71
+ pos_embed = pos_embed.to(x.dtype)
72
+ return x + pos_embed
73
+
74
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
75
+ layer_1, layer_2, layer_3, layer_4 = features
76
+
77
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
78
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
79
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
80
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
81
+
82
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
83
+ del layer_4_rn, layer_4
84
+
85
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
86
+ del layer_3_rn, layer_3
87
+
88
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
89
+ del layer_2_rn, layer_2
90
+
91
+ out = self.scratch.refinenet1(out, layer_1_rn)
92
+ del layer_1_rn, layer_1
93
+
94
+ out = self.scratch.output_conv1(out)
95
+ return out
96
+
97
+ def _extract_fused_features(
98
+ self,
99
+ token_list: List[torch.Tensor],
100
+ B: int,
101
+ S: int,
102
+ H: int,
103
+ W: int,
104
+ patch_start_idx: int,
105
+ frame_start: int = None,
106
+ frame_end: int = None,
107
+ ) -> torch.Tensor:
108
+ """Extract multi-scale features from tokens, fuse via scratch network, and upsample."""
109
+ ph = H // self.patch_size
110
+ pw = W // self.patch_size
111
+
112
+ feats = []
113
+ for proj, resize, tokens in zip(self.projects, self.resize_layers, token_list):
114
+ patch_tokens = tokens[:, :, patch_start_idx:]
115
+ if frame_start is not None and frame_end is not None:
116
+ patch_tokens = patch_tokens[:, frame_start:frame_end]
117
+
118
+ patch_tokens = patch_tokens.reshape(B * S, -1, patch_tokens.shape[-1])
119
+ patch_tokens = self.norm(patch_tokens)
120
+
121
+ feat = patch_tokens.permute(0, 2, 1).reshape(B * S, patch_tokens.shape[-1], ph, pw)
122
+ feat = proj(feat)
123
+
124
+ if self.pos_embed:
125
+ feat = self._apply_pos_embed(feat, W, H)
126
+ feat = resize(feat)
127
+ feats.append(feat)
128
+
129
+ fused = checkpoint(self.scratch_forward, feats, use_reentrant=False) if self.gradient_checkpoint else self.scratch_forward(feats)
130
+ _interpolate_fn = lambda t: custom_interpolate(
131
+ t,
132
+ size=(
133
+ int(ph * self.patch_size / self.down_ratio),
134
+ int(pw * self.patch_size / self.down_ratio)
135
+ ),
136
+ mode="bilinear",
137
+ align_corners=True,
138
+ )
139
+ fused = checkpoint(_interpolate_fn, fused, use_reentrant=False) if self.gradient_checkpoint else _interpolate_fn(fused)
140
+
141
+ if self.pos_embed:
142
+ fused = self._apply_pos_embed(fused, W, H)
143
+
144
+ return fused
145
+
146
+
147
+ class DPTHead(_BaseDPTHead):
148
+ """
149
+ # DPT Head for dense prediction tasks.
150
+
151
+ # This module implements the DPT (Dense Prediction Transformer) head as proposed in
152
+ # "Vision Transformers for Dense Prediction" (https://arxiv.org/abs/2103.13413).
153
+ # It takes features from a vision transformer backbone and generates dense (per-pixel) predictions
154
+ # by fusing multi-scale features through a series of projection, upsampling, and refinement blocks.
155
+
156
+ # Args:
157
+ # dim_in (int): Number of input feature channels.
158
+ # patch_size (int, optional): Patch size used by the backbone, default is 14.
159
+ # output_dim (int, optional): Number of output channels, default is 4.
160
+ # activation (str, optional): Activation function type for the output head, default is "inv_log".
161
+ # conf_activation (str, optional): Activation function type for the confidence/output uncertainty head, default is "expp1".
162
+ # features (int, optional): Number of channels used in intermediate feature representations, default is 256.
163
+ # out_channels (List[int], optional): Number of channels for each intermediate multi-scale feature.
164
+ # intermediate_layer_idx (List[int], optional): Indices specifying which backbone layers to use for multi-scale fusion.
165
+ # pos_embed (bool, optional): Whether to add positional encoding to the features, default is True.
166
+ # feature_only (bool, optional): If True, only return intermediate features (skip final prediction and activations).
167
+ # down_ratio (int, optional): Downsampling ratio of the output predictions, default is 1 (no downsampling).
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ dim_in: int,
173
+ patch_size: int = 14,
174
+ output_dim: int = 4,
175
+ activation: str = "inv_log+expp1",
176
+ features: int = 256,
177
+ out_channels: List[int] = [256, 512, 1024, 1024],
178
+ pos_embed: bool = True,
179
+ down_ratio: int = 1,
180
+ is_gsdpt: bool = False,
181
+ enable_depth_mask: bool = False,
182
+ gradient_checkpoint: bool = False,
183
+ ) -> None:
184
+ super().__init__(
185
+ dim_in=dim_in, patch_size=patch_size, features=features,
186
+ out_channels=out_channels, pos_embed=pos_embed,
187
+ down_ratio=down_ratio, gradient_checkpoint=gradient_checkpoint,
188
+ )
189
+ self.activation = activation
190
+ self.is_gsdpt = is_gsdpt
191
+ self.enable_depth_mask = enable_depth_mask
192
+
193
+ head_features_2 = 32
194
+ conv2_in_channels = features // 2
195
+
196
+ self.scratch.output_conv2 = nn.Sequential(
197
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
198
+ nn.ReLU(inplace=True),
199
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
200
+ )
201
+ if self.is_gsdpt:
202
+ self.input_merger = nn.Sequential(
203
+ nn.Conv2d(3, conv2_in_channels, 7, 1, 3),
204
+ nn.ReLU()
205
+ )
206
+
207
+ def to(self, *args, **kwargs):
208
+ self.norm = self.norm.to(*args, **kwargs)
209
+ self.projects = self.projects.to(*args, **kwargs)
210
+ self.resize_layers = self.resize_layers.to(*args, **kwargs)
211
+ if self.is_gsdpt:
212
+ self.input_merger = self.input_merger.to(*args, **kwargs)
213
+ for key in ('layer1_rn', 'layer2_rn', 'layer3_rn', 'layer4_rn',
214
+ 'refinenet1', 'refinenet2', 'refinenet3', 'refinenet4',
215
+ 'output_conv1'):
216
+ if not hasattr(self.scratch, key):
217
+ continue
218
+ setattr(self.scratch, key, getattr(self.scratch, key).to(*args, **kwargs))
219
+
220
+ # keep output_conv2 in FP32
221
+ args, kwargs = MlpFP32.map_to_args_to_float(args, kwargs)
222
+ self.scratch.output_conv2 = self.scratch.output_conv2.to(*args, **kwargs)
223
+
224
+ return self
225
+
226
+ def forward(
227
+ self,
228
+ token_list: List[torch.Tensor],
229
+ images: torch.Tensor,
230
+ patch_start_idx: int,
231
+ frames_chunk_size: int = 8,
232
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
233
+ """
234
+ Forward pass with optional frame chunking for memory efficiency.
235
+
236
+ Args:
237
+ token_list: List of token tensors from transformer, each [B, N, C]
238
+ images: Input images [B, S, 3, H, W], range [0, 1]
239
+ patch_start_idx: Starting index of patch tokens
240
+ frames_chunk_size: Number of frames per chunk. If None or >= S, process all at once
241
+ gradient_checkpoint: Whether to use gradient checkpointing
242
+
243
+ Returns:
244
+ For is_gsdpt: predictions [B, S, ...]
245
+ Otherwise: (predictions, confidence), [B, S, X, H, W] and [B, S, 1, H, W]
246
+ """
247
+ B, S, _, H, W = images.shape
248
+
249
+ # Process all frames together if chunk size not specified or large enough
250
+ if frames_chunk_size is None or frames_chunk_size >= S:
251
+ return self._forward_impl(token_list, images, patch_start_idx)
252
+
253
+ assert frames_chunk_size > 0
254
+
255
+ # Process frames in chunks
256
+ preds_chunks = []
257
+ conf_chunks = []
258
+ gs_chunks = []
259
+ depth_mask_chunks = []
260
+
261
+ for frame_start in range(0, S, frames_chunk_size):
262
+ frame_end = min(frame_start + frames_chunk_size, S)
263
+
264
+ if self.is_gsdpt:
265
+ if self.enable_depth_mask:
266
+ gs, preds, conf, depth_mask = self._forward_impl(
267
+ token_list, images, patch_start_idx, frame_start, frame_end
268
+ )
269
+ gs_chunks.append(gs)
270
+ preds_chunks.append(preds)
271
+ conf_chunks.append(conf)
272
+ depth_mask_chunks.append(depth_mask)
273
+ else:
274
+ gs, preds, conf = self._forward_impl(
275
+ token_list, images, patch_start_idx, frame_start, frame_end
276
+ )
277
+ gs_chunks.append(gs)
278
+ preds_chunks.append(preds)
279
+ conf_chunks.append(conf)
280
+ else:
281
+ if self.enable_depth_mask:
282
+ preds, conf, depth_mask = self._forward_impl(
283
+ token_list, images, patch_start_idx, frame_start, frame_end
284
+ )
285
+ preds_chunks.append(preds)
286
+ conf_chunks.append(conf)
287
+ depth_mask_chunks.append(depth_mask)
288
+ else:
289
+ preds, conf = self._forward_impl(
290
+ token_list, images, patch_start_idx, frame_start, frame_end
291
+ )
292
+ preds_chunks.append(preds)
293
+ conf_chunks.append(conf)
294
+
295
+ # Concatenate chunks along frame dimension
296
+ if self.is_gsdpt:
297
+ if self.enable_depth_mask:
298
+ return (
299
+ torch.cat(gs_chunks, dim=1),
300
+ torch.cat(preds_chunks, dim=1),
301
+ torch.cat(conf_chunks, dim=1),
302
+ torch.cat(depth_mask_chunks, dim=1),
303
+ )
304
+ return torch.cat(gs_chunks, dim=1), torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1)
305
+ else:
306
+ if self.enable_depth_mask:
307
+ return torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1), torch.cat(depth_mask_chunks, dim=1)
308
+ else:
309
+ return torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1)
310
+
311
+ def _forward_impl(
312
+ self,
313
+ token_list: List[torch.Tensor],
314
+ images: torch.Tensor,
315
+ patch_start_idx: int,
316
+ frame_start: int = None,
317
+ frame_end: int = None,
318
+ ) -> torch.Tensor:
319
+ """
320
+ Core forward implementation for DPT head.
321
+
322
+ Args:
323
+ token_list: List of transformer tokens from each layer, [B, S, N, C]
324
+ images: Input images [B, S, 3, H, W]
325
+ patch_start_idx: Starting index of patch tokens
326
+ frame_start: Start index for frame chunking (optional)
327
+ frame_end: End index for frame chunking (optional)
328
+
329
+ Returns:
330
+ If is_gsdpt: (features, preds, conf)
331
+ Else: (preds, conf)
332
+ """
333
+ if frame_start is not None and frame_end is not None:
334
+ images = images[:, frame_start:frame_end].contiguous()
335
+
336
+ B, S, _, H, W = images.shape
337
+
338
+ fused = self._extract_fused_features(token_list, B, S, H, W, patch_start_idx, frame_start, frame_end)
339
+
340
+ # Generate predictions and confidence
341
+ if self.is_gsdpt:
342
+ out = self.scratch.output_conv2(fused.float().contiguous())
343
+ if self.enable_depth_mask:
344
+ preds, conf, depth_mask = self.activate_head(out, activation=self.activation)
345
+ else:
346
+ preds, conf = self.activate_head(out, activation=self.activation)
347
+ preds = preds.reshape(B, S, *preds.shape[1:])
348
+ conf = conf.reshape(B, S, *conf.shape[1:])
349
+
350
+ # Merge direct image features
351
+ img_flat = images.reshape(B * S, -1, H, W)
352
+ img_feat = self.input_merger(img_flat)
353
+ fused = fused + img_feat
354
+ fused = fused.reshape(B, S, *fused.shape[1:]).float().contiguous()
355
+ if self.enable_depth_mask:
356
+ depth_mask = depth_mask.reshape(B, S, *depth_mask.shape[1:])
357
+ return fused, preds, conf, depth_mask
358
+ return fused, preds, conf
359
+ else:
360
+ out = self.scratch.output_conv2(fused.float().contiguous())
361
+ if self.enable_depth_mask:
362
+ preds, conf, depth_mask = self.activate_head(out, activation=self.activation)
363
+ preds = preds.reshape(B, S, *preds.shape[1:])
364
+ conf = conf.reshape(B, S, *conf.shape[1:])
365
+ depth_mask = depth_mask.reshape(B, S, *depth_mask.shape[1:])
366
+ return preds, conf, depth_mask
367
+ else:
368
+ preds, conf = self.activate_head(out, activation=self.activation)
369
+ preds = preds.reshape(B, S, *preds.shape[1:])
370
+ conf = conf.reshape(B, S, *conf.shape[1:])
371
+ return preds, conf
372
+
373
+ def activate_head(self, out_head: torch.Tensor, activation: str = "inv_log+expp1") -> Tuple[torch.Tensor, torch.Tensor]:
374
+ """
375
+ Process network output to extract attribute (e.g. points, depth, etc.) and confidence values.
376
+
377
+ Args:
378
+ out_head: Network output tensor (B, C, H, W)
379
+ activation: Activation type for processing (e.g., "inv_log+expp1")
380
+
381
+ Returns:
382
+ Tuple of (attribute tensor, confidence tensor)
383
+ """
384
+ # Parse activation string
385
+ if self.enable_depth_mask:
386
+ act_attr, act_conf, act_depth_mask = (activation.split("+") if "+" in activation else (activation, "expp1", "linear"))
387
+
388
+ # (B,C,H,W) -> (B,H,W,C)
389
+ feat = out_head.permute(0, 2, 3, 1)
390
+ attr, conf, depth_mask = feat[..., :-2], feat[..., -2], feat[..., -1]
391
+ else:
392
+ act_attr, act_conf = (activation.split("+") if "+" in activation else (activation, "expp1"))
393
+
394
+ # (B,C,H,W) -> (B,H,W,C)
395
+ feat = out_head.permute(0, 2, 3, 1)
396
+ attr, conf = feat[..., :-1], feat[..., -1]
397
+
398
+ # Map point activations to lambdas for clarity and conciseness
399
+ attr_activations = {
400
+ "norm_exp": lambda x: (x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)) * torch.expm1(x.norm(dim=-1, keepdim=True)),
401
+ "norm": lambda x: x / x.norm(dim=-1, keepdim=True),
402
+ "exp": torch.exp,
403
+ "relu": F.relu,
404
+ "inv_log": self._apply_inverse_log_transform,
405
+ "xy_inv_log": lambda x: torch.cat([
406
+ x[..., :2] * self._apply_inverse_log_transform(x[..., 2:]),
407
+ self._apply_inverse_log_transform(x[..., 2:])
408
+ ], dim=-1),
409
+ "sigmoid": torch.sigmoid,
410
+ "linear": lambda x: x
411
+ }
412
+
413
+ if act_attr not in attr_activations:
414
+ raise ValueError(f"Unknown attribute activation: {act_attr}")
415
+ attr_out = attr_activations[act_attr](attr)
416
+
417
+ # Confidence activation mapping
418
+ conf_activations = {
419
+ "expp1": lambda c: 1 + c.exp(),
420
+ "expp0": torch.exp,
421
+ "sigmoid": torch.sigmoid
422
+ }
423
+ if act_conf not in conf_activations:
424
+ raise ValueError(f"Unknown confidence activation: {act_conf}")
425
+ conf_out = conf_activations[act_conf](conf)
426
+
427
+ if self.enable_depth_mask:
428
+ depth_mask_activations = {
429
+ "sigmoid": torch.sigmoid,
430
+ "linear": lambda x: x,
431
+ }
432
+ if act_depth_mask not in depth_mask_activations:
433
+ raise ValueError(f"Unknown depth mask activation: {act_depth_mask}")
434
+ depth_mask_out = depth_mask_activations[act_depth_mask](depth_mask)
435
+ return attr_out, conf_out, depth_mask_out
436
+ else:
437
+ return attr_out, conf_out
438
+
439
+ def _apply_inverse_log_transform(self, input_tensor: torch.Tensor) -> torch.Tensor:
440
+ """
441
+ Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
442
+
443
+ Args:
444
+ input_tensor: Input tensor
445
+
446
+ Returns:
447
+ Transformed tensor
448
+ """
449
+ return torch.sign(input_tensor) * (torch.expm1(torch.abs(input_tensor)))
450
+
451
+
452
+
453
+ ################################################################################
454
+ # DPT Modules
455
+ ################################################################################
456
+
457
+
458
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
459
+ return FeatureFusionBlock(
460
+ features,
461
+ nn.ReLU(inplace=True),
462
+ deconv=False,
463
+ bn=False,
464
+ expand=False,
465
+ align_corners=True,
466
+ size=size,
467
+ has_residual=has_residual,
468
+ groups=groups,
469
+ )
470
+
471
+
472
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
473
+ scratch = nn.Module()
474
+ out_shape1 = out_shape
475
+ out_shape2 = out_shape
476
+ out_shape3 = out_shape
477
+ if len(in_shape) >= 4:
478
+ out_shape4 = out_shape
479
+
480
+ if expand:
481
+ out_shape1 = out_shape
482
+ out_shape2 = out_shape * 2
483
+ out_shape3 = out_shape * 4
484
+ if len(in_shape) >= 4:
485
+ out_shape4 = out_shape * 8
486
+
487
+ scratch.layer1_rn = nn.Conv2d(
488
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
489
+ )
490
+ scratch.layer2_rn = nn.Conv2d(
491
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
492
+ )
493
+ scratch.layer3_rn = nn.Conv2d(
494
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
495
+ )
496
+ if len(in_shape) >= 4:
497
+ scratch.layer4_rn = nn.Conv2d(
498
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
499
+ )
500
+ return scratch
501
+
502
+
503
+ class ResidualConvUnit(nn.Module):
504
+ """Residual convolution module with skip connection."""
505
+
506
+ def __init__(self, features, activation, bn, groups=1):
507
+ """Initialize ResidualConvUnit.
508
+
509
+ Args:
510
+ features (int): Number of input/output feature channels
511
+ activation: Activation function to use
512
+ bn (bool): Whether to use batch normalization (currently unused)
513
+ groups (int): Number of groups for grouped convolution
514
+ """
515
+ super().__init__()
516
+
517
+ self.bn = bn
518
+ self.groups = groups
519
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
520
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
521
+
522
+ self.norm1 = None
523
+ self.norm2 = None
524
+
525
+ self.activation = activation
526
+ self.skip_add = nn.quantized.FloatFunctional()
527
+
528
+ def forward(self, x):
529
+ """Forward pass with residual connection.
530
+
531
+ Args:
532
+ x (tensor): Input tensor of shape (B, C, H, W)
533
+
534
+ Returns:
535
+ tensor: Output tensor of shape (B, C, H, W) with residual added
536
+ """
537
+
538
+ out = self.activation(x)
539
+ out = self.conv1(out)
540
+ if self.norm1 is not None:
541
+ out = self.norm1(out)
542
+
543
+ out = self.activation(out)
544
+ out = self.conv2(out)
545
+ if self.norm2 is not None:
546
+ out = self.norm2(out)
547
+
548
+ return self.skip_add.add(out, x)
549
+
550
+
551
+ class FeatureFusionBlock(nn.Module):
552
+ """Feature fusion block."""
553
+
554
+ def __init__(
555
+ self,
556
+ features,
557
+ activation,
558
+ deconv=False,
559
+ bn=False,
560
+ expand=False,
561
+ align_corners=True,
562
+ size=None,
563
+ has_residual=True,
564
+ groups=1,
565
+ ):
566
+ """Initialize FeatureFusionBlock.
567
+
568
+ Args:
569
+ features (int): Number of input/output feature channels
570
+ activation: Activation function to use
571
+ deconv (bool): Whether to use deconvolution
572
+ bn (bool): Whether to use batch normalization
573
+ expand (bool): Whether to expand features (halve output channels)
574
+ align_corners (bool): Align corners for interpolation
575
+ size: Target size for upsampling
576
+ has_residual (bool): Whether to include residual connection
577
+ groups (int): Number of groups for grouped convolution
578
+ """
579
+ super(FeatureFusionBlock, self).__init__()
580
+
581
+ self.deconv = deconv
582
+ self.align_corners = align_corners
583
+ self.groups = groups
584
+ self.expand = expand
585
+ out_features = features
586
+ if self.expand == True:
587
+ out_features = features // 2
588
+
589
+ self.out_conv = nn.Conv2d(
590
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
591
+ )
592
+
593
+ if has_residual:
594
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
595
+
596
+ self.has_residual = has_residual
597
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
598
+
599
+ self.skip_add = nn.quantized.FloatFunctional()
600
+ self.size = size
601
+
602
+ def forward(self, *xs, size=None):
603
+ """Forward pass through the feature fusion block.
604
+
605
+ Args:
606
+ *xs: Variable number of input tensors. First tensor is the main input,
607
+ second tensor (if present) is used for residual connection.
608
+ size: Optional target size for upsampling. If None, uses self.size or scale_factor=2.
609
+
610
+ Returns:
611
+ torch.Tensor: Fused and upsampled output tensor.
612
+ """
613
+ output = xs[0]
614
+
615
+ if self.has_residual:
616
+ res = self.resConfUnit1(xs[1])
617
+ output = self.skip_add.add(output, res)
618
+
619
+ output = self.resConfUnit2(output)
620
+
621
+ if (size is None) and (self.size is None):
622
+ modifier = {"scale_factor": 2}
623
+ elif size is None:
624
+ modifier = {"size": self.size}
625
+ else:
626
+ modifier = {"size": size}
627
+
628
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
629
+ output = self.out_conv(output)
630
+
631
+ return output
632
+
633
+
634
+ def custom_interpolate(
635
+ x: torch.Tensor,
636
+ size: Tuple[int, int] = None,
637
+ scale_factor: float = None,
638
+ mode: str = "bilinear",
639
+ align_corners: bool = True,
640
+ ) -> torch.Tensor:
641
+ """
642
+ Custom interpolation function to handle large tensors by chunking.
643
+
644
+ Avoids INT_MAX overflow issues in nn.functional.interpolate when dealing with
645
+ very large input tensors by splitting them into smaller chunks.
646
+
647
+ Args:
648
+ x: Input tensor to interpolate
649
+ size: Target output size (H, W)
650
+ scale_factor: Scaling factor if size is not provided
651
+ mode: Interpolation mode (default: "bilinear")
652
+ align_corners: Whether to align corners in interpolation
653
+
654
+ Returns:
655
+ Interpolated tensor
656
+ """
657
+ if size is None:
658
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
659
+
660
+ INT_MAX = 1610612736
661
+
662
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
663
+
664
+ if input_elements > INT_MAX:
665
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
666
+ interpolated_chunks = [
667
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
668
+ ]
669
+ x = torch.cat(interpolated_chunks, dim=0)
670
+ return x.contiguous()
671
+ else:
672
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
hyworldmirror/models/heads/gs_head.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .dense_head import _BaseDPTHead
7
+
8
+
9
+ class GSFeatHead(_BaseDPTHead):
10
+ """
11
+ GS feature head that only outputs fused GS features.
12
+
13
+ This head is used when gs depth is disabled. It skips the prediction
14
+ conv (output_conv2) and returns only the fused GS feature map.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ dim_in: int,
20
+ patch_size: int = 14,
21
+ features: int = 256,
22
+ out_channels: List[int] = [256, 512, 1024, 1024],
23
+ pos_embed: bool = True,
24
+ down_ratio: int = 1,
25
+ gradient_checkpoint: bool = False,
26
+ ) -> None:
27
+ super().__init__(
28
+ dim_in=dim_in, patch_size=patch_size, features=features,
29
+ out_channels=out_channels, pos_embed=pos_embed,
30
+ down_ratio=down_ratio, gradient_checkpoint=gradient_checkpoint,
31
+ _cast_pos_embed_dtype=False,
32
+ )
33
+ conv2_in_channels = features // 2
34
+ self.input_merger = nn.Sequential(
35
+ nn.Conv2d(3, conv2_in_channels, 7, 1, 3),
36
+ nn.ReLU(),
37
+ )
38
+
39
+ def forward(
40
+ self,
41
+ token_list: List[torch.Tensor],
42
+ images: torch.Tensor,
43
+ patch_start_idx: int,
44
+ frames_chunk_size: int = 8,
45
+ ) -> torch.Tensor:
46
+ B, S, _, H, W = images.shape
47
+
48
+ if frames_chunk_size is None or frames_chunk_size >= S:
49
+ return self._forward_impl(token_list, images, patch_start_idx)
50
+
51
+ assert frames_chunk_size > 0
52
+ gs_chunks = []
53
+ for frame_start in range(0, S, frames_chunk_size):
54
+ frame_end = min(frame_start + frames_chunk_size, S)
55
+ gs = self._forward_impl(
56
+ token_list, images, patch_start_idx, frame_start, frame_end
57
+ )
58
+ gs_chunks.append(gs)
59
+
60
+ return torch.cat(gs_chunks, dim=1)
61
+
62
+ def _forward_impl(
63
+ self,
64
+ token_list: List[torch.Tensor],
65
+ images: torch.Tensor,
66
+ patch_start_idx: int,
67
+ frame_start: int = None,
68
+ frame_end: int = None,
69
+ ) -> torch.Tensor:
70
+ if frame_start is not None and frame_end is not None:
71
+ images = images[:, frame_start:frame_end].contiguous()
72
+
73
+ B, S, _, H, W = images.shape
74
+
75
+ fused = self._extract_fused_features(
76
+ token_list, B, S, H, W, patch_start_idx, frame_start, frame_end
77
+ )
78
+
79
+ img_flat = images.reshape(B * S, -1, H, W)
80
+ img_feat = self.input_merger(img_flat)
81
+ fused = fused + img_feat
82
+ fused = fused.reshape(B, S, *fused.shape[1:])
83
+ return fused
hyworldmirror/models/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .mlp import Mlp, MlpFP32
2
+ from .patch_embed import PatchEmbed, PatchEmbed_Mlp
3
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
4
+ from .block import NestedTensorBlock
5
+ from .attention import MemEffAttention
hyworldmirror/models/layers/attention.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
4
+
5
+ from torch import Tensor
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ import torch
9
+
10
+ try:
11
+ from flash_attn_interface import flash_attn_func as flash_attn_func_v3
12
+ _USE_FLASH_ATTN_V3 = True
13
+ except ImportError:
14
+ from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_v2
15
+ _USE_FLASH_ATTN_V3 = False
16
+ from ...comm.padding import minimal_pad_to_divisible, depad_by_length, pad_by_length
17
+ import torch.distributed as dist
18
+ from ...comm.communication import _All2All, _Allgather
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def _compute_qkv(self, x: Tensor):
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
55
+ return q, k, v, B, N, C
56
+
57
+ def _apply_attention(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
58
+ if q.dtype==torch.bfloat16 or q.dtype==torch.float16:
59
+ if q.is_contiguous():
60
+ q = q.transpose(1,2)
61
+ else:
62
+ q = q.transpose(1, 2).contiguous()
63
+ if k.is_contiguous():
64
+ k = k.transpose(1, 2)
65
+ else:
66
+ k = k.transpose(1, 2).contiguous()
67
+ if v.is_contiguous():
68
+ v = v.transpose(1, 2)
69
+ else:
70
+ v = v.transpose(1, 2).contiguous()
71
+ if _USE_FLASH_ATTN_V3:
72
+ x = flash_attn_func_v3(q, k, v)
73
+ else:
74
+ x = flash_attn_func_v2(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
75
+ if x.is_contiguous():
76
+ x = x.transpose(1, 2)
77
+ else:
78
+ x = x.transpose(1, 2).contiguous()
79
+ else:
80
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
81
+ return x
82
+
83
+ def _project_output(self, x: Tensor, B: int, N: int, C: int) -> Tensor:
84
+ x = x.transpose(1, 2).reshape(B, N, C)
85
+ x = self.proj(x)
86
+ x = self.proj_drop(x)
87
+ return x
88
+
89
+ def forward(self, x: Tensor, pos=None) -> Tensor:
90
+ q, k, v, B, N, C = self._compute_qkv(x)
91
+
92
+ if self.rope is not None:
93
+ q = self.rope(q, pos)
94
+ k = self.rope(k, pos)
95
+
96
+ x = self._apply_attention(q, k, v)
97
+ return self._project_output(x, B, N, C)
98
+
99
+ class DistAttention(Attention):
100
+ def forward(self, x: Tensor, pos=None, sp_size=1, sp_group=None, padding_tokens=0) -> Tensor:
101
+
102
+ q, k, v, B, N, C = self._compute_qkv(x)
103
+
104
+ if sp_size>1:
105
+
106
+ q = _All2All.apply(q,1,2,sp_group,False)
107
+ k = _All2All.apply(k,1,2,sp_group,False)
108
+ v = _All2All.apply(v,1,2,sp_group,False)
109
+ q = depad_by_length(q,padding_tokens,2)
110
+ k = depad_by_length(k,padding_tokens,2)
111
+ v = depad_by_length(v,padding_tokens,2)
112
+
113
+ if self.rope is not None:
114
+ q = self.rope(q, pos)
115
+ k = self.rope(k, pos)
116
+
117
+ x = self._apply_attention(q, k, v)
118
+
119
+ if sp_size>1:
120
+ x = pad_by_length(x,padding_tokens,2,0)
121
+ x = _All2All.apply(x,2,1,sp_group,False)
122
+
123
+ return self._project_output(x, B, N, C)
124
+
125
+
126
+ class MemEffAttention(Attention):
127
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
128
+ assert pos is None
129
+ if attn_bias is not None:
130
+ raise AssertionError("xFormers is required for using nested tensors")
131
+ return super().forward(x)
hyworldmirror/models/layers/block.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
4
+
5
+ from typing import Callable, List, Any, Tuple, Dict
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+
10
+ from .attention import Attention, DistAttention
11
+ from .drop_path import DropPath
12
+ from .layer_scale import LayerScale
13
+ from .mlp import Mlp
14
+
15
+
16
+ XFORMERS_AVAILABLE = False
17
+
18
+ def modulate(x, shift, scale):
19
+ return x * (1 + scale.unsqueeze(2)) + shift.unsqueeze(2)
20
+
21
+ class Block(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int,
26
+ mlp_ratio: float = 4.0,
27
+ qkv_bias: bool = True,
28
+ proj_bias: bool = True,
29
+ ffn_bias: bool = True,
30
+ drop: float = 0.0,
31
+ attn_drop: float = 0.0,
32
+ init_values=None,
33
+ drop_path: float = 0.0,
34
+ act_layer: Callable[..., nn.Module] = nn.GELU,
35
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
36
+ attn_class: Callable[..., nn.Module] = Attention,
37
+ ffn_layer: Callable[..., nn.Module] = Mlp,
38
+ qk_norm: bool = False,
39
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
40
+ rope=None
41
+ ) -> None:
42
+ super().__init__()
43
+
44
+ self.norm1 = norm_layer(dim)
45
+
46
+ self.attn = attn_class(
47
+ dim,
48
+ num_heads=num_heads,
49
+ qkv_bias=qkv_bias,
50
+ proj_bias=proj_bias,
51
+ attn_drop=attn_drop,
52
+ proj_drop=drop,
53
+ qk_norm=qk_norm,
54
+ fused_attn=fused_attn,
55
+ rope=rope,
56
+ )
57
+
58
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
59
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
60
+
61
+ self.norm2 = norm_layer(dim)
62
+ mlp_hidden_dim = int(dim * mlp_ratio)
63
+ self.mlp = ffn_layer(
64
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
65
+ )
66
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
67
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
68
+
69
+ self.sample_drop_ratio = drop_path
70
+
71
+
72
+ def forward(self, x: Tensor, pos=None) -> Tensor:
73
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
74
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
75
+
76
+ def ffn_residual_func(x: Tensor) -> Tensor:
77
+ return self.ls2(self.mlp(self.norm2(x)))
78
+
79
+ if self.training and self.sample_drop_ratio > 0.1:
80
+ # the overhead is compensated only for a drop path rate larger than 0.1
81
+ x = drop_add_residual_stochastic_depth(
82
+ x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
83
+ )
84
+ x = drop_add_residual_stochastic_depth(
85
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
86
+ )
87
+ elif self.training and self.sample_drop_ratio > 0.0:
88
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
89
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
90
+ else:
91
+ x = x + attn_residual_func(x, pos=pos)
92
+ x = x + ffn_residual_func(x)
93
+ return x
94
+
95
+ class DistBlock(Block):
96
+ def __init__(self, *args, attn_class: Callable[..., nn.Module] = DistAttention, **kwargs):
97
+ super().__init__(*args, attn_class=attn_class, **kwargs)
98
+
99
+ def forward(self, x: Tensor, pos=None, sp_size=1,sp_group=None,padding_tokens=0,block_type = None, token_shape=None) -> Tensor:
100
+ def attn_residual_func(x: Tensor, pos=None, sp_size=1,sp_group=None,padding_tokens=0) -> Tensor:
101
+ return self.ls1(self.attn(self.norm1(x), pos=pos,sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens))
102
+
103
+ def ffn_residual_func(x: Tensor) -> Tensor:
104
+ return self.ls2(self.mlp(self.norm2(x)))
105
+
106
+ if self.training and self.sample_drop_ratio > 0.1:
107
+ # the overhead is compensated only for a drop path rate larger than 0.1
108
+ x = drop_add_residual_stochastic_depth(
109
+ x, pos=pos, sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens,residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
110
+ )
111
+ x = drop_add_residual_stochastic_depth(
112
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
113
+ )
114
+ elif self.training and self.sample_drop_ratio > 0.0:
115
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos,sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens))
116
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
117
+ else:
118
+ x = x + attn_residual_func(x, pos=pos,sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens)
119
+ x = x + ffn_residual_func(x)
120
+ return x
121
+
122
+
123
+ def drop_add_residual_stochastic_depth(
124
+ x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
125
+ ) -> Tensor:
126
+ # 1) extract subset using permutation
127
+ b, n, d = x.shape
128
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
129
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
130
+ x_subset = x[brange]
131
+
132
+ # 2) apply residual_func to get residual
133
+ if pos is not None:
134
+ # if necessary, apply rope to the subset
135
+ pos = pos[brange]
136
+ residual = residual_func(x_subset, pos=pos)
137
+ else:
138
+ residual = residual_func(x_subset)
139
+
140
+ x_flat = x.flatten(1)
141
+ residual = residual.flatten(1)
142
+
143
+ residual_scale_factor = b / sample_subset_size
144
+
145
+ # 3) add the residual
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ return x_plus_residual.view_as(x)
148
+
149
+
150
+ def get_branges_scales(x, sample_drop_ratio=0.0):
151
+ b, n, d = x.shape
152
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
153
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
154
+ residual_scale_factor = b / sample_subset_size
155
+ return brange, residual_scale_factor
156
+
157
+
158
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
159
+ if scaling_vector is None:
160
+ x_flat = x.flatten(1)
161
+ residual = residual.flatten(1)
162
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
163
+ else:
164
+ x_plus_residual = scaled_index_add(
165
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
166
+ )
167
+ return x_plus_residual
168
+
169
+
170
+ attn_bias_cache: Dict[Tuple, Any] = {}
171
+
172
+
173
+ def get_attn_bias_and_cat(x_list, branges=None):
174
+ """
175
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
176
+ """
177
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
178
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
179
+ if all_shapes not in attn_bias_cache.keys():
180
+ seqlens = []
181
+ for b, x in zip(batch_sizes, x_list):
182
+ for _ in range(b):
183
+ seqlens.append(x.shape[1])
184
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
185
+ attn_bias._batch_sizes = batch_sizes
186
+ attn_bias_cache[all_shapes] = attn_bias
187
+
188
+ if branges is not None:
189
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
190
+ else:
191
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
192
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
193
+
194
+ return attn_bias_cache[all_shapes], cat_tensors
195
+
196
+
197
+ def drop_add_residual_stochastic_depth_list(
198
+ x_list: List[Tensor],
199
+ residual_func: Callable[[Tensor, Any], Tensor],
200
+ sample_drop_ratio: float = 0.0,
201
+ scaling_vector=None,
202
+ ) -> Tensor:
203
+ # 1) generate random set of indices for dropping samples in the batch
204
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
205
+ branges = [s[0] for s in branges_scales]
206
+ residual_scale_factors = [s[1] for s in branges_scales]
207
+
208
+ # 2) get attention bias and index+concat the tensors
209
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
210
+
211
+ # 3) apply residual_func to get residual, and split the result
212
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
213
+
214
+ outputs = []
215
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
216
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
217
+ return outputs
218
+
219
+
220
+ class NestedTensorBlock(Block):
221
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
222
+ """
223
+ x_list contains a list of tensors to nest together and run
224
+ """
225
+ assert isinstance(self.attn, MemEffAttention)
226
+
227
+ if self.training and self.sample_drop_ratio > 0.0:
228
+
229
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
230
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
231
+
232
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
233
+ return self.mlp(self.norm2(x))
234
+
235
+ x_list = drop_add_residual_stochastic_depth_list(
236
+ x_list,
237
+ residual_func=attn_residual_func,
238
+ sample_drop_ratio=self.sample_drop_ratio,
239
+ scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
240
+ )
241
+ x_list = drop_add_residual_stochastic_depth_list(
242
+ x_list,
243
+ residual_func=ffn_residual_func,
244
+ sample_drop_ratio=self.sample_drop_ratio,
245
+ scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
246
+ )
247
+ return x_list
248
+ else:
249
+
250
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
251
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
252
+
253
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
254
+ return self.ls2(self.mlp(self.norm2(x)))
255
+
256
+ attn_bias, x = get_attn_bias_and_cat(x_list)
257
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
258
+ x = x + ffn_residual_func(x)
259
+ return attn_bias.split(x)
260
+
261
+ def forward(self, x_or_x_list):
262
+ if isinstance(x_or_x_list, Tensor):
263
+ return super().forward(x_or_x_list)
264
+ elif isinstance(x_or_x_list, list):
265
+ if not XFORMERS_AVAILABLE:
266
+ raise AssertionError("xFormers is required for using nested tensors")
267
+ return self.forward_nested(x_or_x_list)
268
+ else:
269
+ raise AssertionError
hyworldmirror/models/layers/drop_path.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
4
+
5
+
6
+ from torch import nn
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
10
+ if drop_prob == 0.0 or not training:
11
+ return x
12
+ keep_prob = 1 - drop_prob
13
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
14
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
15
+ if keep_prob > 0.0:
16
+ random_tensor.div_(keep_prob)
17
+ output = x * random_tensor
18
+ return output
19
+
20
+
21
+ class DropPath(nn.Module):
22
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
23
+
24
+ def __init__(self, drop_prob=None):
25
+ super(DropPath, self).__init__()
26
+ self.drop_prob = drop_prob
27
+
28
+ def forward(self, x):
29
+ return drop_path(x, self.drop_prob, self.training)
hyworldmirror/models/layers/layer_scale.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch import nn
8
+
9
+
10
+ class LayerScale(nn.Module):
11
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
12
+ super().__init__()
13
+ self.inplace = inplace
14
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
15
+
16
+ def forward(self, x: Tensor) -> Tensor:
17
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
hyworldmirror/models/layers/mlp.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
4
+
5
+
6
+ from typing import Callable, Optional
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ class Mlp(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ hidden_features: Optional[int] = None,
16
+ out_features: Optional[int] = None,
17
+ act_layer: Callable[..., nn.Module] = nn.GELU,
18
+ drop: float = 0.0,
19
+ bias: bool = True,
20
+ ) -> None:
21
+ super().__init__()
22
+ out_features = out_features or in_features
23
+ hidden_features = hidden_features or in_features
24
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
25
+ self.act = act_layer()
26
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
27
+ self.drop = nn.Dropout(drop)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.drop(x)
33
+ x = self.fc2(x)
34
+ x = self.drop(x)
35
+ return x
36
+
37
+
38
+ class MlpFP32(Mlp):
39
+ @staticmethod
40
+ def map_to_args_to_float(args, kwargs):
41
+ args = tuple(
42
+ torch.float32 if isinstance(arg, torch.dtype) else arg
43
+ for arg in args
44
+ )
45
+ kwargs = dict(kwargs)
46
+ for key in kwargs:
47
+ if key == "dtype":
48
+ kwargs[key] = torch.float32
49
+ return args, kwargs
50
+
51
+ def to(self, *args, **kwargs):
52
+ self.fc1 = self.fc1.to(*args, **kwargs)
53
+ args, kwargs = self.map_to_args_to_float(args, kwargs)
54
+ self.fc2 = self.fc2.to(*args, **kwargs)
55
+ return self
56
+
57
+ def forward_infer(self, x):
58
+ x = self.fc1(x)
59
+ x = 0.5 * x * (1 + torch.erf(x * 2**-0.5))
60
+ x = self.fc2(x.float())
61
+ return x
62
+
63
+ def forward(self, x: Tensor) -> Tensor:
64
+ return self.forward_infer(x)
hyworldmirror/models/layers/norm_rope.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, Literal, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+
9
+
10
+ class PositionGetter:
11
+ """Generates and caches 2D spatial positions for patches in a grid."""
12
+
13
+ def __init__(self) -> None:
14
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
15
+
16
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
17
+ if (height, width) not in self.position_cache:
18
+ y_coords = torch.arange(height, device=device)
19
+ x_coords = torch.arange(width, device=device)
20
+ self.position_cache[height, width] = torch.cartesian_prod(y_coords, x_coords)
21
+
22
+ cached_positions = self.position_cache[height, width]
23
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
24
+
25
+
26
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
27
+ x1, x2 = x.chunk(2, dim=-1)
28
+ return torch.cat((-x2, x1), dim=-1)
29
+
30
+
31
+ class NormalizedRotaryPositionEmbedding2D(nn.Module):
32
+ """DINOv3-aligned 2D Rotary Position Embedding."""
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ head_dim: int,
38
+ base: float = 100.0,
39
+ normalize_coords: Literal["min", "max", "separate"] = "separate",
40
+ shift_coords: Union[float, None] = None,
41
+ jitter_coords: Union[float, None] = None,
42
+ rescale_coords: Union[float, None] = None,
43
+ dtype: Union[torch.dtype, None] = None,
44
+ device: Union[torch.device, None] = None,
45
+ **ignored_kwargs,
46
+ ) -> None:
47
+ super().__init__()
48
+ if len(ignored_kwargs) > 0:
49
+ # maintain parity with DINOv3 implementation that warns on ignored kwargs
50
+ pass
51
+
52
+ if head_dim % 4 != 0:
53
+ raise ValueError("head_dim must be divisible by 4 for 2D RoPE")
54
+
55
+ self.head_dim = head_dim
56
+ self.base = base
57
+ self.normalize_coords = normalize_coords
58
+ self.shift_coords = shift_coords
59
+ self.jitter_coords = jitter_coords
60
+ self.rescale_coords = rescale_coords
61
+ self.dtype = dtype
62
+
63
+ quarter_dim = head_dim // 4
64
+ self.register_buffer(
65
+ "periods",
66
+ torch.empty(quarter_dim, device=device, dtype=dtype),
67
+ persistent=True,
68
+ )
69
+ self._init_periods()
70
+
71
+ def _init_periods(self) -> None:
72
+ quarter_dim = self.periods.shape[0]
73
+ half_dim = self.head_dim // 2
74
+ exponents = 2 * torch.arange(quarter_dim, device=self.periods.device, dtype=self.dtype) / half_dim
75
+ periods = self.base ** exponents
76
+ self.periods.data.copy_(periods)
77
+
78
+ def _get_sincos_for_grid(self, H: int, W: int, device: torch.device, dtype: torch.dtype) -> Tuple[Tensor, Tensor]:
79
+ dd = {"device": device, "dtype": dtype}
80
+
81
+ if self.normalize_coords == "max":
82
+ max_hw = max(H, W)
83
+ coords_h = torch.arange(0.5, H, **dd) / max_hw
84
+ coords_w = torch.arange(0.5, W, **dd) / max_hw
85
+ elif self.normalize_coords == "min":
86
+ min_hw = min(H, W)
87
+ coords_h = torch.arange(0.5, H, **dd) / min_hw
88
+ coords_w = torch.arange(0.5, W, **dd) / min_hw
89
+ elif self.normalize_coords == "separate":
90
+ coords_h = torch.arange(0.5, H, **dd) / H
91
+ coords_w = torch.arange(0.5, W, **dd) / W
92
+ else:
93
+ raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
94
+
95
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
96
+ coords = coords.flatten(0, 1) # [HW, 2]
97
+ coords = 2.0 * coords - 1.0
98
+
99
+ if self.training:
100
+ if self.shift_coords is not None:
101
+ shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
102
+ coords += shift_hw[None, :]
103
+ if self.jitter_coords is not None:
104
+ jitter_max = np.log(self.jitter_coords)
105
+ jitter_hw = torch.empty(2, **dd).uniform_(-jitter_max, jitter_max).exp()
106
+ coords *= jitter_hw[None, :]
107
+ if self.rescale_coords is not None:
108
+ rescale_max = np.log(self.rescale_coords)
109
+ rescale_hw = torch.empty(1, **dd).uniform_(-rescale_max, rescale_max).exp()
110
+ coords *= rescale_hw
111
+
112
+ periods = self.periods.to(device=device, dtype=dtype)
113
+ angles = (2 * math.pi * coords[:, :, None]) / periods[None, None, :] # [HW, 2, D/4]
114
+ angles = angles.flatten(1, 2) # [HW, D/2]
115
+ angles = torch.cat((angles, angles), dim=-1) # [HW, D]
116
+
117
+ cos = torch.cos(angles)
118
+ sin = torch.sin(angles)
119
+ return sin, cos
120
+
121
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
122
+ # Validate inputs
123
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
124
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
125
+
126
+ B, _, N, C_head = tokens.shape
127
+ if C_head != self.head_dim:
128
+ raise ValueError(f"Head dim {C_head} doesn't match configured {self.head_dim}")
129
+
130
+ H = int(positions[..., 0].max().item() + 1)
131
+ W = int(positions[..., 1].max().item() + 1)
132
+
133
+ sin, cos = self._get_sincos_for_grid(H, W, tokens.device, tokens.dtype)
134
+
135
+ indices = (positions[..., 0] * W + positions[..., 1]).long()
136
+ flat_indices = indices.view(-1)
137
+ gathered_sin = sin[flat_indices].view(B, 1, N, C_head)
138
+ gathered_cos = cos[flat_indices].view(B, 1, N, C_head)
139
+ return (tokens * gathered_cos) + (_rotate_half(tokens) * gathered_sin)
140
+
hyworldmirror/models/layers/patch_embed.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
4
+
5
+ from typing import Callable, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+
14
+ def make_2tuple(x):
15
+ if isinstance(x, tuple):
16
+ assert len(x) == 2
17
+ return x
18
+
19
+ assert isinstance(x, int)
20
+ return (x, x)
21
+
22
+
23
+ class PatchEmbed(nn.Module):
24
+ """
25
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
26
+
27
+ Args:
28
+ img_size: Image size.
29
+ patch_size: Patch token size.
30
+ in_chans: Number of input image channels.
31
+ embed_dim: Number of linear projection output channels.
32
+ norm_layer: Normalization layer.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ img_size: Union[int, Tuple[int, int]] = 224,
38
+ patch_size: Union[int, Tuple[int, int]] = 16,
39
+ in_chans: int = 3,
40
+ embed_dim: int = 768,
41
+ norm_layer: Optional[Callable] = None,
42
+ flatten_embedding: bool = True,
43
+ ) -> None:
44
+ super().__init__()
45
+
46
+ image_HW = make_2tuple(img_size)
47
+ patch_HW = make_2tuple(patch_size)
48
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
49
+
50
+ self.img_size = image_HW
51
+ self.patch_size = patch_HW
52
+ self.patches_resolution = patch_grid_size
53
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
54
+
55
+ self.in_chans = in_chans
56
+ self.embed_dim = embed_dim
57
+
58
+ self.flatten_embedding = flatten_embedding
59
+
60
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
61
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
62
+
63
+ def forward(self, x: Tensor) -> Tensor:
64
+ _, _, H, W = x.shape
65
+ patch_H, patch_W = self.patch_size
66
+
67
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
68
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
69
+
70
+ x = self.proj(x) # B C H W
71
+ H, W = x.size(2), x.size(3)
72
+ x = x.flatten(2).transpose(1, 2) # B HW C
73
+ x = self.norm(x)
74
+ if not self.flatten_embedding:
75
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
76
+ return x
77
+
78
+
79
+ class PatchEmbed_Mlp(PatchEmbed):
80
+ def __init__(self, img_size=224,
81
+ patch_size=16,
82
+ in_chans=3,
83
+ embed_dim=768,
84
+ norm_layer=None,
85
+ flatten_embedding=True):
86
+ super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten_embedding)
87
+
88
+ self.proj = nn.Sequential(
89
+ PixelUnshuffle(patch_size),
90
+ Permute((0,2,3,1)),
91
+ Mlp(in_chans * patch_size**2, 4*embed_dim, embed_dim),
92
+ Permute((0,3,1,2)),
93
+ )
94
+
95
+
96
+ class PixelUnshuffle (nn.Module):
97
+ def __init__(self, downscale_factor):
98
+ super().__init__()
99
+ self.downscale_factor = downscale_factor
100
+
101
+ def forward(self, input):
102
+ if input.numel() == 0:
103
+ # this is not in the original torch implementation
104
+ C,H,W = input.shape[-3:]
105
+ assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
106
+ return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
107
+ else:
108
+ return F.pixel_unshuffle(input, self.downscale_factor)
109
+
110
+
111
+ class Permute(torch.nn.Module):
112
+ dims: tuple[int, ...]
113
+ def __init__(self, dims: tuple[int, ...]) -> None:
114
+ super().__init__()
115
+ self.dims = tuple(dims)
116
+
117
+ def __repr__(self):
118
+ return f"Permute{self.dims}"
119
+
120
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
121
+ return input.permute(*self.dims)
122
+
123
+
124
+ def _ntuple(n):
125
+ def parse(x):
126
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
127
+ return x
128
+ return tuple(repeat(x, n))
129
+ return parse
130
+ to_2tuple = _ntuple(2)
131
+
132
+ class Mlp(nn.Module):
133
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
134
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
135
+ super().__init__()
136
+ out_features = out_features or in_features
137
+ hidden_features = hidden_features or in_features
138
+ bias = to_2tuple(bias)
139
+ drop_probs = to_2tuple(drop)
140
+
141
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
142
+ self.act = act_layer()
143
+ self.drop1 = nn.Dropout(drop_probs[0])
144
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
145
+ self.drop2 = nn.Dropout(drop_probs[1])
146
+
147
+ def forward(self, x):
148
+ x = self.fc1(x)
149
+ x = self.act(x)
150
+ x = self.drop1(x)
151
+ x = self.fc2(x)
152
+ x = self.drop2(x)
153
+ return x
154
+
155
+
hyworldmirror/models/layers/rope.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
2
+
3
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
4
+ # which extends the original RoPE concept to handle 2D spatial positions.
5
+
6
+ # Inspired by:
7
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
8
+ # https://github.com/naver-ai/rope-vit
9
+
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import Dict, Tuple
16
+
17
+
18
+ class PositionGetter:
19
+ """Generates and caches 2D spatial positions for patches in a grid.
20
+
21
+ This class efficiently manages the generation of spatial coordinates for patches
22
+ in a 2D grid, caching results to avoid redundant computations.
23
+
24
+ Attributes:
25
+ position_cache: Dictionary storing precomputed position tensors for different
26
+ grid dimensions.
27
+ """
28
+
29
+ def __init__(self):
30
+ """Initializes the position generator with an empty cache."""
31
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
32
+
33
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
34
+ """Generates spatial positions for a batch of patches.
35
+
36
+ Args:
37
+ batch_size: Number of samples in the batch.
38
+ height: Height of the grid in patches.
39
+ width: Width of the grid in patches.
40
+ device: Target device for the position tensor.
41
+
42
+ Returns:
43
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
44
+ for each position in the grid, repeated for each batch item.
45
+ """
46
+ if (height, width) not in self.position_cache:
47
+ y_coords = torch.arange(height, device=device)
48
+ x_coords = torch.arange(width, device=device)
49
+ positions = torch.cartesian_prod(y_coords, x_coords)
50
+ self.position_cache[height, width] = positions
51
+
52
+ cached_positions = self.position_cache[height, width]
53
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
54
+
55
+
56
+ class RotaryPositionEmbedding2D(nn.Module):
57
+ """2D Rotary Position Embedding implementation.
58
+
59
+ This module applies rotary position embeddings to input tokens based on their
60
+ 2D spatial positions. It handles the position-dependent rotation of features
61
+ separately for vertical and horizontal dimensions.
62
+
63
+ Args:
64
+ frequency: Base frequency for the position embeddings. Default: 100.0
65
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
66
+
67
+ Attributes:
68
+ base_frequency: Base frequency for computing position embeddings.
69
+ scaling_factor: Factor to scale the computed frequencies.
70
+ frequency_cache: Cache for storing precomputed frequency components.
71
+ """
72
+
73
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0,):
74
+ """Initializes the 2D RoPE module."""
75
+ super().__init__()
76
+ self.base_frequency = frequency
77
+ self.scaling_factor = scaling_factor
78
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
79
+
80
+ def _compute_frequency_components(
81
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """Computes frequency components for rotary embeddings.
84
+
85
+ Args:
86
+ dim: Feature dimension (must be even).
87
+ seq_len: Maximum sequence length.
88
+ device: Target device for computations.
89
+ dtype: Data type for the computed tensors.
90
+
91
+ Returns:
92
+ Tuple of (cosine, sine) tensors for frequency components.
93
+ """
94
+ cache_key = (dim, seq_len, device, dtype)
95
+ if cache_key not in self.frequency_cache:
96
+ # Compute frequency bands
97
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
98
+ inv_freq = 1.0 / (self.base_frequency**exponents)
99
+
100
+ # Generate position-dependent frequencies
101
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
102
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
103
+
104
+ # Compute and cache frequency components
105
+ angles = angles.to(dtype)
106
+ angles = torch.cat((angles, angles), dim=-1)
107
+ cos_components = angles.cos().to(dtype)
108
+ sin_components = angles.sin().to(dtype)
109
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
110
+
111
+ return self.frequency_cache[cache_key]
112
+
113
+ @staticmethod
114
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
115
+ """Performs feature rotation by splitting and recombining feature dimensions.
116
+
117
+ Args:
118
+ x: Input tensor to rotate.
119
+
120
+ Returns:
121
+ Rotated feature tensor.
122
+ """
123
+ feature_dim = x.shape[-1]
124
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
125
+ return torch.cat((-x2, x1), dim=-1)
126
+
127
+ def _apply_1d_rope(
128
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ """Applies 1D rotary position embeddings along one dimension.
131
+
132
+ Args:
133
+ tokens: Input token features.
134
+ positions: Position indices.
135
+ cos_comp: Cosine components for rotation.
136
+ sin_comp: Sine components for rotation.
137
+
138
+ Returns:
139
+ Tokens with applied rotary position embeddings.
140
+ """
141
+ # Embed positions with frequency components
142
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
143
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
144
+
145
+ # Apply rotation
146
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
147
+
148
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
149
+ """Applies 2D rotary position embeddings to input tokens.
150
+
151
+ Args:
152
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
153
+ The feature dimension (dim) must be divisible by 4.
154
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
155
+ the y and x coordinates for each token.
156
+
157
+ Returns:
158
+ Tensor of same shape as input with applied 2D rotary position embeddings.
159
+
160
+ Raises:
161
+ AssertionError: If input dimensions are invalid or positions are malformed.
162
+ """
163
+ # Validate inputs
164
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
165
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
166
+
167
+ # Compute feature dimension for each spatial direction
168
+ feature_dim = tokens.size(-1) // 2
169
+
170
+ # Get frequency components
171
+ max_position = int(positions.max()) + 1
172
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
173
+
174
+ # Split features for vertical and horizontal processing
175
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
176
+
177
+ # Apply RoPE separately for each dimension
178
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
179
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
180
+
181
+ # Combine processed features
182
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
hyworldmirror/models/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from torch import Tensor, nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class SwiGLUFFN(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_features: int,
11
+ hidden_features: Optional[int] = None,
12
+ out_features: Optional[int] = None,
13
+ act_layer: Callable[..., nn.Module] = None,
14
+ drop: float = 0.0,
15
+ bias: bool = True,
16
+ ) -> None:
17
+ super().__init__()
18
+ out_features = out_features or in_features
19
+ hidden_features = hidden_features or in_features
20
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
21
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
22
+
23
+ def forward(self, x: Tensor) -> Tensor:
24
+ x12 = self.w12(x)
25
+ x1, x2 = x12.chunk(2, dim=-1)
26
+ hidden = F.silu(x1) * x2
27
+ return self.w3(hidden)
28
+
29
+
30
+ SwiGLU = SwiGLUFFN
31
+
32
+
33
+ class SwiGLUFFNFused(SwiGLU):
34
+ def __init__(
35
+ self,
36
+ in_features: int,
37
+ hidden_features: Optional[int] = None,
38
+ out_features: Optional[int] = None,
39
+ act_layer: Callable[..., nn.Module] = None,
40
+ drop: float = 0.0,
41
+ bias: bool = True,
42
+ ) -> None:
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
46
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
hyworldmirror/models/layers/vision_transformer.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
4
+
5
+ from functools import partial
6
+ import math
7
+ import logging
8
+ from typing import Sequence, Tuple, Union, Callable
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.checkpoint import checkpoint
13
+ from torch.nn.init import trunc_normal_
14
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
20
+ if not depth_first and include_root:
21
+ fn(module=module, name=name)
22
+ for child_name, child_module in module.named_children():
23
+ child_name = ".".join((name, child_name)) if name else child_name
24
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
25
+ if depth_first and include_root:
26
+ fn(module=module, name=name)
27
+ return module
28
+
29
+
30
+ class BlockChunk(nn.ModuleList):
31
+ def forward(self, x):
32
+ for b in self:
33
+ x = b(x)
34
+ return x
35
+
36
+
37
+ class DinoVisionTransformer(nn.Module):
38
+ def __init__(
39
+ self,
40
+ img_size=224,
41
+ patch_size=16,
42
+ in_chans=3,
43
+ embed_dim=768,
44
+ depth=12,
45
+ num_heads=12,
46
+ mlp_ratio=4.0,
47
+ qkv_bias=True,
48
+ ffn_bias=True,
49
+ proj_bias=True,
50
+ drop_path_rate=0.0,
51
+ drop_path_uniform=False,
52
+ init_values=None, # for layerscale: None or 0 => no layerscale
53
+ embed_layer=PatchEmbed,
54
+ act_layer=nn.GELU,
55
+ block_fn=Block,
56
+ ffn_layer="mlp",
57
+ block_chunks=1,
58
+ num_register_tokens=0,
59
+ interpolate_antialias=False,
60
+ interpolate_offset=0.1,
61
+ qk_norm=False,
62
+ ):
63
+ """
64
+ Args:
65
+ img_size (int, tuple): input image size
66
+ patch_size (int, tuple): patch size
67
+ in_chans (int): number of input channels
68
+ embed_dim (int): embedding dimension
69
+ depth (int): depth of transformer
70
+ num_heads (int): number of attention heads
71
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
72
+ qkv_bias (bool): enable bias for qkv if True
73
+ proj_bias (bool): enable bias for proj in attn if True
74
+ ffn_bias (bool): enable bias for ffn if True
75
+ drop_path_rate (float): stochastic depth rate
76
+ drop_path_uniform (bool): apply uniform drop rate across blocks
77
+ weight_init (str): weight init scheme
78
+ init_values (float): layer-scale init values
79
+ embed_layer (nn.Module): patch embedding layer
80
+ act_layer (nn.Module): MLP activation layer
81
+ block_fn (nn.Module): transformer block class
82
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
83
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
84
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
85
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
86
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
87
+ """
88
+ super().__init__()
89
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
90
+
91
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
92
+ self.num_tokens = 1
93
+ self.n_blocks = depth
94
+ self.num_heads = num_heads
95
+ self.patch_size = patch_size
96
+ self.num_register_tokens = num_register_tokens
97
+ self.interpolate_antialias = interpolate_antialias
98
+ self.interpolate_offset = interpolate_offset
99
+ self.use_reentrant = False # hardcoded to False
100
+
101
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
102
+ num_patches = self.patch_embed.num_patches
103
+
104
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
105
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
106
+ assert num_register_tokens >= 0
107
+ self.register_tokens = (
108
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
109
+ )
110
+
111
+ if drop_path_uniform is True:
112
+ dpr = [drop_path_rate] * depth
113
+ else:
114
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
115
+
116
+ if ffn_layer == "mlp":
117
+ log.info("using MLP layer as FFN")
118
+ ffn_layer = Mlp
119
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
120
+ log.info("using SwiGLU layer as FFN")
121
+ ffn_layer = SwiGLUFFNFused
122
+ elif ffn_layer == "identity":
123
+ log.info("using Identity layer as FFN")
124
+
125
+ def f(*args, **kwargs):
126
+ return nn.Identity()
127
+
128
+ ffn_layer = f
129
+ else:
130
+ raise NotImplementedError
131
+
132
+ blocks_list = [
133
+ block_fn(
134
+ dim=embed_dim,
135
+ num_heads=num_heads,
136
+ mlp_ratio=mlp_ratio,
137
+ qkv_bias=qkv_bias,
138
+ proj_bias=proj_bias,
139
+ ffn_bias=ffn_bias,
140
+ drop_path=dpr[i],
141
+ norm_layer=norm_layer,
142
+ act_layer=act_layer,
143
+ ffn_layer=ffn_layer,
144
+ init_values=init_values,
145
+ qk_norm=qk_norm,
146
+ )
147
+ for i in range(depth)
148
+ ]
149
+ if block_chunks > 0:
150
+ self.chunked_blocks = True
151
+ chunked_blocks = []
152
+ chunksize = depth // block_chunks
153
+ for i in range(0, depth, chunksize):
154
+ # this is to keep the block index consistent if we chunk the block list
155
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
156
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
157
+ else:
158
+ self.chunked_blocks = False
159
+ self.blocks = nn.ModuleList(blocks_list)
160
+
161
+ self.norm = norm_layer(embed_dim)
162
+ self.head = nn.Identity()
163
+
164
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
165
+
166
+ self.init_weights()
167
+
168
+ def init_weights(self):
169
+ trunc_normal_(self.pos_embed, std=0.02)
170
+ nn.init.normal_(self.cls_token, std=1e-6)
171
+ if self.register_tokens is not None:
172
+ nn.init.normal_(self.register_tokens, std=1e-6)
173
+ named_apply(init_weights_vit_timm, self)
174
+
175
+ def interpolate_pos_encoding(self, x, w, h):
176
+ previous_dtype = x.dtype
177
+ npatch = x.shape[1] - 1
178
+ N = self.pos_embed.shape[1] - 1
179
+ if npatch == N and w == h:
180
+ return self.pos_embed
181
+ pos_embed = self.pos_embed.float()
182
+ class_pos_embed = pos_embed[:, 0]
183
+ patch_pos_embed = pos_embed[:, 1:]
184
+ dim = x.shape[-1]
185
+ w0 = w // self.patch_size
186
+ h0 = h // self.patch_size
187
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
188
+ assert N == M * M
189
+ kwargs = {}
190
+ if self.interpolate_offset:
191
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
192
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
193
+ sx = float(w0 + self.interpolate_offset) / M
194
+ sy = float(h0 + self.interpolate_offset) / M
195
+ kwargs["scale_factor"] = (sx, sy)
196
+ else:
197
+ # Simply specify an output size instead of a scale factor
198
+ kwargs["size"] = (w0, h0)
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
201
+ mode="bicubic",
202
+ antialias=self.interpolate_antialias,
203
+ **kwargs,
204
+ )
205
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
206
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
207
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
208
+
209
+ def prepare_tokens_with_masks(self, x, masks=None):
210
+ B, nc, w, h = x.shape
211
+ x = self.patch_embed(x)
212
+ if masks is not None:
213
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
214
+
215
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
216
+ x = x + self.interpolate_pos_encoding(x, w, h)
217
+
218
+ if self.register_tokens is not None:
219
+ x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
220
+
221
+ return x
222
+
223
+ def forward_features_list(self, x_list, masks_list):
224
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
225
+
226
+ for blk in self.blocks:
227
+ if self.training:
228
+ # x = blk(x)
229
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
230
+ else:
231
+ x = blk(x)
232
+
233
+ all_x = x
234
+ output = []
235
+ for x, masks in zip(all_x, masks_list):
236
+ x_norm = self.norm(x)
237
+ output.append(
238
+ {
239
+ "x_norm_clstoken": x_norm[:, 0],
240
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
241
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
242
+ "x_prenorm": x,
243
+ "masks": masks,
244
+ }
245
+ )
246
+ return output
247
+
248
+ def forward_features(self, x, masks=None):
249
+ if isinstance(x, list):
250
+ return self.forward_features_list(x, masks)
251
+
252
+ x = self.prepare_tokens_with_masks(x, masks)
253
+
254
+ for blk in self.blocks:
255
+ if self.training:
256
+ # x = blk(x)
257
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
258
+ else:
259
+ x = blk(x)
260
+
261
+ x_norm = self.norm(x)
262
+ return {
263
+ "x_norm_clstoken": x_norm[:, 0],
264
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
265
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
266
+ "x_prenorm": x,
267
+ "masks": masks,
268
+ }
269
+
270
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
271
+ x = self.prepare_tokens_with_masks(x)
272
+ # If n is an int, take the n last blocks. If it's a list, take them
273
+ output, total_block_len = [], len(self.blocks)
274
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
275
+ for i, blk in enumerate(self.blocks):
276
+ x = blk(x)
277
+ if i in blocks_to_take:
278
+ output.append(x)
279
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
280
+ return output
281
+
282
+ def _get_intermediate_layers_chunked(self, x, n=1):
283
+ x = self.prepare_tokens_with_masks(x)
284
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
287
+ for block_chunk in self.blocks:
288
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ i += 1
293
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
294
+ return output
295
+
296
+ def get_intermediate_layers(
297
+ self,
298
+ x: torch.Tensor,
299
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
300
+ reshape: bool = False,
301
+ return_class_token: bool = False,
302
+ norm=True,
303
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
304
+ if self.chunked_blocks:
305
+ outputs = self._get_intermediate_layers_chunked(x, n)
306
+ else:
307
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
308
+ if norm:
309
+ outputs = [self.norm(out) for out in outputs]
310
+ class_tokens = [out[:, 0] for out in outputs]
311
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
312
+ if reshape:
313
+ B, _, w, h = x.shape
314
+ outputs = [
315
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
316
+ for out in outputs
317
+ ]
318
+ if return_class_token:
319
+ return tuple(zip(outputs, class_tokens))
320
+ return tuple(outputs)
321
+
322
+ def forward(self, *args, is_training=True, **kwargs):
323
+ ret = self.forward_features(*args, **kwargs)
324
+ if is_training:
325
+ return ret
326
+ else:
327
+ return self.head(ret["x_norm_clstoken"])
328
+
329
+
330
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
331
+ """ViT weight initialization, original timm impl (for reproducibility)"""
332
+ if isinstance(module, nn.Linear):
333
+ trunc_normal_(module.weight, std=0.02)
334
+ if module.bias is not None:
335
+ nn.init.zeros_(module.bias)
336
+
337
+
338
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
339
+ model = DinoVisionTransformer(
340
+ patch_size=patch_size,
341
+ embed_dim=384,
342
+ depth=12,
343
+ num_heads=6,
344
+ mlp_ratio=4,
345
+ block_fn=partial(Block, attn_class=MemEffAttention),
346
+ num_register_tokens=num_register_tokens,
347
+ **kwargs,
348
+ )
349
+ return model
350
+
351
+
352
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
353
+ model = DinoVisionTransformer(
354
+ patch_size=patch_size,
355
+ embed_dim=768,
356
+ depth=12,
357
+ num_heads=12,
358
+ mlp_ratio=4,
359
+ block_fn=partial(Block, attn_class=MemEffAttention),
360
+ num_register_tokens=num_register_tokens,
361
+ **kwargs,
362
+ )
363
+ return model
364
+
365
+
366
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
367
+ model = DinoVisionTransformer(
368
+ patch_size=patch_size,
369
+ embed_dim=1024,
370
+ depth=24,
371
+ num_heads=16,
372
+ mlp_ratio=4,
373
+ block_fn=partial(Block, attn_class=MemEffAttention),
374
+ num_register_tokens=num_register_tokens,
375
+ **kwargs,
376
+ )
377
+ return model
378
+
379
+
380
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
381
+ """
382
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
383
+ """
384
+ model = DinoVisionTransformer(
385
+ patch_size=patch_size,
386
+ embed_dim=1536,
387
+ depth=40,
388
+ num_heads=24,
389
+ mlp_ratio=4,
390
+ block_fn=partial(Block, attn_class=MemEffAttention),
391
+ num_register_tokens=num_register_tokens,
392
+ **kwargs,
393
+ )
394
+ return model
hyworldmirror/models/models/__init__.py ADDED
File without changes
hyworldmirror/models/models/rasterization.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+ from einops import rearrange
7
+
8
+ from gsplat.rendering import rasterization
9
+ from gsplat.strategy import DefaultStrategy
10
+
11
+ from ..utils.frustum import calculate_unprojected_mask
12
+ from ..utils.geometry import depth_to_world_coords_points
13
+ from ..utils import sh_utils, act_gs
14
+
15
+ from typing import List
16
+
17
+ class Rasterizer:
18
+ def __init__(self, rasterization_mode="classic", packed=True, abs_grad=True, with_eval3d=False,
19
+ camera_model="pinhole", sparse_grad=False, distributed=False, grad_strategy=DefaultStrategy):
20
+ self.rasterization_mode = rasterization_mode
21
+ self.packed = packed
22
+ self.abs_grad = abs_grad
23
+ self.camera_model = camera_model
24
+ self.sparse_grad = sparse_grad
25
+ self.grad_strategy = grad_strategy
26
+ self.distributed = distributed
27
+ self.with_eval3d = with_eval3d
28
+
29
+ def rasterize_splats(
30
+ self,
31
+ means,
32
+ quats,
33
+ scales,
34
+ opacities,
35
+ colors,
36
+ camtoworlds: Tensor,
37
+ Ks: Tensor,
38
+ width: int,
39
+ height: int,
40
+ **kwargs,
41
+ ) -> Tuple[Tensor, Tensor, Dict]:
42
+ render_colors, render_alphas, _ = rasterization(
43
+ means=means,
44
+ quats=quats,
45
+ scales=scales,
46
+ opacities=opacities,
47
+ colors=colors,
48
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
49
+ Ks=Ks, # [C, 3, 3]
50
+ width=width,
51
+ height=height,
52
+ packed=self.packed,
53
+ absgrad=(
54
+ self.abs_grad
55
+ if isinstance(self.grad_strategy, DefaultStrategy)
56
+ else False
57
+ ),
58
+ sparse_grad=self.sparse_grad,
59
+ rasterize_mode=self.rasterization_mode,
60
+ distributed=self.distributed,
61
+ camera_model=self.camera_model,
62
+ with_eval3d=self.with_eval3d,
63
+ render_mode="RGB+ED",
64
+ **kwargs,
65
+ )
66
+ return render_colors[..., :3], render_colors[..., 3:], render_alphas
67
+
68
+ def rasterize_batches(self, means, quats, scales, opacities, colors, viewmats, Ks, width, height, **kwargs):
69
+ rendered_colors, rendered_depths, rendered_alphas = [], [], []
70
+ batch_size = len(means)
71
+ for i in range(batch_size):
72
+ means_i = means[i] # [N, 4]
73
+ quats_i = quats[i] # [N, 4]
74
+ scales_i = scales[i] # [N, 3]
75
+ opacities_i = opacities[i] # [N,]
76
+ colors_i = colors[i] # [N, 3]
77
+ viewmats_i = viewmats[i] # [V, 4, 4]
78
+ Ks_i = Ks[i] # [V, 3, 3]
79
+ render_colors_i, render_depths_i, render_alphas_i = self.rasterize_splats(
80
+ means_i, quats_i, scales_i, opacities_i, colors_i, viewmats_i, Ks_i, width, height, **kwargs
81
+ )
82
+
83
+ rendered_colors.append(render_colors_i) # V H W 3
84
+ rendered_depths.append(render_depths_i) # V H W 1
85
+ rendered_alphas.append(render_alphas_i) # V H W 1
86
+
87
+ rendered_colors = torch.stack(rendered_colors, dim=0) # B V H W 3
88
+ rendered_depths = torch.stack(rendered_depths, dim=0) # B V H W 1
89
+ rendered_alphas = torch.stack(rendered_alphas, dim=0) # B V H W 1
90
+
91
+ return rendered_colors, rendered_depths, rendered_alphas
92
+
93
+
94
+ class GaussianSplatRenderer(nn.Module):
95
+ def __init__(
96
+ self,
97
+ feature_dim: int = 256, # Output channels of gs_feat_head
98
+ sh_degree: int = 0,
99
+ enable_prune: bool = True,
100
+ voxel_size: float = 0.002, # Default voxel size for prune_gs
101
+ enable_conf_filter: bool = False, # Enable confidence filtering
102
+ conf_threshold_percent: float = 30.0, # Confidence threshold percentage
103
+ max_gaussians: int = 5000000, # Maximum number of Gaussians
104
+ ):
105
+ super().__init__()
106
+
107
+ self.feature_dim = feature_dim
108
+ self.sh_degree = sh_degree
109
+ self.nums_sh = (sh_degree + 1) ** 2
110
+ self.voxel_size = voxel_size
111
+ self.enable_prune = enable_prune
112
+ self.enable_conf_filter = enable_conf_filter
113
+ self.conf_threshold_percent = conf_threshold_percent
114
+ self.max_gaussians = max_gaussians
115
+
116
+ # Predict Gaussian parameters from GS features (quaternions/scales/opacities/SH/weights)
117
+ splits_and_inits = [
118
+ (4, 1.0, 0.0), # quats
119
+ (3, 0.00003, -7.0), # scales
120
+ (1, 1.0, -2.0), # opacities
121
+ (3 * self.nums_sh, 1.0, 0.0), # residual_sh
122
+ (1, 1.0, -2.0), # weights
123
+ ]
124
+ gaussian_raw_channels = 4 + 3 + 1 + self.nums_sh * 3 + 1
125
+
126
+ self.gs_head = nn.Sequential(
127
+ nn.Conv2d(feature_dim // 2, feature_dim, kernel_size=3, padding=1, bias=False),
128
+ nn.ReLU(True),
129
+ nn.Conv2d(feature_dim, gaussian_raw_channels, kernel_size=1),
130
+ )
131
+ # Initialize weights and biases of the final layer by segments
132
+ final_conv_layer = self.gs_head[-1]
133
+ start_channels = 0
134
+ for out_channel, s, b in splits_and_inits:
135
+ nn.init.xavier_uniform_(final_conv_layer.weight[start_channels:start_channels+out_channel], s)
136
+ nn.init.constant_(final_conv_layer.bias[start_channels:start_channels+out_channel], b)
137
+ start_channels += out_channel
138
+
139
+ # Rasterizer
140
+ self.rasterizer = Rasterizer()
141
+
142
+ # ======== Main entry point: Complete GS rendering and fill results back to predictions ========
143
+ def render(
144
+ self,
145
+ gs_feats: torch.Tensor, # [B, S, 3, H, W]
146
+ images: torch.Tensor, # [B, S+V, 3, H, W]
147
+ predictions: Dict[str, torch.Tensor], # From WorldMirror: pose/depth/pts3d etc
148
+ views: Dict[str, torch.Tensor],
149
+ context_predictions: Dict[str, torch.Tensor],
150
+ is_inference: bool=True,
151
+ ) -> Dict[str, torch.Tensor]:
152
+ """
153
+ Returns predictions with the following fields filled:
154
+ - rendered_colors / rendered_depths / (rendered_alphas during training)
155
+ - gt_colors / gt_depths / valid_masks
156
+ - splats / rendered_extrinsics / rendered_intrinsics
157
+ """
158
+ B, _, _, H, W = images.shape
159
+ S = context_predictions.get("imgs", images).shape[1] # context view nums
160
+ V = images.shape[1] - S # target view nums
161
+
162
+ # 1) Predict GS features from tokens, then convert to Gaussian parameters
163
+ gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
164
+ # Align input dtype with gs_head weights (handles fp32 input from
165
+ # precision-critical DPT output_conv2 when model runs in bf16 mode)
166
+ head_dtype = next(self.gs_head.parameters()).dtype
167
+ if gs_feats_reshape.dtype != head_dtype:
168
+ gs_feats_reshape = gs_feats_reshape.to(head_dtype)
169
+ gs_params = self.gs_head(gs_feats_reshape)
170
+ gt_colors = images.permute(0, 1, 3, 4, 2)
171
+
172
+ # 2) Select rendering cameras
173
+ if self.training:
174
+ # Using all gt cameras
175
+ render_viewmats, render_Ks = self.prepare_cameras(views, S + V)
176
+ gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
177
+ gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
178
+ unproject_masks = calculate_unprojected_mask(views, S) # [B, V, H, W]
179
+ valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
180
+ else:
181
+ # Re-predict the camera for novel views and perform translation scale alignment
182
+ pred_all_extrinsic, pred_all_intrinsic = self.prepare_cameras(predictions, S + V)
183
+ scale_factor = torch.ones(B, device=images.device)
184
+ if "camera_poses" in context_predictions:
185
+ pred_context_extrinsic, _ = self.prepare_cameras(context_predictions, S)
186
+ scale_factor = pred_context_extrinsic[:, :, :3, 3].norm(dim=-1).mean(dim=1, keepdim=True) / (
187
+ pred_all_extrinsic[:, :S, :3, 3].norm(dim=-1).mean(dim=1, keepdim=True) + 1e-6
188
+ )
189
+
190
+ pred_all_extrinsic[..., :3, 3] = pred_all_extrinsic[..., :3, 3] * scale_factor.unsqueeze(-1)
191
+ render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
192
+ valid_masks = views.get("valid_mask", torch.ones(B, S + V, H, W, dtype=bool, device=images.device))
193
+
194
+ # 3) Generate splats from gs_params + predictions, and perform voxel merging
195
+ if self.training:
196
+ splats = self.prepare_splats(
197
+ views,
198
+ predictions,
199
+ images,
200
+ gs_params,
201
+ S,
202
+ position_from="gsdepth+gtcamera",
203
+ )
204
+ elif not is_inference:
205
+ splats = self.prepare_splats(
206
+ views,
207
+ predictions,
208
+ images,
209
+ gs_params,
210
+ S,
211
+ context_predictions,
212
+ position_from="gsdepth+predcamera",
213
+ )
214
+ else:
215
+ splats = self.prepare_splats(
216
+ views,
217
+ predictions,
218
+ images,
219
+ gs_params,
220
+ S,
221
+ position_from="gsdepth+predcamera",
222
+ )
223
+
224
+ if is_inference:
225
+ predictions["splats"] = splats
226
+ return predictions
227
+
228
+ # Apply confidence filtering before pruning
229
+ if self.enable_conf_filter and "gs_depth_conf" in predictions:
230
+ splats = self.apply_confidence_filter(splats, predictions["gs_depth_conf"])
231
+
232
+ if self.enable_prune:
233
+ splats = self.prune_gs(splats, voxel_size=self.voxel_size)
234
+
235
+ predictions["splats"] = splats
236
+
237
+ # 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
238
+
239
+ # Prevent OOM by using chunked rendering
240
+ rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
241
+ chunk_size = 2
242
+ for i in range(0, gt_colors.shape[1], chunk_size):
243
+ end_idx = min(i + chunk_size, gt_colors.shape[1])
244
+ viewmats_i = render_viewmats[:, i:end_idx]
245
+ Ks_i = render_Ks[:, i:end_idx]
246
+
247
+ rendered_colors, rendered_depths, rendered_alphas = self.rasterizer.rasterize_batches(
248
+ splats["means"], splats["quats"], splats["scales"], splats["opacities"],
249
+ splats["sh"] if "sh" in splats else splats["colors"],
250
+ viewmats_i.detach(), Ks_i.detach(),
251
+ width=images.shape[-1], height=images.shape[-2],
252
+ sh_degree=min(self.sh_degree, 0) if "sh" in splats else None,
253
+ )
254
+ rendered_colors_list.append(rendered_colors)
255
+ rendered_depths_list.append(rendered_depths)
256
+ rendered_alphas_list.append(rendered_alphas)
257
+
258
+ rendered_colors = torch.cat(rendered_colors_list, dim=1)
259
+ rendered_depths = torch.cat(rendered_depths_list, dim=1)
260
+ rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
261
+
262
+ if self.training and V > 0:
263
+ nvs_rendered_mask = rendered_alphas[:, S:, ..., 0].detach() > 0.1
264
+ valid_masks[:, S:] = nvs_rendered_mask & valid_masks[:, S:]
265
+
266
+ # 5) return predictions
267
+ predictions["rendered_colors"] = rendered_colors
268
+ predictions["rendered_depths"] = rendered_depths
269
+ predictions["rendered_alphas"] = rendered_alphas
270
+ predictions["gt_colors"] = gt_colors.float()
271
+ predictions["gt_depths"] = views.get("depthmap")
272
+ predictions["valid_masks"] = valid_masks.bool()
273
+ predictions["rendered_extrinsics"] = render_viewmats
274
+ predictions["rendered_intrinsics"] = render_Ks
275
+
276
+
277
+ return predictions
278
+
279
+ def apply_confidence_filter(self, splats, gs_depth_conf):
280
+ """
281
+ Apply confidence filtering to Gaussian splats before pruning.
282
+ Discard bottom p% confidence points, keep top (100-p)%.
283
+
284
+ Args:
285
+ splats: Dictionary containing Gaussian parameters
286
+ gs_depth_conf: Confidence tensor [B, S, H, W]
287
+
288
+ Returns:
289
+ Filtered splats dictionary
290
+ """
291
+ if not self.enable_conf_filter or gs_depth_conf is None:
292
+ return splats
293
+
294
+ device = splats["means"].device
295
+ B, N = splats["means"].shape[:2]
296
+
297
+ # Flatten confidence: [B, S, H, W] -> [B, N]
298
+ conf = gs_depth_conf.flatten(1).to(device)
299
+ # Mask invalid/very small values
300
+ conf = conf.masked_fill(conf <= 1e-5, float("-inf"))
301
+
302
+ # Keep top (100-p)% points, discard bottom p%
303
+ if self.conf_threshold_percent > 0:
304
+ keep_from_percent = int(np.ceil(N * (100.0 - self.conf_threshold_percent) / 100.0))
305
+ else:
306
+ keep_from_percent = N
307
+ K = max(1, min(self.max_gaussians, keep_from_percent))
308
+
309
+ # Select top-K indices for each batch (deterministic, no randomness)
310
+ topk_idx = torch.topk(conf, K, dim=1, largest=True, sorted=False).indices # [B, K]
311
+
312
+ filtered = {}
313
+ mask_keys = ["means", "quats", "scales", "opacities", "sh", "weights"]
314
+
315
+ for key in splats.keys():
316
+ if key in mask_keys and key in splats:
317
+ x = splats[key]
318
+ if x.ndim == 2: # [B, N]
319
+ filtered[key] = torch.gather(x, 1, topk_idx)
320
+ else:
321
+ # Expand indices to match tensor dimensions
322
+ expand_idx = topk_idx.clone()
323
+ for i in range(x.ndim - 2):
324
+ expand_idx = expand_idx.unsqueeze(-1)
325
+ expand_idx = expand_idx.expand(-1, -1, *x.shape[2:])
326
+ filtered[key] = torch.gather(x, 1, expand_idx)
327
+ else:
328
+ filtered[key] = splats[key]
329
+
330
+ return filtered
331
+
332
+ def prune_gs(self, splats, voxel_size=0.002, filter_mask=None):
333
+ """
334
+ Prune Gaussian splats by optional mask filtering + voxel merging.
335
+
336
+ Args:
337
+ splats: Dictionary containing Gaussian parameters.
338
+ Each value is [B, S*H*W, ...] (batch of per-pixel gaussians).
339
+ voxel_size: Size of voxels for spatial grouping.
340
+ filter_mask: Optional bool tensor [B, S*H*W] or numpy [S, H, W].
341
+ True = keep, False = discard. Applied before voxel merge.
342
+
343
+ Returns:
344
+ Dictionary with pruned/merged splats (list-of-tensors per batch).
345
+ """
346
+ B = splats["means"].shape[0]
347
+ merged_splats_list = []
348
+ device = splats["means"].device
349
+
350
+ for i in range(B):
351
+ # Extract splats for current batch
352
+ splats_i = {k: splats[k][i] for k in ["means", "quats", "scales", "opacities", "sh", "weights"]}
353
+
354
+ # --- Apply filter_mask (discard unwanted gaussians before merge) ---
355
+ if filter_mask is not None:
356
+ if isinstance(filter_mask, np.ndarray):
357
+ fm = torch.from_numpy(filter_mask.reshape(-1)).to(device)
358
+ elif filter_mask.dim() == 3:
359
+ # [S, H, W] -> flatten
360
+ fm = filter_mask.reshape(-1).to(device)
361
+ else:
362
+ fm = filter_mask[i].to(device)
363
+ fm = fm.bool()
364
+ splats_i = {k: v[fm] for k, v in splats_i.items()}
365
+
366
+ N_in = splats_i["means"].shape[0]
367
+ if N_in == 0:
368
+ # All filtered out — push empty tensors
369
+ merged_splats_list.append({
370
+ "means": torch.zeros((0, 3), device=device),
371
+ "quats": torch.zeros((0, 4), device=device),
372
+ "scales": torch.zeros((0, 3), device=device),
373
+ "opacities": torch.zeros(0, device=device),
374
+ "sh": torch.zeros((0, self.nums_sh, 3), device=device),
375
+ })
376
+ continue
377
+
378
+ # Compute voxel indices
379
+ coords = splats_i["means"]
380
+ voxel_indices = (coords / voxel_size).floor().long()
381
+ min_indices = voxel_indices.min(dim=0)[0]
382
+ voxel_indices = voxel_indices - min_indices
383
+ max_dims = voxel_indices.max(dim=0)[0] + 1
384
+
385
+ # Flatten 3D voxel indices to 1D
386
+ flat_indices = (voxel_indices[:, 0] * max_dims[1] * max_dims[2] +
387
+ voxel_indices[:, 1] * max_dims[2] +
388
+ voxel_indices[:, 2])
389
+
390
+ # Find unique voxels and inverse mapping
391
+ unique_voxels, inverse_indices = torch.unique(flat_indices, return_inverse=True)
392
+ K = len(unique_voxels)
393
+
394
+ # Initialize merged splats
395
+ merged = {
396
+ "means": torch.zeros((K, 3), device=device),
397
+ "quats": torch.zeros((K, 4), device=device),
398
+ "scales": torch.zeros((K, 3), device=device),
399
+ "opacities": torch.zeros(K, device=device),
400
+ "sh": torch.zeros((K, self.nums_sh, 3), device=device)
401
+ }
402
+
403
+ # Get weights and compute weight sums per voxel
404
+ weights = splats_i["weights"]
405
+ weight_sums = torch.zeros(K, device=device)
406
+ weight_sums.scatter_add_(0, inverse_indices, weights)
407
+ weight_sums = torch.clamp(weight_sums, min=1e-8)
408
+
409
+ # Merge means (weighted average)
410
+ for d in range(3):
411
+ merged["means"][:, d].scatter_add_(0, inverse_indices,
412
+ splats_i["means"][:, d] * weights)
413
+ merged["means"] = merged["means"] / weight_sums.unsqueeze(1)
414
+
415
+ # Merge spherical harmonics (weighted average)
416
+ for d in range(3):
417
+ merged["sh"][:, 0, d].scatter_add_(0, inverse_indices,
418
+ splats_i["sh"][:, 0, d] * weights)
419
+ merged["sh"] = merged["sh"] / weight_sums.unsqueeze(-1).unsqueeze(-1)
420
+
421
+ # Merge opacities (weighted sum of squares)
422
+ merged["opacities"].scatter_add_(0, inverse_indices, weights * weights)
423
+ merged["opacities"] = merged["opacities"] / weight_sums
424
+
425
+ # Merge scales (weighted average)
426
+ for d in range(3):
427
+ merged["scales"][:, d].scatter_add_(0, inverse_indices,
428
+ splats_i["scales"][:, d] * weights)
429
+ merged["scales"] = merged["scales"] / weight_sums.unsqueeze(1)
430
+
431
+ # Merge quaternions (weighted average + normalization)
432
+ for d in range(4):
433
+ merged["quats"][:, d].scatter_add_(0, inverse_indices,
434
+ splats_i["quats"][:, d] * weights)
435
+ quat_norms = torch.norm(merged["quats"], dim=1, keepdim=True)
436
+ merged["quats"] = merged["quats"] / torch.clamp(quat_norms, min=1e-8)
437
+
438
+ merged_splats_list.append(merged)
439
+
440
+ # Reorganize output
441
+ output = {}
442
+ for key in ["means", "sh", "opacities", "scales", "quats"]:
443
+ output[key] = [merged[key] for merged in merged_splats_list]
444
+
445
+ return output
446
+
447
+ def prepare_splats(self, views, predictions, images, gs_params, context_nums,
448
+ context_predictions={}, position_from="gsdepth+gtcamera"):
449
+ """
450
+ Prepare Gaussian splats from model predictions and input data.
451
+
452
+ Args:
453
+ views: Dictionary containing view data (camera poses, intrinsics, etc.)
454
+ predictions: Model predictions including depth, pose_enc, etc.
455
+ images: Input images [B, S_all, 3, H, W]
456
+ gs_params: Gaussian splatting parameters from model
457
+ context_predictions: Optional context predictions for camera poses
458
+ position_from: Method to compute 3D positions ("pts3d", "gsdepth+gtcamera", "gsdepth+predcamera",
459
+ "depth_head+gtcamera", "depth_head+predcamera")
460
+ debug: Whether to use debug mode with ground truth data
461
+
462
+ Returns:
463
+ splats: Dictionary containing prepared Gaussian splat parameters
464
+ """
465
+ B, _, _, H, W = images.shape
466
+ S = context_nums
467
+ splats = {}
468
+
469
+ # Only take parameters from source view branch
470
+ gs_params = rearrange(gs_params, "(b s) c h w -> b s h w c", b=B)
471
+ splats["gs_feats"] = gs_params.reshape(B, S*H*W, -1)
472
+
473
+ # Split Gaussian parameters
474
+ quats, scales, opacities, residual_sh, weights = torch.split(
475
+ gs_params, [4, 3, 1, self.nums_sh * 3, 1], dim=-1
476
+ )
477
+
478
+ # Apply activation functions to Gaussian parameters
479
+ splats["quats"] = act_gs.reg_dense_rotation(quats.reshape(B, S * H * W, 4))
480
+ splats["scales"] = act_gs.reg_dense_scales(scales.reshape(B, S * H * W, 3)).clamp_max(0.3)
481
+ splats["opacities"] = act_gs.reg_dense_opacities(opacities.reshape(B, S * H * W))
482
+ residual_sh = act_gs.reg_dense_sh(residual_sh.reshape(B, S * H * W, self.nums_sh * 3))
483
+
484
+ # Handle spherical harmonics (SH) coefficients
485
+ new_sh = torch.zeros_like(residual_sh)
486
+ new_sh[..., 0, :] = sh_utils.RGB2SH(
487
+ images[:, :S].permute(0, 1, 3, 4, 2).reshape(B, S * H * W, 3)
488
+ )
489
+ splats['sh'] = new_sh + residual_sh
490
+ splats['residual_sh'] = residual_sh
491
+
492
+ splats["weights"] = act_gs.reg_dense_weights(weights.reshape(B, S * H * W))
493
+
494
+ # Compute 3D positions based on specified method
495
+ if position_from == "pts3d":
496
+ pts3d = predictions["pts3d"][:, :S].reshape(B, S * H * W, 3)
497
+ splats["means"] = pts3d
498
+ elif position_from == "gsdepth+gtcamera":
499
+ depth = predictions["gs_depth"][:, :S].reshape(B * S, H, W)
500
+ pose4x4 = views["camera_poses"][:, :S].reshape(B * S, 4, 4)
501
+ intrinsic = views["camera_intrs"][:, :S].reshape(B * S, 3, 3)
502
+ pts3d, _, _ = depth_to_world_coords_points(depth, pose4x4, intrinsic)
503
+ pts3d = pts3d.reshape(B, S * H * W, 3)
504
+ splats["means"] = pts3d
505
+
506
+ elif position_from == "gsdepth+predcamera":
507
+ depth = predictions["gs_depth"][:, :S].reshape(B * S , H, W)
508
+ pose4x4 = context_predictions.get("camera_poses", predictions["camera_poses"])[:, :S].reshape(B * S, 4, 4)
509
+ intrinsic = context_predictions.get("camera_intrs", predictions["camera_intrs"])[:, :S].reshape(B * S, 3, 3)
510
+ pts3d, _, _ = depth_to_world_coords_points(depth, pose4x4.detach(), intrinsic.detach())
511
+ pts3d = pts3d.reshape(B, S * H * W, 3)
512
+ splats["means"] = pts3d
513
+ else:
514
+ raise ValueError(f"Invalid position_from={position_from}")
515
+
516
+ return splats
517
+
518
+ def prepare_cameras(self, views, nums):
519
+ viewmats = views['camera_poses'][:, :nums]
520
+ Ks = views['camera_intrs'][:, :nums]
521
+ return viewmats, Ks
522
+
523
+
524
+
525
+
hyworldmirror/models/models/visual_transformer.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import Tuple, List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from ..layers import PatchEmbed, PatchEmbed_Mlp
10
+ from ..layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
11
+ from ..layers.block import Block, DistBlock
12
+ from ...comm.padding import minimal_pad_to_divisible,depad_by_length,pad_by_length
13
+ import torch.distributed as dist
14
+ from ...comm.communication import _All2All,_Allgather
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
19
+ _RESNET_STD = [0.229, 0.224, 0.225]
20
+
21
+
22
+ class VisualGeometryTransformer(nn.Module):
23
+ """
24
+ The VisualGeometryTransformer applies alternating-attention over input frames,
25
+ as described in VGGT: Visual Geometry Grounded Transformer.
26
+
27
+ Args:
28
+ img_size (int): Image size in pixels.
29
+ patch_size (int): Size of each patch for PatchEmbed.
30
+ embed_dim (int): Dimension of the token embeddings.
31
+ depth (int): Number of blocks.
32
+ num_heads (int): Number of attention heads.
33
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
34
+ num_register_tokens (int): Number of register tokens.
35
+ block_fn (nn.Module): The block type used for attention (Block by default).
36
+ qkv_bias (bool): Whether to include bias in QKV projections.
37
+ proj_bias (bool): Whether to include bias in the output projection.
38
+ ffn_bias (bool): Whether to include bias in MLP layers.
39
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
40
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
41
+ qk_norm (bool): Whether to apply QK normalization.
42
+ rope_base (int): Base frequency for rotary embedding.
43
+ rope_normalize_coords (str): Normalize coordinates for rotary embedding.
44
+ rope_shift_coords (float): Shift coordinates for rotary embedding.
45
+ rope_jitter_coords (float): Jitter coordinates for rotary embedding.
46
+ rope_rescale_coords (float): Rescale coordinates for rotary embedding.
47
+ init_values (float): Init scale for layer scale.
48
+ enable_condition (bool): Whether to enable conditioning inputs.
49
+ sampling_strategy (str): Sampling strategy for patches.
50
+ fixed_patch_embed (bool): Whether to fix patch embedding weights.
51
+ condition_strategy (list[str]): Strategy for each conditioning input.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ img_size=518,
57
+ patch_size=14,
58
+ embed_dim=1024,
59
+ depth=24,
60
+ num_heads=16,
61
+ mlp_ratio=4.0,
62
+ num_register_tokens=4,
63
+ block_fn=Block,
64
+ qkv_bias=True,
65
+ proj_bias=True,
66
+ ffn_bias=True,
67
+ patch_embed="dinov2_vitl14_reg",
68
+ qk_norm=True,
69
+ rope_base=100.0,
70
+ normalized_rope=False,
71
+ rope_normalize_coords="separate",
72
+ rope_shift_coords=None,
73
+ rope_jitter_coords=None,
74
+ rope_rescale_coords=None,
75
+ init_values=0.01,
76
+ enable_cond=False,
77
+ sampling_strategy="uniform",
78
+ fixed_patch_embed=False,
79
+ condition_strategy=["token", "pow3r", "token"],
80
+ intermediate_idxs: List[int] = [4, 11, 17, 23]
81
+ ):
82
+ super().__init__()
83
+ # Store config parameters
84
+ self.enable_cond = enable_cond
85
+ self.sampling_strategy = sampling_strategy
86
+ self.cond_methods = condition_strategy
87
+ self.intermediate_idxs = intermediate_idxs
88
+ self.depth = depth
89
+ self.patch_size = patch_size
90
+
91
+ # Initialize patch embedding module
92
+ self.patch_embed = self._init_patch_embedding_module(
93
+ patch_embed, img_size, patch_size, num_register_tokens,
94
+ embed_dim=embed_dim, is_fixed=fixed_patch_embed
95
+ )
96
+
97
+ # Initialize conditioning embeddings if enabled
98
+ if self.enable_cond:
99
+ self._init_cond_embeddings(embed_dim, img_size, patch_size, num_register_tokens)
100
+
101
+ # Initialize rotary position embedding
102
+ self._init_rotary_position_embedding(rope_base, normalized_rope, embed_dim // num_heads, rope_normalize_coords, rope_shift_coords, rope_jitter_coords, rope_rescale_coords)
103
+
104
+ # Initialize transformer blocks
105
+ self._init_transformer_blocks(block_fn, embed_dim, num_heads, mlp_ratio, qkv_bias, proj_bias, ffn_bias, init_values, qk_norm)
106
+
107
+ # Initialize learnable tokens
108
+ self._init_learnable_tokens(embed_dim, num_register_tokens)
109
+
110
+ # Calculate patch start index based on conditioning
111
+ if self.enable_cond:
112
+ self.patch_start_idx = 1 + num_register_tokens + 1 + 1 # camera + register + pose + rays
113
+ else:
114
+ self.patch_start_idx = 1 + num_register_tokens # camera + register
115
+
116
+ # Register normalization constants
117
+ for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
118
+ self.register_buffer(name, torch.FloatTensor(value).reshape(1, 1, 3, 1, 1), persistent=False)
119
+
120
+ self.use_reentrant = False
121
+
122
+ def _init_patch_embedding_module(
123
+ self,
124
+ patch_embed_type,
125
+ img_size,
126
+ patch_size,
127
+ num_reg_tokens,
128
+ interpolate_antialias=True,
129
+ interpolate_offset=0.0,
130
+ block_chunks=0,
131
+ init_values=1.0,
132
+ embed_dim=1024,
133
+ is_fixed=False,
134
+ in_chans=3
135
+ ):
136
+ """
137
+ Create the patch embedding module. If 'conv', we use a
138
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
139
+ """
140
+ if "conv" in patch_embed_type:
141
+ if 'mlp' in patch_embed_type:
142
+ patch_embed_module = PatchEmbed_Mlp(
143
+ img_size=img_size,
144
+ patch_size=patch_size,
145
+ in_chans=in_chans,
146
+ embed_dim=embed_dim
147
+ )
148
+ else:
149
+ patch_embed_module = PatchEmbed(
150
+ img_size=img_size,
151
+ patch_size=patch_size,
152
+ in_chans=in_chans,
153
+ embed_dim=embed_dim
154
+ )
155
+ else:
156
+ vit_models = {
157
+ "dinov2_vitl14_reg": vit_large,
158
+ "dinov2_vitb14_reg": vit_base,
159
+ "dinov2_vits14_reg": vit_small,
160
+ "dinov2_vitg2_reg": vit_giant2,
161
+ }
162
+
163
+ patch_embed_module = vit_models[patch_embed_type](
164
+ img_size=img_size,
165
+ patch_size=patch_size,
166
+ num_register_tokens=num_reg_tokens,
167
+ interpolate_antialias=interpolate_antialias,
168
+ interpolate_offset=interpolate_offset,
169
+ block_chunks=block_chunks,
170
+ init_values=init_values,
171
+ )
172
+
173
+ # Disable gradient updates for mask token
174
+ if hasattr(patch_embed_module, "mask_token"):
175
+ patch_embed_module.mask_token.requires_grad_(False)
176
+
177
+ if is_fixed:
178
+ for param in patch_embed_module.parameters():
179
+ param.requires_grad_(False)
180
+
181
+ return patch_embed_module
182
+
183
+ def _init_cond_embeddings(self, embed_dim, img_size, patch_size, num_reg_tokens):
184
+ """Initialize conditioning embeddings for camera, depth, and rays."""
185
+ assert self.cond_methods is not None
186
+ assert self.cond_methods[0] == "token"
187
+
188
+ # Camera pose embedding
189
+ if self.cond_methods[0] == "token":
190
+ self.pose_embed = nn.Sequential(
191
+ nn.Linear(7, embed_dim, bias=True),
192
+ nn.SiLU(),
193
+ nn.Linear(embed_dim, embed_dim, bias=True)
194
+ )
195
+ else:
196
+ raise NotImplementedError
197
+
198
+ # Depth map embedding
199
+ if self.cond_methods[1] == "pow3r":
200
+ self.depth_embed = self._init_patch_embedding_module(
201
+ "conv+mlp", img_size, patch_size, num_reg_tokens,
202
+ embed_dim=embed_dim, in_chans=1
203
+ )
204
+ else:
205
+ raise NotImplementedError
206
+
207
+ # Ray direction embedding
208
+ if self.cond_methods[2] == "token":
209
+ self.ray_embed = nn.Sequential(
210
+ nn.Linear(4, embed_dim, bias=True),
211
+ nn.SiLU(),
212
+ nn.Linear(embed_dim, embed_dim, bias=True)
213
+ )
214
+ else:
215
+ raise NotImplementedError
216
+
217
+ def _init_rotary_position_embedding(self, rope_base, normalized_rope, head_dim, rope_normalize_coords, rope_shift_coords, rope_jitter_coords, rope_rescale_coords):
218
+ if normalized_rope:
219
+ print("[INFO] Using normalized RoPE!")
220
+ from ..layers.norm_rope import NormalizedRotaryPositionEmbedding2D, PositionGetter
221
+ if head_dim % 4 != 0:
222
+ raise ValueError("RoPE requires head_dim divisible by 4 (embed_dim must be divisible by 4*num_heads)")
223
+ self.rope = NormalizedRotaryPositionEmbedding2D(
224
+ head_dim=head_dim,
225
+ base=rope_base,
226
+ normalize_coords=rope_normalize_coords,
227
+ shift_coords=rope_shift_coords,
228
+ jitter_coords=rope_jitter_coords,
229
+ rescale_coords=rope_rescale_coords,
230
+ ) if rope_base > 0 else None
231
+ self.pos_getter = PositionGetter() if self.rope is not None else None
232
+ else:
233
+ from ..layers.rope import RotaryPositionEmbedding2D, PositionGetter
234
+ print("[INFO] Using standard RoPE!")
235
+ self.rope = RotaryPositionEmbedding2D(
236
+ frequency=rope_base,
237
+ ) if rope_base > 0 else None
238
+ self.pos_getter = PositionGetter() if self.rope is not None else None
239
+
240
+ def _init_transformer_blocks(self, block_fn, embed_dim, num_heads, mlp_ratio, qkv_bias, proj_bias, ffn_bias, init_values, qk_norm):
241
+ self.frame_blocks = nn.ModuleList([
242
+ block_fn(
243
+ dim=embed_dim,
244
+ num_heads=num_heads,
245
+ mlp_ratio=mlp_ratio,
246
+ qkv_bias=qkv_bias,
247
+ proj_bias=proj_bias,
248
+ ffn_bias=ffn_bias,
249
+ init_values=init_values,
250
+ qk_norm=qk_norm,
251
+ rope=self.rope,
252
+ )
253
+ for _ in range(self.depth)
254
+ ])
255
+
256
+ self.global_blocks = nn.ModuleList([
257
+ block_fn(
258
+ dim=embed_dim,
259
+ num_heads=num_heads,
260
+ mlp_ratio=mlp_ratio,
261
+ qkv_bias=qkv_bias,
262
+ proj_bias=proj_bias,
263
+ ffn_bias=ffn_bias,
264
+ init_values=init_values,
265
+ qk_norm=qk_norm,
266
+ rope=self.rope
267
+ )
268
+ for _ in range(self.depth)
269
+ ])
270
+
271
+ def _init_learnable_tokens(self, embed_dim, num_reg_tokens):
272
+ """Initialize learnable tokens."""
273
+ self.cam_token = nn.Parameter(torch.zeros(1, 2, 1, embed_dim))
274
+ self.reg_token = nn.Parameter(torch.zeros(1, 2, num_reg_tokens, embed_dim))
275
+ nn.init.normal_(self.cam_token, std=1e-6)
276
+ nn.init.normal_(self.reg_token, std=1e-6)
277
+
278
+ def forward(self, images: torch.Tensor, priors: List | None=None, cond_flags: List[int]=[0,0,0], ctx_frames: int=None, enable_bf16=False, sp_size: int=1, sp_group: torch._C._distributed_c10d.ProcessGroup=None) -> Tuple[List[torch.Tensor], int]:
279
+ """
280
+ Args:
281
+ images: Input images with shape [B, S, 3, H, W], in range [0, 1]
282
+ priors: Optional tuple of (depth, rays, poses) for conditioning
283
+ cond_flags: List indicating which conditions to use [pose, depth, rays]
284
+ ctx_frames: Number of context frames to use
285
+
286
+ Returns:
287
+ (list[torch.Tensor], int): List of attention block outputs and patch_start_idx
288
+ """
289
+ depth_maps, ray_dirs, poses = priors if priors is not None else (None, None, None)
290
+
291
+ # Slice to context frames if specified
292
+ if ctx_frames is not None:
293
+ for var_name in ['images', 'depth_maps', 'ray_dirs', 'poses']:
294
+ var = locals()[var_name]
295
+ if var is not None:
296
+ locals()[var_name] = var[:, :ctx_frames].clone()
297
+
298
+ # Process image tokens
299
+ b, seq_len, ch, h, w = images.shape
300
+ if ch != 3:
301
+ raise ValueError(f"Expected 3 input channels, got {ch}")
302
+
303
+ with torch.amp.autocast('cuda', enabled=(not enable_bf16), dtype=torch.bfloat16):
304
+ images = (images - self._resnet_mean) / self._resnet_std
305
+ images = images.reshape(b * seq_len, ch, h, w)
306
+ patch_tokens = self.patch_embed(images)
307
+ if isinstance(patch_tokens, dict):
308
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
309
+
310
+ _, patch_count, embed_dim = patch_tokens.shape
311
+
312
+ # Prepare special tokens
313
+ cam_tokens = expand_and_flatten_special_tokens(self.cam_token, b, seq_len)
314
+ reg_tokens = expand_and_flatten_special_tokens(self.reg_token, b, seq_len)
315
+
316
+ # Process all tokens (optional conditioning)
317
+ if self.enable_cond:
318
+ pose_tokens, depth_tokens, ray_tokens = self._process_conditioning(depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags)
319
+ # Add condition tokens to patch tokens
320
+ patch_tokens = patch_tokens + depth_tokens
321
+ all_tokens = torch.cat([cam_tokens, reg_tokens, pose_tokens, ray_tokens, patch_tokens], dim=1)
322
+ else:
323
+ all_tokens = torch.cat([cam_tokens, reg_tokens, patch_tokens], dim=1)
324
+
325
+ _, patch_count, embed_dim = all_tokens.shape
326
+
327
+ # Position embedding
328
+ pos_emb = None
329
+ if self.rope is not None:
330
+ pos_emb = self.pos_getter(b * seq_len, h // self.patch_size, w // self.patch_size, device=images.device)
331
+ if self.patch_start_idx > 0:
332
+ pos_emb = pos_emb + 1
333
+ special_pos = torch.zeros(b * seq_len, self.patch_start_idx, 2, device=images.device, dtype=pos_emb.dtype)
334
+ pos_emb = torch.cat([special_pos, pos_emb], dim=1)
335
+
336
+ if sp_size>1:
337
+ rank_in_sp_group = dist.get_group_rank(sp_group,dist.get_rank())
338
+ all_tokens,tk_padding_len = minimal_pad_to_divisible(all_tokens, sp_size, dim=1,pad_value=0)
339
+ all_tokens = torch.chunk(all_tokens, sp_size,dim=1)[rank_in_sp_group]
340
+
341
+ _, patch_count, embed_dim = all_tokens.shape
342
+ token_shape = (b, seq_len, patch_count, embed_dim)
343
+ # Forward through attention blocks
344
+ with torch.amp.autocast('cuda', enabled=(not enable_bf16), dtype=torch.bfloat16):
345
+ outputs = []
346
+ global_tokens = None
347
+ if sp_size>1:
348
+ for idx in range(self.depth):
349
+ local_tokens = self._process_dist_attention_blocks(
350
+ tokens=all_tokens if global_tokens is None else global_tokens,
351
+ b=b,
352
+ seq_len=seq_len,
353
+ patch_count=patch_count,
354
+ embed_dim=embed_dim,
355
+ block_idx=idx,
356
+ blocks=self.frame_blocks,
357
+ block_type='frame',
358
+ pos=pos_emb,
359
+ sp_size = sp_size,
360
+ sp_group = sp_group,
361
+ padding_tokens = tk_padding_len
362
+ )
363
+ global_tokens = self._process_dist_attention_blocks(
364
+ tokens=local_tokens,
365
+ b=b,
366
+ seq_len=seq_len,
367
+ patch_count=patch_count,
368
+ embed_dim=embed_dim,
369
+ block_idx=idx,
370
+ blocks=self.global_blocks,
371
+ block_type='global',
372
+ pos=pos_emb,
373
+ sp_size = sp_size,
374
+ sp_group = sp_group,
375
+ padding_tokens = tk_padding_len
376
+ )
377
+ global_tokens = global_tokens.reshape(b,-1,embed_dim)
378
+ global_tokens = _Allgather.apply(global_tokens,1,sp_group,False)
379
+ global_tokens = depad_by_length(global_tokens,tk_padding_len*seq_len,1)
380
+ global_tokens = global_tokens.reshape(b,seq_len,-1,embed_dim)
381
+ global_tokens = pad_by_length(global_tokens,tk_padding_len,2)
382
+ global_tokens = torch.chunk(global_tokens, sp_size,dim=2)[rank_in_sp_group]
383
+
384
+ # Combine frame and global intermediates
385
+ if idx in self.intermediate_idxs:
386
+ local_tokens = _Allgather.apply(local_tokens,2,sp_group,False)
387
+ local_tokens = depad_by_length(local_tokens,tk_padding_len,2)
388
+ global_tokens = _Allgather.apply(global_tokens,2,sp_group,False)
389
+ global_tokens = depad_by_length(global_tokens,tk_padding_len,2)
390
+ combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
391
+ outputs.append(combined_out)
392
+ global_tokens = pad_by_length(global_tokens,tk_padding_len,2)
393
+ global_tokens = torch.chunk(global_tokens, sp_size,dim=2)[rank_in_sp_group]
394
+ else:
395
+ for idx in range(self.depth):
396
+ local_tokens = self._process_attention_blocks(
397
+ tokens=all_tokens if global_tokens is None else global_tokens,
398
+ b=b,
399
+ seq_len=seq_len,
400
+ patch_count=patch_count,
401
+ embed_dim=embed_dim,
402
+ block_idx=idx,
403
+ blocks=self.frame_blocks,
404
+ block_type='frame',
405
+ pos=pos_emb,
406
+ )
407
+ global_tokens = self._process_attention_blocks(
408
+ tokens=local_tokens,
409
+ b=b,
410
+ seq_len=seq_len,
411
+ patch_count=patch_count,
412
+ embed_dim=embed_dim,
413
+ block_idx=idx,
414
+ blocks=self.global_blocks,
415
+ block_type='global',
416
+ pos=pos_emb,
417
+ )
418
+ # Combine frame and global intermediates
419
+ if idx in self.intermediate_idxs:
420
+ combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
421
+ outputs.append(combined_out)
422
+
423
+ # Combine frame and global intermediates
424
+ if idx in self.intermediate_idxs:
425
+ combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
426
+ outputs.append(combined_out)
427
+
428
+ return outputs, self.patch_start_idx
429
+
430
+ def _process_conditioning(self, depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags):
431
+ """Process conditioning inputs."""
432
+ h, w = images.shape[-2:]
433
+
434
+ # Process camera pose embedding
435
+ use_poses = (cond_flags[0] == 1 and poses is not None)
436
+ if use_poses:
437
+ poses = poses.reshape(b*seq_len, -1)
438
+ pose_tokens = self.pose_embed(poses).unsqueeze(1)
439
+ else:
440
+ pose_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
441
+
442
+ # Process depth map embedding
443
+ use_depth = cond_flags[1] == 1 and depth_maps is not None
444
+ if use_depth:
445
+ depth_maps = depth_maps.reshape(b*seq_len, 1, h, w)
446
+ depth_tokens = self.depth_embed(depth_maps).reshape(b * seq_len, patch_count, embed_dim)
447
+ else:
448
+ depth_tokens = torch.zeros((b*seq_len, patch_count, embed_dim), device=images.device, dtype=images.dtype)
449
+
450
+ # Process ray direction embedding
451
+ use_rays = cond_flags[2] == 1 and ray_dirs is not None
452
+ if use_rays:
453
+ ray_dirs = ray_dirs.reshape(b*seq_len, -1)
454
+ ray_tokens = self.ray_embed(ray_dirs).unsqueeze(1)
455
+ else:
456
+ ray_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
457
+
458
+ return pose_tokens, depth_tokens, ray_tokens
459
+
460
+ def _process_attention_blocks(self, tokens, b, seq_len, patch_count, embed_dim, block_idx, blocks, block_type, pos=None):
461
+ """Process attention blocks with tokens in shape (B*S, P, C)."""
462
+ token_shape = (b, seq_len, patch_count, embed_dim)
463
+ if block_type == 'frame': # local
464
+ target_shape = (b * seq_len, patch_count, embed_dim)
465
+ pos_target_shape = (b * seq_len, patch_count, 2) if pos is not None else None
466
+ else: # global
467
+ target_shape = (b, seq_len * patch_count, embed_dim)
468
+ pos_target_shape = (b, seq_len * patch_count, 2) if pos is not None else None
469
+
470
+ if tokens.shape != target_shape:
471
+ tokens = tokens.reshape(*target_shape)
472
+
473
+ if pos is not None and pos.shape != pos_target_shape:
474
+ pos = pos.reshape(*pos_target_shape)
475
+
476
+ if self.training:
477
+ # tokens = blocks[block_idx](tokens, pos=pos)
478
+ tokens = checkpoint(blocks[block_idx], tokens, pos=pos, use_reentrant=self.use_reentrant)
479
+ else:
480
+ tokens = blocks[block_idx](tokens, pos=pos)
481
+
482
+ return tokens.reshape(*token_shape)
483
+
484
+ def _process_dist_attention_blocks(self, tokens, b, seq_len, patch_count, embed_dim, block_idx, blocks, block_type, pos=None,
485
+ sp_size = 1,
486
+ sp_group = None,
487
+ padding_tokens = 0):
488
+ """Process attention blocks with tokens in shape (B*S, P, C)."""
489
+ token_shape = (b, seq_len, patch_count, embed_dim)
490
+ if block_type == 'frame': # local
491
+ target_shape = (b * seq_len, patch_count, embed_dim)
492
+ pos_target_shape = (b * seq_len, patch_count*sp_size-padding_tokens, 2) if pos is not None else None
493
+ else: # global
494
+ target_shape = (b, seq_len * patch_count, embed_dim)
495
+ pos_target_shape = (b, seq_len * (patch_count*sp_size-padding_tokens), 2) if pos is not None else None
496
+ # padding_tokens = padding_tokens*seq_len
497
+
498
+ if block_type=="global":
499
+ rank_in_sp_group = dist.get_group_rank(sp_group,dist.get_rank())
500
+ tokens = _Allgather.apply(tokens,2,sp_group,False) #(1,7,4*146,64)
501
+ tokens = depad_by_length(tokens,padding_tokens,2) #(1,7,4*146-2,64)
502
+ tokens = tokens.reshape(b,-1,embed_dim) #(1,7*(4*146-2),64)
503
+ padding_tokens = padding_tokens*seq_len
504
+ tokens = pad_by_length(tokens,padding_tokens,1) #(1,4088,1024)
505
+ tokens = torch.chunk(tokens, sp_size,dim=1)[rank_in_sp_group]
506
+
507
+ if tokens.shape != target_shape:
508
+ tokens = tokens.reshape(*target_shape)
509
+
510
+ if pos is not None and pos.shape != pos_target_shape:
511
+ pos = pos.reshape(*pos_target_shape)
512
+
513
+ if self.training:
514
+ # tokens = blocks[block_idx](tokens, pos=pos)
515
+ tokens = checkpoint(blocks[block_idx], tokens, pos=pos, use_reentrant=self.use_reentrant, sp_size=sp_size, sp_group=sp_group, padding_tokens=padding_tokens, block_type =block_type,token_shape=token_shape)
516
+ else:
517
+ tokens = blocks[block_idx](tokens, pos=pos, sp_size=sp_size, sp_group=sp_group, padding_tokens=padding_tokens, block_type =block_type,token_shape=token_shape)
518
+
519
+ return tokens.reshape(*token_shape)
520
+
521
+
522
+
523
+ def expand_and_flatten_special_tokens(token_tensor, b, seq_len):
524
+ """
525
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing.
526
+ Uses first position for frame 0, second position for remaining frames.
527
+
528
+ Args:
529
+ token_tensor: Input tensor with shape (1, 2, X, C)
530
+ b: Batch size
531
+ seq_len: Sequence length
532
+
533
+ Returns:
534
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
535
+ """
536
+ # First frame uses position 0, remaining frames use position 1
537
+ first_frame_tokens = token_tensor[:, 0:1, ...].expand(b, 1, *token_tensor.shape[2:])
538
+ remaining_frame_tokens = token_tensor[:, 1:, ...].expand(b, seq_len - 1, *token_tensor.shape[2:])
539
+
540
+ # Concatenate and flatten
541
+ combined_tokens = torch.cat([first_frame_tokens, remaining_frame_tokens], dim=1)
542
+ return combined_tokens.reshape(b * seq_len, *combined_tokens.shape[2:])
hyworldmirror/models/models/worldmirror.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .visual_transformer import VisualGeometryTransformer
8
+ from ..heads.camera_head import CameraHead
9
+ from ..heads.dense_head import DPTHead
10
+ from ..heads.gs_head import GSFeatHead
11
+ from .rasterization import GaussianSplatRenderer
12
+ from ..utils.camera_utils import (
13
+ vector_to_camera_matrices,
14
+ extrinsics_to_vector,
15
+ )
16
+ from ..utils.priors import normalize_depth, normalize_poses
17
+ from huggingface_hub import PyTorchModelHubMixin
18
+
19
+ from ..layers.block import Block, DistBlock
20
+ import torch.distributed as dist
21
+
22
+
23
+ class WorldMirror(nn.Module, PyTorchModelHubMixin):
24
+ def __init__(
25
+ self,
26
+ img_size=518,
27
+ patch_size=14,
28
+ model_size="large",
29
+ embed_dim=1024,
30
+ depth=24,
31
+ num_heads=16,
32
+ mlp_ratio=4.0,
33
+ gs_dim=256,
34
+ num_register_tokens=4,
35
+ enable_cond=True,
36
+ enable_cam=True,
37
+ enable_pts=True,
38
+ enable_depth=True,
39
+ enable_depth_mask=True,
40
+ enable_norm=True,
41
+ enable_gs=True,
42
+ enable_bf16=False,
43
+ patch_embed="dinov2_vitl14_reg",
44
+ fixed_patch_embed=False,
45
+ sampling_strategy="uniform",
46
+ dpt_gradient_checkpoint=False,
47
+ condition_strategy=["token", "pow3r", "token"],
48
+ rope_base=100.0,
49
+ normalized_rope=True,
50
+ rope_normalize_coords="separate",
51
+ rope_shift_coords=None,
52
+ rope_jitter_coords=None,
53
+ rope_rescale_coords=None,
54
+ sp_size=1,
55
+ # Legacy parameters (ignored, kept for checkpoint compatibility)
56
+ set_sky_region_to_maxdepth=False,
57
+ disable_gs_depth=False,
58
+ ):
59
+
60
+ super().__init__()
61
+
62
+ self.intermediate_layer_idx = {
63
+ "small": [2, 5, 8, 11],
64
+ "base": [2, 5, 8, 11],
65
+ "large": [4, 11, 17, 23],
66
+ "giant": [9, 19, 29, 39],
67
+ }
68
+ self.model_size = model_size
69
+ if model_size == "large":
70
+ embed_dim = 1024
71
+ depth = 24
72
+ num_heads = 16
73
+ mlp_ratio = 4.0
74
+ gs_dim = 256
75
+ num_register_tokens = 4
76
+ elif model_size == "base":
77
+ embed_dim = 768
78
+ depth = 12
79
+ num_heads = 12
80
+ mlp_ratio = 4.0
81
+ gs_dim = 256
82
+ num_register_tokens = 4
83
+ elif model_size == "small":
84
+ embed_dim = 384
85
+ depth = 12
86
+ num_heads = 6
87
+ mlp_ratio = 4.0
88
+ gs_dim = 128
89
+ num_register_tokens = 4
90
+ elif model_size is None:
91
+ pass
92
+ print(
93
+ f"[WorldMirror] model_size: {model_size}, embed_dim: {embed_dim}, "
94
+ f"depth: {depth}, num_heads: {num_heads}, mlp_ratio: {mlp_ratio}, "
95
+ f"gs_dim: {gs_dim}, num_register_tokens: {num_register_tokens}"
96
+ )
97
+
98
+ self.img_size = img_size
99
+ self.patch_size = patch_size
100
+ self.embed_dim = embed_dim
101
+ self.depth = depth
102
+ self.num_heads = num_heads
103
+ self.mlp_ratio = mlp_ratio
104
+ self.gs_dim = gs_dim
105
+ self.num_register_tokens = num_register_tokens
106
+
107
+ self.normalized_rope = normalized_rope
108
+ self.rope_normalize_coords = rope_normalize_coords
109
+ self.rope_shift_coords = rope_shift_coords
110
+ self.rope_jitter_coords = rope_jitter_coords
111
+ self.rope_rescale_coords = rope_rescale_coords
112
+
113
+ self.enable_cam = enable_cam
114
+ self.enable_pts = enable_pts
115
+ self.enable_depth = enable_depth
116
+ self.enable_depth_mask = enable_depth_mask
117
+ self.enable_cond = enable_cond
118
+ self.enable_norm = enable_norm
119
+ self.enable_gs = enable_gs
120
+ self.enable_bf16 = enable_bf16
121
+ self.patch_embed = patch_embed
122
+ self.sampling = sampling_strategy
123
+ self.dpt_checkpoint = dpt_gradient_checkpoint
124
+ self.cond_methods = condition_strategy
125
+ self.config = self._store_config()
126
+ self.sp_size = sp_size
127
+
128
+ self.visual_geometry_transformer = VisualGeometryTransformer(
129
+ img_size=img_size,
130
+ patch_size=patch_size,
131
+ embed_dim=embed_dim,
132
+ depth=depth,
133
+ num_heads=num_heads,
134
+ mlp_ratio=mlp_ratio,
135
+ num_register_tokens=num_register_tokens,
136
+ block_fn=Block if self.sp_size == 1 else DistBlock,
137
+ normalized_rope=normalized_rope,
138
+ rope_normalize_coords=rope_normalize_coords,
139
+ rope_shift_coords=rope_shift_coords,
140
+ rope_jitter_coords=rope_jitter_coords,
141
+ rope_rescale_coords=rope_rescale_coords,
142
+ enable_cond=enable_cond,
143
+ sampling_strategy=sampling_strategy,
144
+ patch_embed=patch_embed,
145
+ fixed_patch_embed=fixed_patch_embed,
146
+ condition_strategy=condition_strategy,
147
+ intermediate_idxs=self.intermediate_layer_idx[model_size],
148
+ )
149
+
150
+ self._init_heads(embed_dim, patch_size, gs_dim)
151
+
152
+ if enable_bf16:
153
+ self.to = self._bf16_to
154
+
155
+ def _store_config(self):
156
+ """Save the model configuration."""
157
+ return {
158
+ "img_size": self.img_size,
159
+ "patch_size": self.patch_size,
160
+ "embed_dim": self.embed_dim,
161
+ "depth": self.depth,
162
+ "num_heads": self.num_heads,
163
+ "mlp_ratio": self.mlp_ratio,
164
+ "gs_dim": self.gs_dim,
165
+ "num_register_tokens": self.num_register_tokens,
166
+ "normalized_rope": self.normalized_rope,
167
+ "rope_normalize_coords": self.rope_normalize_coords,
168
+ "rope_shift_coords": self.rope_shift_coords,
169
+ "rope_jitter_coords": self.rope_jitter_coords,
170
+ "rope_rescale_coords": self.rope_rescale_coords,
171
+ "enable_cam": self.enable_cam,
172
+ "enable_pts": self.enable_pts,
173
+ "enable_depth": self.enable_depth,
174
+ "enable_depth_mask": self.enable_depth_mask,
175
+ "enable_norm": self.enable_norm,
176
+ "enable_gs": self.enable_gs,
177
+ "patch_embed": self.patch_embed,
178
+ "sampling_strategy": self.sampling,
179
+ "dpt_gradient_checkpoint": self.dpt_checkpoint,
180
+ "condition_strategy": self.cond_methods,
181
+ "model_size": self.model_size,
182
+ }
183
+
184
+ def _init_heads(self, dim, patch_size, gs_dim):
185
+ """Initialize all prediction heads."""
186
+
187
+ if self.enable_cam:
188
+ self.cam_head = CameraHead(
189
+ dim_in=2 * dim,
190
+ block_fn=Block if self.sp_size == 1 else DistBlock,
191
+ )
192
+
193
+ if self.enable_pts:
194
+ self.pts_head = DPTHead(
195
+ dim_in=2 * dim,
196
+ output_dim=4,
197
+ patch_size=patch_size,
198
+ activation="inv_log+expp1",
199
+ gradient_checkpoint=self.dpt_checkpoint,
200
+ )
201
+
202
+ if self.enable_depth:
203
+ self.depth_head = DPTHead(
204
+ dim_in=2 * dim,
205
+ output_dim=2 if not self.enable_depth_mask else 3,
206
+ patch_size=patch_size,
207
+ activation="exp+expp1" if not self.enable_depth_mask else "exp+expp1+linear",
208
+ enable_depth_mask=self.enable_depth_mask,
209
+ gradient_checkpoint=self.dpt_checkpoint,
210
+ )
211
+
212
+ if self.enable_norm:
213
+ self.norm_head = DPTHead(
214
+ dim_in=2 * dim,
215
+ output_dim=4,
216
+ patch_size=patch_size,
217
+ activation="norm+expp1",
218
+ gradient_checkpoint=self.dpt_checkpoint,
219
+ )
220
+
221
+ if self.enable_gs:
222
+ self.gs_head = DPTHead(
223
+ dim_in=2 * dim,
224
+ output_dim=2 if not self.enable_depth_mask else 3,
225
+ patch_size=patch_size,
226
+ features=gs_dim,
227
+ is_gsdpt=True,
228
+ activation="exp+expp1" if not self.enable_depth_mask else "exp+expp1+linear",
229
+ enable_depth_mask=self.enable_depth_mask,
230
+ gradient_checkpoint=self.dpt_checkpoint,
231
+ )
232
+ self.gs_renderer = GaussianSplatRenderer(
233
+ feature_dim=gs_dim,
234
+ sh_degree=0,
235
+ enable_prune=True,
236
+ voxel_size=0.002,
237
+ )
238
+
239
+ def _bf16_to(self, *args, **kwargs):
240
+ """Custom to() for bf16 mode: selectively move heads to target device/dtype."""
241
+ self.visual_geometry_transformer = self.visual_geometry_transformer.to(*args, **kwargs)
242
+ if self.enable_cam:
243
+ self.cam_head = self.cam_head.to(*args, **kwargs)
244
+ if self.enable_pts:
245
+ self.pts_head = self.pts_head.to(*args, **kwargs)
246
+ if self.enable_depth:
247
+ self.depth_head = self.depth_head.to(*args, **kwargs)
248
+ if self.enable_norm:
249
+ self.norm_head = self.norm_head.to(*args, **kwargs)
250
+ if self.enable_gs:
251
+ self.gs_head = self.gs_head.to(*args, **kwargs)
252
+ self.gs_renderer = self.gs_renderer.to(*args, **kwargs)
253
+ return self
254
+
255
+ def forward(
256
+ self,
257
+ views: Dict[str, torch.Tensor],
258
+ cond_flags: List[int] = [0, 0, 0],
259
+ is_inference=True,
260
+ sp_size=1,
261
+ sp_group=None,
262
+ ):
263
+ """Execute forward pass through the WorldMirror model.
264
+
265
+ Args:
266
+ views: Input data dictionary containing 'img' and optional priors.
267
+ cond_flags: Conditioning flags [pose, depth, intrinsics].
268
+ is_inference: Whether running in inference mode.
269
+ sp_size: Sequence parallel size (>1 for multi-GPU).
270
+ sp_group: Process group for SP communication.
271
+
272
+ Returns:
273
+ dict: Prediction results dictionary.
274
+ """
275
+ if self.enable_bf16:
276
+ views['img'] = views['img'].to(torch.bfloat16)
277
+
278
+ imgs = views["img"]
279
+ use_cond = sum(cond_flags) > 0
280
+
281
+ if use_cond:
282
+ priors = self.extract_priors(views)
283
+ token_list, patch_start_idx = self.visual_geometry_transformer(
284
+ imgs, priors, cond_flags=cond_flags,
285
+ enable_bf16=self.enable_bf16, sp_size=sp_size, sp_group=sp_group,
286
+ )
287
+ else:
288
+ token_list, patch_start_idx = self.visual_geometry_transformer(
289
+ imgs, enable_bf16=self.enable_bf16, sp_size=sp_size, sp_group=sp_group,
290
+ )
291
+
292
+ with torch.amp.autocast('cuda', enabled=(not self.enable_bf16), dtype=torch.float32):
293
+ if sp_size > 1:
294
+ preds = self._gen_all_preds_frame_sp(
295
+ token_list, imgs, patch_start_idx, views, cond_flags,
296
+ is_inference, sp_size, sp_group,
297
+ )
298
+ else:
299
+ preds = self._gen_all_preds(
300
+ token_list, imgs, patch_start_idx, views, cond_flags, is_inference,
301
+ )
302
+
303
+ return preds
304
+
305
+ def _gen_all_preds_frame_sp(
306
+ self, token_list, imgs, patch_start_idx, views, cond_flags, is_inference,
307
+ sp_size, sp_group,
308
+ ):
309
+ """Generate predictions with frame-parallel DPT heads for SP inference.
310
+
311
+ Splits S frames across sp_size ranks. Each rank processes S/sp_size frames
312
+ through ALL head types, then Allgather to reconstruct full results.
313
+ CameraHead runs on all frames on every rank (cross-view attention needed).
314
+ """
315
+ preds = {}
316
+ rank = dist.get_rank()
317
+ rank_in_sp = dist.get_group_rank(sp_group, rank)
318
+
319
+ B, S, C_img, H, W = imgs.shape
320
+
321
+ # Determine frame assignment for this rank
322
+ if S >= sp_size:
323
+ base_chunk = S // sp_size
324
+ remainder = S % sp_size
325
+ if rank_in_sp < remainder:
326
+ my_count = base_chunk + 1
327
+ my_start = rank_in_sp * (base_chunk + 1)
328
+ else:
329
+ my_count = base_chunk
330
+ my_start = remainder * (base_chunk + 1) + (rank_in_sp - remainder) * base_chunk
331
+ else:
332
+ if rank_in_sp < S:
333
+ my_count = 1
334
+ my_start = rank_in_sp
335
+ else:
336
+ my_count = 0
337
+ my_start = S
338
+
339
+ my_end = my_start + my_count
340
+ has_frames = my_count > 0
341
+
342
+ if has_frames:
343
+ token_list_chunk = [t[:, my_start:my_end].contiguous() for t in token_list]
344
+ imgs_chunk = imgs[:, my_start:my_end].contiguous()
345
+
346
+ # Camera head: runs on ALL frames on every rank (cross-view attention)
347
+ if self.enable_cam:
348
+ cam_seq = self.cam_head(token_list)
349
+ cam_params = cam_seq[-1]
350
+ preds["camera_params"] = cam_params
351
+ c2w_mat, int_mat = self.transform_camera_vector(cam_params, H, W)
352
+ preds["camera_poses"] = c2w_mat
353
+ preds["camera_intrs"] = int_mat
354
+
355
+ # DPT heads: frame-parallel
356
+ if self.enable_depth:
357
+ if has_frames:
358
+ if self.enable_depth_mask:
359
+ depth_chunk, depth_conf_chunk, depth_mask_logits_chunk = self.depth_head(
360
+ token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
361
+ )
362
+ else:
363
+ depth_chunk, depth_conf_chunk = self.depth_head(
364
+ token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
365
+ )
366
+ else:
367
+ depth_chunk = torch.zeros(B, 0, H, W, 1, dtype=imgs.dtype, device=imgs.device)
368
+ depth_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
369
+ if self.enable_depth_mask:
370
+ depth_mask_logits_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
371
+
372
+ preds["depth"] = self._frame_allgather_variable(depth_chunk, my_count, S, sp_size, sp_group, dim=1)
373
+ preds["depth_conf"] = self._frame_allgather_variable(depth_conf_chunk, my_count, S, sp_size, sp_group, dim=1)
374
+ if self.enable_depth_mask:
375
+ depth_mask_logits_full = self._frame_allgather_variable(
376
+ depth_mask_logits_chunk, my_count, S, sp_size, sp_group, dim=1,
377
+ )
378
+ preds["depth_mask_logits"] = depth_mask_logits_full
379
+ preds["depth_mask"] = depth_mask_logits_full.sigmoid()
380
+
381
+ if self.enable_pts:
382
+ if has_frames:
383
+ pts_chunk, pts_conf_chunk = self.pts_head(
384
+ token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
385
+ )
386
+ else:
387
+ pts_chunk = torch.zeros(B, 0, H, W, 3, dtype=imgs.dtype, device=imgs.device)
388
+ pts_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
389
+
390
+ preds["pts3d"] = self._frame_allgather_variable(pts_chunk, my_count, S, sp_size, sp_group, dim=1)
391
+ preds["pts3d_conf"] = self._frame_allgather_variable(pts_conf_chunk, my_count, S, sp_size, sp_group, dim=1)
392
+
393
+ if self.enable_norm:
394
+ if has_frames:
395
+ normals_chunk, norm_conf_chunk = self.norm_head(
396
+ token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
397
+ )
398
+ else:
399
+ normals_chunk = torch.zeros(B, 0, H, W, 3, dtype=imgs.dtype, device=imgs.device)
400
+ norm_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
401
+
402
+ preds["normals"] = self._frame_allgather_variable(normals_chunk, my_count, S, sp_size, sp_group, dim=1)
403
+ preds["normals_conf"] = self._frame_allgather_variable(norm_conf_chunk, my_count, S, sp_size, sp_group, dim=1)
404
+
405
+ # GS head: frame-parallel, then render on full gathered data
406
+ if self.enable_gs:
407
+ context_preds, context_nums = self.prepare_contexts(views, cond_flags, is_inference)
408
+ gs_token_list = context_preds.get("token_list", token_list)
409
+ gs_imgs = context_preds.get("imgs", imgs)
410
+ gs_S = gs_imgs.shape[1]
411
+
412
+ if gs_S == S and has_frames:
413
+ gs_token_chunk = [t[:, my_start:my_end].contiguous() for t in gs_token_list]
414
+ gs_imgs_chunk = gs_imgs[:, my_start:my_end].contiguous()
415
+
416
+ if self.enable_depth_mask:
417
+ gs_feat_chunk, gs_depth_chunk, gs_depth_conf_chunk, gs_dmask_chunk = self.gs_head(
418
+ gs_token_chunk, images=gs_imgs_chunk, patch_start_idx=patch_start_idx,
419
+ )
420
+ else:
421
+ gs_feat_chunk, gs_depth_chunk, gs_depth_conf_chunk = self.gs_head(
422
+ gs_token_chunk, images=gs_imgs_chunk, patch_start_idx=patch_start_idx,
423
+ )
424
+
425
+ gs_feat = self._frame_allgather_variable(gs_feat_chunk, my_count, gs_S, sp_size, sp_group, dim=1)
426
+ gs_depth = self._frame_allgather_variable(gs_depth_chunk, my_count, gs_S, sp_size, sp_group, dim=1)
427
+ gs_depth_conf = self._frame_allgather_variable(gs_depth_conf_chunk, my_count, gs_S, sp_size, sp_group, dim=1)
428
+ if self.enable_depth_mask:
429
+ gs_depth_mask_logits = self._frame_allgather_variable(
430
+ gs_dmask_chunk, my_count, gs_S, sp_size, sp_group, dim=1,
431
+ )
432
+ preds["gs_depth_mask_logits"] = gs_depth_mask_logits
433
+ preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
434
+
435
+ elif gs_S == S and not has_frames:
436
+ gs_feat_c = self.gs_dim // 2
437
+ gs_feat_chunk = torch.zeros(B, 0, gs_feat_c, H, W, dtype=imgs.dtype, device=imgs.device)
438
+ gs_depth_chunk = torch.zeros(B, 0, H, W, 1, dtype=imgs.dtype, device=imgs.device)
439
+ gs_depth_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
440
+
441
+ gs_feat = self._frame_allgather_variable(gs_feat_chunk, 0, gs_S, sp_size, sp_group, dim=1)
442
+ gs_depth = self._frame_allgather_variable(gs_depth_chunk, 0, gs_S, sp_size, sp_group, dim=1)
443
+ gs_depth_conf = self._frame_allgather_variable(gs_depth_conf_chunk, 0, gs_S, sp_size, sp_group, dim=1)
444
+ if self.enable_depth_mask:
445
+ gs_dmask_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
446
+ gs_depth_mask_logits = self._frame_allgather_variable(
447
+ gs_dmask_chunk, 0, gs_S, sp_size, sp_group, dim=1,
448
+ )
449
+ preds["gs_depth_mask_logits"] = gs_depth_mask_logits
450
+ preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
451
+ else:
452
+ if self.enable_depth_mask:
453
+ gs_feat, gs_depth, gs_depth_conf, gs_depth_mask_logits = self.gs_head(
454
+ gs_token_list, images=gs_imgs, patch_start_idx=patch_start_idx,
455
+ )
456
+ preds["gs_depth_mask_logits"] = gs_depth_mask_logits
457
+ preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
458
+ else:
459
+ gs_feat, gs_depth, gs_depth_conf = self.gs_head(
460
+ gs_token_list, images=gs_imgs, patch_start_idx=patch_start_idx,
461
+ )
462
+
463
+ preds["gs_depth"] = gs_depth
464
+ preds["gs_depth_conf"] = gs_depth_conf
465
+
466
+ preds = self.gs_renderer.render(
467
+ gs_feats=gs_feat,
468
+ images=imgs,
469
+ predictions=preds,
470
+ views=views,
471
+ context_predictions=context_preds,
472
+ is_inference=is_inference,
473
+ )
474
+
475
+ return preds
476
+
477
+ def _frame_allgather_variable(self, chunk, my_count, total_S, sp_size, sp_group, dim=1):
478
+ """Allgather tensors with potentially variable chunk sizes across ranks.
479
+
480
+ Pads each chunk to max_chunk_size, allgathers, then extracts valid frames
481
+ from each rank's chunk to reconstruct the correct frame order.
482
+ """
483
+ if sp_size <= 1:
484
+ return chunk
485
+
486
+ if total_S >= sp_size:
487
+ base_chunk = total_S // sp_size
488
+ remainder = total_S % sp_size
489
+ counts = [(base_chunk + 1) if r < remainder else base_chunk
490
+ for r in range(sp_size)]
491
+ else:
492
+ counts = [1 if r < total_S else 0 for r in range(sp_size)]
493
+
494
+ max_chunk = max(counts)
495
+
496
+ current_size = chunk.shape[dim]
497
+ if current_size < max_chunk:
498
+ pad_size = max_chunk - current_size
499
+ pad_shape = list(chunk.shape)
500
+ pad_shape[dim] = pad_size
501
+ padding = torch.zeros(pad_shape, dtype=chunk.dtype, device=chunk.device)
502
+ chunk = torch.cat([chunk, padding], dim=dim)
503
+
504
+ chunk = chunk.contiguous()
505
+ gathered_list = [torch.zeros_like(chunk) for _ in range(sp_size)]
506
+ dist.all_gather(gathered_list, chunk, group=sp_group)
507
+
508
+ valid_chunks = []
509
+ for r in range(sp_size):
510
+ cnt = counts[r]
511
+ if cnt > 0:
512
+ slices = [slice(None)] * gathered_list[r].dim()
513
+ slices[dim] = slice(0, cnt)
514
+ valid_chunks.append(gathered_list[r][tuple(slices)])
515
+
516
+ return torch.cat(valid_chunks, dim=dim).contiguous()
517
+
518
+ def _gen_all_preds(
519
+ self, token_list, imgs, patch_start_idx, views, cond_flags, is_inference
520
+ ):
521
+ """Generate all enabled predictions (single-GPU path)."""
522
+ preds = {}
523
+
524
+ if self.enable_cam:
525
+ cam_seq = self.cam_head(token_list)
526
+ cam_params = cam_seq[-1]
527
+ preds["camera_params"] = cam_params
528
+ c2w_mat, int_mat = self.transform_camera_vector(
529
+ cam_params, imgs.shape[-2], imgs.shape[-1]
530
+ )
531
+ preds["camera_poses"] = c2w_mat
532
+ preds["camera_intrs"] = int_mat
533
+
534
+ if self.enable_depth:
535
+ if self.enable_depth_mask:
536
+ depth, depth_conf, depth_mask_logits = self.depth_head(
537
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
538
+ )
539
+ preds["depth_mask_logits"] = depth_mask_logits
540
+ preds["depth_mask"] = depth_mask_logits.sigmoid()
541
+ else:
542
+ depth, depth_conf = self.depth_head(
543
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
544
+ )
545
+ preds["depth"] = depth
546
+ preds["depth_conf"] = depth_conf
547
+
548
+ if self.enable_pts:
549
+ pts, pts_conf = self.pts_head(
550
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
551
+ )
552
+ preds["pts3d"] = pts
553
+ preds["pts3d_conf"] = pts_conf
554
+
555
+ if self.enable_norm:
556
+ normals, norm_conf = self.norm_head(
557
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
558
+ )
559
+ preds["normals"] = normals
560
+ preds["normals_conf"] = norm_conf
561
+
562
+ if self.enable_gs:
563
+ context_preds, context_nums = self.prepare_contexts(views, cond_flags, is_inference)
564
+ if self.enable_depth_mask:
565
+ gs_feat, gs_depth, gs_depth_conf, gs_depth_mask_logits = self.gs_head(
566
+ context_preds.get("token_list", token_list),
567
+ images=context_preds.get("imgs", imgs),
568
+ patch_start_idx=patch_start_idx,
569
+ )
570
+ preds["gs_depth_mask_logits"] = gs_depth_mask_logits
571
+ preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
572
+ else:
573
+ gs_feat, gs_depth, gs_depth_conf = self.gs_head(
574
+ context_preds.get("token_list", token_list),
575
+ images=context_preds.get("imgs", imgs),
576
+ patch_start_idx=patch_start_idx,
577
+ )
578
+
579
+ preds["gs_depth"] = gs_depth
580
+ preds["gs_depth_conf"] = gs_depth_conf
581
+
582
+ preds = self.gs_renderer.render(
583
+ gs_feats=gs_feat,
584
+ images=imgs,
585
+ predictions=preds,
586
+ views=views,
587
+ context_predictions=context_preds,
588
+ is_inference=is_inference,
589
+ )
590
+
591
+ return preds
592
+
593
+ def extract_priors(self, views):
594
+ """Extract and normalize geometric priors from input views.
595
+
596
+ Returns (depths, rays, poses) tuple — each may be None if unavailable.
597
+ """
598
+ h, w = views["img"].shape[-2:]
599
+ depths = rays = poses = None
600
+
601
+ if "camera_poses" in views:
602
+ extrinsics = views["camera_poses"][:, :, :3]
603
+ extrinsics = normalize_poses(extrinsics)
604
+ cam_params = extrinsics_to_vector(extrinsics)
605
+ poses = cam_params[:, :, :7]
606
+ if self.enable_bf16:
607
+ poses = poses.to(torch.bfloat16)
608
+
609
+ if "depthmap" in views:
610
+ depth_h, depth_w = views["depthmap"].shape[-2:]
611
+ depths = views["depthmap"]
612
+ if depth_h != h or depth_w != w:
613
+ depths = F.interpolate(depths, size=(h, w), mode="bilinear", align_corners=False)
614
+ depths = normalize_depth(depths)
615
+ if self.enable_bf16:
616
+ depths = depths.to(torch.bfloat16)
617
+
618
+ if "camera_intrs" in views:
619
+ intrinsics = views["camera_intrs"][:, :, :3, :3]
620
+ fx, fy = intrinsics[:, :, 0, 0] / w, intrinsics[:, :, 1, 1] / h
621
+ cx, cy = intrinsics[:, :, 0, 2] / w, intrinsics[:, :, 1, 2] / h
622
+ rays = torch.stack([fx, fy, cx, cy], dim=-1)
623
+ if self.enable_bf16:
624
+ rays = rays.to(torch.bfloat16)
625
+
626
+ return (depths, rays, poses)
627
+
628
+ def transform_camera_vector(self, camera_params, h, w):
629
+ """Convert camera parameter vector to c2w and intrinsic matrices."""
630
+ ext_mat, int_mat = vector_to_camera_matrices(camera_params, image_hw=(h, w))
631
+ homo_row = torch.tensor([0, 0, 0, 1], device=ext_mat.device).view(1, 1, 1, 4)
632
+ homo_row = homo_row.repeat(ext_mat.shape[0], ext_mat.shape[1], 1, 1)
633
+ w2c_mat = torch.cat([ext_mat, homo_row], dim=2)
634
+ try:
635
+ c2w_mat = torch.linalg.inv(w2c_mat)
636
+ except Exception as e:
637
+ print(f"[WorldMirror] linalg.inv fallback to CPU: {e}")
638
+ c2w_mat = torch.linalg.inv(w2c_mat.cpu()).to(camera_params.device)
639
+ return c2w_mat, int_mat
640
+
641
+ def prepare_contexts(self, views, cond_flags, is_inference):
642
+ """Prepare context views for GS rendering (training only, passthrough in inference)."""
643
+ context_preds = {}
644
+ if is_inference:
645
+ return context_preds, views["img"].shape[1]
646
+
647
+ assert self.enable_cam and self.enable_gs
648
+ if "is_target" not in views:
649
+ context_nums = views["img"].shape[1]
650
+ else:
651
+ context_nums = (views["is_target"][0] == False).sum().item()
652
+ context_imgs = views["img"][:, :context_nums]
653
+
654
+ use_cond = sum(cond_flags) > 0
655
+
656
+ if self.enable_bf16:
657
+ context_imgs = context_imgs.to(torch.bfloat16)
658
+
659
+ with torch.amp.autocast('cuda', enabled=(not self.enable_bf16), dtype=torch.bfloat16):
660
+ if use_cond:
661
+ priors = self.extract_priors(views)
662
+ context_priors = (
663
+ prior[:, :context_nums] if prior is not None else None
664
+ for prior in priors
665
+ )
666
+ context_token_list, _ = self.visual_geometry_transformer(
667
+ context_imgs, context_priors, cond_flags=cond_flags,
668
+ enable_bf16=self.enable_bf16,
669
+ )
670
+ else:
671
+ context_token_list, _ = self.visual_geometry_transformer(
672
+ context_imgs, enable_bf16=self.enable_bf16,
673
+ )
674
+
675
+ context_cam_seq = self.cam_head(context_token_list)
676
+ context_cam_params = context_cam_seq[-1]
677
+ context_c2w_mat, context_int_mat = self.transform_camera_vector(
678
+ context_cam_params, context_imgs.shape[-2], context_imgs.shape[-1]
679
+ )
680
+ context_preds["camera_poses"] = context_c2w_mat
681
+ context_preds["camera_intrs"] = context_int_mat
682
+ context_preds["token_list"] = context_token_list
683
+ context_preds["imgs"] = context_imgs
684
+
685
+ return context_preds, context_nums
hyworldmirror/models/utils/__init__.py ADDED
File without changes
hyworldmirror/models/utils/act_gs.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def reg_dense_offsets(xyz, shift=6.0):
6
+ d = xyz.norm(dim=-1, keepdim=True)
7
+ return xyz / d.clamp(min=1e-8) * (torch.exp(d - shift) - torch.exp(-shift))
8
+
9
+ def reg_dense_scales(scales):
10
+ return scales.exp()
11
+
12
+ def reg_dense_rotation(rotations, eps=1e-8):
13
+ return rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
14
+
15
+ def reg_dense_sh(sh):
16
+ return rearrange(sh, '... (d_sh xyz) -> ... d_sh xyz', xyz=3)
17
+
18
+ def reg_dense_opacities(opacities):
19
+ return opacities.sigmoid()
20
+
21
+ def reg_dense_weights(weights):
22
+ return weights.sigmoid()
hyworldmirror/models/utils/camera_utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .rotation import quat_to_rotmat, rotmat_to_quat
3
+
4
+
5
+ def camera_params_to_vector(
6
+ ext, intr, image_hw=None
7
+ ):
8
+ """Convert camera matrices to a compact vector."""
9
+ # ext: (..., 3, 4): Camera-to-world extrinsic [R|t]
10
+ # intr: (..., 3, 3): Intrinsics
11
+ # image_hw: (h, w)
12
+ R = ext[..., :3, :3] # Rotation part
13
+ t = ext[..., :3, 3] # Translation part
14
+ q = rotmat_to_quat(R) # Quaternion (wxyz)
15
+ h, w = image_hw
16
+ fov_v = 2.0 * torch.atan(h * 0.5 / intr[..., 1, 1]) # Vertical FOV
17
+ fov_u = 2.0 * torch.atan(w * 0.5 / intr[..., 0, 0]) # Horizontal FOV
18
+ vec = torch.stack([
19
+ t[..., 0], t[..., 1], t[..., 2],
20
+ q[..., 0], q[..., 1], q[..., 2], q[..., 3],
21
+ fov_v, fov_u
22
+ ], dim=-1).float()
23
+ return vec
24
+
25
+ def extrinsics_to_vector(ext):
26
+ """Convert extrinsics to [t, q] vector."""
27
+ # ext: (..., 3, 4)
28
+ R = ext[..., :3, :3]
29
+ t = ext[..., :3, 3]
30
+ q = rotmat_to_quat(R)
31
+ vec = torch.stack([
32
+ t[..., 0], t[..., 1], t[..., 2],
33
+ q[..., 0], q[..., 1], q[..., 2], q[..., 3]
34
+ ], dim=-1).float()
35
+ return vec
36
+
37
+ def vector_to_extrinsics(cam_vec):
38
+ """Convert [t, q] vector to extrinsic matrix."""
39
+ # cam_vec: (..., 7)
40
+ q = cam_vec[..., 3:7]
41
+ t = cam_vec[..., :3]
42
+ R = quat_to_rotmat(q)
43
+ ext = torch.cat([R, t.unsqueeze(-1)], dim=-1)
44
+ return ext
45
+
46
+ def vector_to_camera_matrices(
47
+ cam_vec, image_hw=None, build_intr=True
48
+ ):
49
+ """Reconstruct extrinsic and intrinsic matrix from vector."""
50
+ # cam_vec: (..., 9)
51
+ intr = None
52
+ # Decompose vector
53
+ t = cam_vec[..., 0:3]
54
+ q = cam_vec[..., 3:7]
55
+ fov_v = cam_vec[..., 7]
56
+ fov_u = cam_vec[..., 8]
57
+
58
+ # Build extrinsic: [R|t]
59
+ R = quat_to_rotmat(q)
60
+ ext = torch.cat([R, t.unsqueeze(-1)], dim=-1)
61
+
62
+ # Build intrinsic if needed
63
+ if build_intr:
64
+ h, w = image_hw
65
+ fy = h * 0.5 / torch.tan(fov_v * 0.5)
66
+ fx = w * 0.5 / torch.tan(fov_u * 0.5)
67
+ shape = cam_vec.shape[:-1] + (3, 3)
68
+ intr = torch.zeros(shape, device=cam_vec.device, dtype=cam_vec.dtype)
69
+ intr[..., 0, 0] = fx
70
+ intr[..., 1, 1] = fy
71
+ intr[..., 0, 2] = w * 0.5
72
+ intr[..., 1, 2] = h * 0.5
73
+ intr[..., 2, 2] = 1.0
74
+
75
+ return ext, intr
hyworldmirror/models/utils/frustum.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+
4
+
5
+ # Calculate the loss mask for the target views in the batch
6
+ @torch.no_grad()
7
+ def calculate_unprojected_mask(views, context_nums):
8
+ '''Calcuate the loss mask for the target views in the batch'''
9
+ target_depth = views["depthmap"][:, context_nums:]
10
+ target_intrinsics = views["camera_intrs"][:, context_nums:]
11
+ target_c2w = views["camera_poses"][:, context_nums:]
12
+ context_depth = views["depthmap"][:, :context_nums]
13
+ context_intrinsics = views["camera_intrs"][:, :context_nums]
14
+ context_c2w = views["camera_poses"][:, :context_nums]
15
+
16
+ target_intrinsics = target_intrinsics[..., :3, :3]
17
+ context_intrinsics = context_intrinsics[..., :3, :3]
18
+
19
+ mask = calculate_in_frustum_mask(
20
+ target_depth, target_intrinsics, target_c2w,
21
+ context_depth, context_intrinsics, context_c2w
22
+ )
23
+ return mask
24
+
25
+ @torch.no_grad()
26
+ def calculate_in_frustum_mask(depth_1, intrinsics_1, c2w_1, depth_2, intrinsics_2, c2w_2):
27
+ """
28
+ A function that takes in the depth, intrinsics and c2w matrices of two sets
29
+ of views, and then works out which of the pixels in the first set of views
30
+ has a direct corresponding pixel in any of views in the second set
31
+
32
+ Args:
33
+ depth_1: (b, v1, h, w)
34
+ intrinsics_1: (b, v1, 3, 3)
35
+ c2w_1: (b, v1, 4, 4)
36
+ depth_2: (b, v2, h, w)
37
+ intrinsics_2: (b, v2, 3, 3)
38
+ c2w_2: (b, v2, 4, 4)
39
+
40
+ Returns:
41
+ torch.Tensor: valid mask with shape (b, v1, v2, h, w).
42
+ """
43
+
44
+ _, v1, h, w = depth_1.shape
45
+ _, v2, _, _ = depth_2.shape
46
+
47
+ # Unproject the depth to get the 3D points in world space
48
+ points_3d = unproject_depth(depth_1[..., None], intrinsics_1, c2w_1) # (b, v1, h, w, 3)
49
+
50
+ # Project the 3D points into the pixel space of all the second views simultaneously
51
+ camera_points = world_space_to_camera_space(points_3d, c2w_2) # (b, v1, v2, h, w, 3)
52
+ points_2d = camera_space_to_pixel_space(camera_points, intrinsics_2) # (b, v1, v2, h, w, 2)
53
+
54
+ # Calculate the depth of each point
55
+ rendered_depth = camera_points[..., 2] # (b, v1, v2, h, w)
56
+
57
+ # We use three conditions to determine if a point should be masked
58
+
59
+ # Condition 1: Check if the points are in the frustum of any of the v2 views
60
+ in_frustum_mask = (
61
+ (points_2d[..., 0] > 0) &
62
+ (points_2d[..., 0] < w) &
63
+ (points_2d[..., 1] > 0) &
64
+ (points_2d[..., 1] < h)
65
+ ) # (b, v1, v2, h, w)
66
+ in_frustum_mask = in_frustum_mask.any(dim=-3) # (b, v1, h, w)
67
+
68
+ # Condition 2: Check if the points have non-zero (i.e. valid) depth in the input view
69
+ non_zero_depth = depth_1 > 1e-6
70
+
71
+ # Condition 3: Check if the points have matching depth to any of the v2
72
+ # views torch.nn.functional.grid_sample expects the input coordinates to
73
+ # be normalized to the range [-1, 1], so we normalize first
74
+ points_2d[..., 0] /= w
75
+ points_2d[..., 1] /= h
76
+ points_2d = points_2d * 2 - 1
77
+ matching_depth = torch.ones_like(rendered_depth, dtype=torch.bool)
78
+ for b in range(depth_1.shape[0]):
79
+ for i in range(v1):
80
+ for j in range(v2):
81
+ depth = einops.rearrange(depth_2[b, j], 'h w -> 1 1 h w')
82
+ coords = einops.rearrange(points_2d[b, i, j], 'h w c -> 1 h w c')
83
+ sampled_depths = torch.nn.functional.grid_sample(depth, coords, align_corners=False)[0, 0]
84
+ matching_depth[b, i, j] = torch.isclose(rendered_depth[b, i, j], sampled_depths, atol=1e-1)
85
+
86
+ matching_depth = matching_depth.any(dim=-3) # (..., v1, h, w)
87
+
88
+ mask = in_frustum_mask & non_zero_depth & matching_depth
89
+ return mask
90
+
91
+ # --- Projections ---
92
+ def homogenize_points(points):
93
+ """Append a '1' along the final dimension of the tensor (i.e. convert xyz->xyz1)"""
94
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
95
+
96
+
97
+ def normalize_homogenous_points(points):
98
+ """Normalize the point vectors"""
99
+ return points / points[..., -1:]
100
+
101
+
102
+ def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
103
+ """
104
+ Convert pixel space points to camera space points.
105
+
106
+ Args:
107
+ pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
108
+ depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
109
+ intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
110
+
111
+ Returns:
112
+ torch.Tensor: Camera space points with shape (b, v, h, w, 3).
113
+ """
114
+ pixel_space_points = homogenize_points(pixel_space_points)
115
+ camera_space_points = torch.einsum('b v i j , h w j -> b v h w i', intrinsics.inverse(), pixel_space_points)
116
+ camera_space_points = camera_space_points * depth
117
+ return camera_space_points
118
+
119
+
120
+ def camera_space_to_world_space(camera_space_points, c2w):
121
+ """
122
+ Convert camera space points to world space points.
123
+
124
+ Args:
125
+ camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
126
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
127
+
128
+ Returns:
129
+ torch.Tensor: World space points with shape (b, v, h, w, 3).
130
+ """
131
+ camera_space_points = homogenize_points(camera_space_points)
132
+ world_space_points = torch.einsum('b v i j , b v h w j -> b v h w i', c2w, camera_space_points)
133
+ return world_space_points[..., :3]
134
+
135
+
136
+ def camera_space_to_pixel_space(camera_space_points, intrinsics):
137
+ """
138
+ Convert camera space points to pixel space points.
139
+
140
+ Args:
141
+ camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
142
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
143
+
144
+ Returns:
145
+ torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
146
+ """
147
+ camera_space_points = normalize_homogenous_points(camera_space_points)
148
+ pixel_space_points = torch.einsum('b u i j , b v u h w j -> b v u h w i', intrinsics, camera_space_points)
149
+ return pixel_space_points[..., :2]
150
+
151
+
152
+ def world_space_to_camera_space(world_space_points, c2w):
153
+ """
154
+ Convert world space points to pixel space points.
155
+
156
+ Args:
157
+ world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
158
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
159
+
160
+ Returns:
161
+ torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
162
+ """
163
+ world_space_points = homogenize_points(world_space_points)
164
+ camera_space_points = torch.einsum('b u i j , b v h w j -> b v u h w i', c2w.inverse(), world_space_points)
165
+ return camera_space_points[..., :3]
166
+
167
+
168
+ def unproject_depth(depth, intrinsics, c2w):
169
+ """
170
+ Turn the depth map into a 3D point cloud in world space
171
+
172
+ Args:
173
+ depth: (b, v, h, w, 1)
174
+ intrinsics: (b, v, 3, 3)
175
+ c2w: (b, v, 4, 4)
176
+
177
+ Returns:
178
+ torch.Tensor: World space points with shape (b, v, h, w, 3).
179
+ """
180
+
181
+ # Compute indices of pixels
182
+ h, w = depth.shape[-3], depth.shape[-2]
183
+ x_grid, y_grid = torch.meshgrid(
184
+ torch.arange(w, device=depth.device, dtype=torch.float32),
185
+ torch.arange(h, device=depth.device, dtype=torch.float32),
186
+ indexing='xy'
187
+ ) # (h, w), (h, w)
188
+
189
+ # Compute coordinates of pixels in camera space
190
+ pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
191
+ camera_points = pixel_space_to_camera_space(pixel_space_points, depth, intrinsics) # (..., h, w, 3)
192
+
193
+ # Convert points to world space
194
+ world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
195
+
196
+ return world_points
hyworldmirror/models/utils/geometry.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Tuple
4
+
5
+ def depth_to_camera_coords(depthmap, camera_intrinsics):
6
+ """
7
+ Convert depth map to 3D camera coordinates.
8
+
9
+ Args:
10
+ depthmap (BxHxW tensor): Batch of depth maps
11
+ camera_intrinsics (Bx3x3 tensor): Camera intrinsics matrix for each camera
12
+
13
+ Returns:
14
+ X_cam (BxHxWx3 tensor): 3D points in camera coordinates
15
+ valid_mask (BxHxW tensor): Mask indicating valid depth pixels
16
+ """
17
+ B, H, W = depthmap.shape
18
+ device = depthmap.device
19
+ dtype = depthmap.dtype
20
+
21
+ # Ensure intrinsics are float
22
+ camera_intrinsics = camera_intrinsics.float()
23
+
24
+ # Extract focal lengths and principal points
25
+ fx = camera_intrinsics[:, 0, 0] # (B,)
26
+ fy = camera_intrinsics[:, 1, 1] # (B,)
27
+ cx = camera_intrinsics[:, 0, 2] # (B,)
28
+ cy = camera_intrinsics[:, 1, 2] # (B,)
29
+
30
+ # Generate pixel grid
31
+ v_grid, u_grid = torch.meshgrid(
32
+ torch.arange(H, dtype=dtype, device=device),
33
+ torch.arange(W, dtype=dtype, device=device),
34
+ indexing='ij'
35
+ )
36
+
37
+ # Reshape for broadcasting: (1, H, W)
38
+ u_grid = u_grid.unsqueeze(0)
39
+ v_grid = v_grid.unsqueeze(0)
40
+
41
+ # Compute 3D camera coordinates
42
+ # X = (u - cx) * Z / fx
43
+ # Y = (v - cy) * Z / fy
44
+ # Z = depth
45
+ z_cam = depthmap # (B, H, W)
46
+ x_cam = (u_grid - cx.view(B, 1, 1)) * z_cam / fx.view(B, 1, 1)
47
+ y_cam = (v_grid - cy.view(B, 1, 1)) * z_cam / fy.view(B, 1, 1)
48
+
49
+ # Stack to form (B, H, W, 3)
50
+ X_cam = torch.stack([x_cam, y_cam, z_cam], dim=-1)
51
+
52
+ # Valid depth mask
53
+ valid_mask = depthmap > 0.0
54
+
55
+ return X_cam, valid_mask
56
+
57
+ def depth_to_world_coords_points(
58
+ depth_map: torch.Tensor, extrinsic: torch.Tensor, intrinsic: torch.Tensor, eps=1e-8
59
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ """
61
+ Convert a batch of depth maps to world coordinates.
62
+
63
+ Args:
64
+ depth_map (torch.Tensor): (B, H, W) Depth map
65
+ extrinsic (torch.Tensor): (B, 4, 4) Camera extrinsic matrix (camera-to-world transformation)
66
+ intrinsic (torch.Tensor): (B, 3, 3) Camera intrinsic matrix
67
+
68
+ Returns:
69
+ world_coords_points (torch.Tensor): (B, H, W, 3) World coordinates
70
+ camera_points (torch.Tensor): (B, H, W, 3) Camera coordinates
71
+ point_mask (torch.Tensor): (B, H, W) Valid depth mask
72
+ """
73
+ if depth_map is None:
74
+ return None, None, None
75
+
76
+ # Valid depth mask (B, H, W)
77
+ point_mask = depth_map > eps
78
+
79
+ # Convert depth map to camera coordinates (B, H, W, 3)
80
+ camera_points, _ = depth_to_camera_coords(depth_map, intrinsic)
81
+
82
+ # Apply extrinsic matrix (camera -> world)
83
+ R_cam_to_world = extrinsic[:, :3, :3] # (B, 3, 3)
84
+ t_cam_to_world = extrinsic[:, :3, 3] # (B, 3)
85
+
86
+ # Transform (B, H, W, 3) x (B, 3, 3)^T + (B, 3) -> (B, H, W, 3)
87
+ world_coords_points = torch.einsum('bhwi,bji->bhwj', camera_points, R_cam_to_world) + t_cam_to_world[:, None, None, :]
88
+
89
+ return world_coords_points, camera_points, point_mask
90
+
91
+
92
+ def closed_form_inverse_se3(se3: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Efficiently invert batched SE(3) matrices of shape (B, 4, 4).
95
+
96
+ Args:
97
+ se3 (torch.Tensor): (B, 4, 4) Transformation matrices
98
+
99
+ Returns:
100
+ out (torch.Tensor): (B, 4, 4) Inverse transformation matrices
101
+ """
102
+ assert se3.ndim == 3 and se3.shape[1:] == (4, 4), f"se3 must be (B, 4, 4), got {se3.shape}"
103
+ R = se3[:, :3, :3] # (B, 3, 3)
104
+ t = se3[:, :3, 3] # (B, 3)
105
+ Rt = R.transpose(1, 2) # (B, 3, 3)
106
+ t_inv = -torch.bmm(Rt, t.unsqueeze(-1)).squeeze(-1) # (B, 3)
107
+ out = se3.new_zeros(se3.shape)
108
+ out[:, :3, :3] = Rt
109
+ out[:, :3, 3] = t_inv
110
+ out[:, 3, 3] = 1.0
111
+ return out
hyworldmirror/models/utils/grid.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
5
+ """
6
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
7
+
8
+ Args:
9
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
10
+ embed_dim: Output channel dimension for embeddings
11
+ omega_0: Base frequency for sinusoidal encoding
12
+
13
+ Returns:
14
+ Tensor of shape (H, W, embed_dim) with positional embeddings
15
+ """
16
+ H, W, grid_dim = pos_grid.shape
17
+ assert grid_dim == 2
18
+ assert embed_dim % 2 == 0
19
+
20
+ device = pos_grid.device
21
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
22
+
23
+ # Generate frequency bands
24
+ omega = torch.arange(embed_dim // 4, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
25
+ omega /= embed_dim / 4.0
26
+ omega = 1.0 / omega_0**omega # (D/4,)
27
+
28
+ # Process x and y coordinates separately
29
+ pos_x = pos_flat[:, 0].reshape(-1) # (H*W,)
30
+ pos_y = pos_flat[:, 1].reshape(-1) # (H*W,)
31
+
32
+ # Compute outer products
33
+ out_x = torch.einsum("m,d->md", pos_x, omega) # (H*W, D/4)
34
+ out_y = torch.einsum("m,d->md", pos_y, omega) # (H*W, D/4)
35
+
36
+ # Apply sin and cos
37
+ emb_x = torch.cat([torch.sin(out_x), torch.cos(out_x)], dim=1) # (H*W, D/2)
38
+ emb_y = torch.cat([torch.sin(out_y), torch.cos(out_y)], dim=1) # (H*W, D/2)
39
+
40
+ # Combine x and y embeddings
41
+ emb = torch.cat([emb_x, emb_y], dim=-1) # (H*W, D)
42
+
43
+ return emb.float().view(H, W, embed_dim) # [H, W, D]
44
+
45
+
46
+ # Inspired by https://github.com/microsoft/moge
47
+ def create_uv_grid(
48
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
49
+ ) -> torch.Tensor:
50
+ """
51
+ Create a normalized UV grid of shape (width, height, 2).
52
+
53
+ The grid spans horizontally and vertically according to an aspect ratio,
54
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
55
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
56
+
57
+ Args:
58
+ width (int): Number of points horizontally.
59
+ height (int): Number of points vertically.
60
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
61
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
62
+ device (torch.device, optional): Device on which the tensor is created.
63
+
64
+ Returns:
65
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
66
+ """
67
+ # Derive aspect ratio if not explicitly provided
68
+ if aspect_ratio is None:
69
+ aspect_ratio = float(width) / float(height)
70
+
71
+ # Compute normalized spans for X and Y
72
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
73
+ span_x = aspect_ratio / diag_factor
74
+ span_y = 1.0 / diag_factor
75
+
76
+ # Establish the linspace boundaries
77
+ left_x = -span_x * (width - 1) / width
78
+ right_x = span_x * (width - 1) / width
79
+ top_y = -span_y * (height - 1) / height
80
+ bottom_y = span_y * (height - 1) / height
81
+
82
+ # Generate 1D coordinates
83
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
84
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
85
+
86
+ # Create 2D meshgrid (width x height) and stack into UV
87
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
88
+ uv_grid = torch.stack((uu, vv), dim=-1)
89
+
90
+ return uv_grid
hyworldmirror/models/utils/priors.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def normalize_poses(extrinsics, padding=0.1, return_stats=False):
5
+ """
6
+ Normalize camera positions to unit cube, processing each batch separately
7
+
8
+ Args:
9
+ extrinsics: Camera extrinsic matrices with shape (B, S, 3, 4)
10
+ padding: Boundary space within [0,1] range to prevent values near boundaries
11
+ return_stats: Whether to return normalization statistics
12
+
13
+ Returns:
14
+ normalized_extrinsics: Normalized extrinsic matrices
15
+ (optional) stats: Dictionary containing scale and translation information
16
+ """
17
+ B, S, _, _ = extrinsics.shape
18
+ device = extrinsics.device
19
+
20
+ # Check input validity and handle NaN/Inf values
21
+ for i in range(B):
22
+ if torch.isnan(extrinsics[i]).any() or torch.isinf(extrinsics[i]).any():
23
+ print(f"Warning: dataset sample has NaN/Inf in extrinsics")
24
+ extrinsics[i] = torch.nan_to_num(
25
+ extrinsics[i], nan=0.0, posinf=1e6, neginf=-1e6
26
+ )
27
+
28
+ normalized_extrinsics = extrinsics.clone()
29
+
30
+ # Store normalization parameters if needed
31
+ if return_stats:
32
+ stats = {
33
+ 'scale_factors': torch.zeros(B, device=device),
34
+ 'translation_vectors': torch.zeros(B, 3, device=device)
35
+ }
36
+
37
+ for b in range(B):
38
+ # Extract camera positions for this batch
39
+ positions = extrinsics[b, :, :3, 3] # (S, 3)
40
+
41
+ # Filter valid positions to ignore outliers
42
+ valid_mask = torch.isfinite(positions).all(dim=1) # (S,)
43
+
44
+ if valid_mask.sum() == 0:
45
+ # No valid positions, use default values
46
+ print(f"Warning: Batch {b} has no valid camera positions")
47
+ normalized_extrinsics[b, :, :3, 3] = 0.5 # Place at center
48
+ if return_stats:
49
+ stats['scale_factors'][b] = 1.0
50
+ stats['translation_vectors'][b] = 0.0
51
+ continue
52
+
53
+ valid_positions = positions[valid_mask]
54
+
55
+ # Calculate bounds using percentiles for robustness
56
+ if valid_positions.shape[0] > 10:
57
+ # Use 5% and 95% percentiles instead of min/max
58
+ min_pos = torch.quantile(valid_positions, 0.05, dim=0)
59
+ max_pos = torch.quantile(valid_positions, 0.95, dim=0)
60
+ else:
61
+ # Too few samples, use min/max
62
+ min_pos = torch.min(valid_positions, dim=0)[0]
63
+ max_pos = torch.max(valid_positions, dim=0)[0]
64
+
65
+ # Calculate scale factor considering all dimensions
66
+ pos_range = max_pos - min_pos
67
+
68
+ # Add small epsilon to prevent dimension collapse
69
+ eps = torch.maximum(
70
+ torch.tensor(1e-6, device=device),
71
+ torch.abs(max_pos) * 1e-6
72
+ )
73
+ pos_range = torch.maximum(pos_range, eps)
74
+
75
+ # Use maximum range as scale factor for uniform scaling
76
+ scale_factor = torch.max(pos_range)
77
+ scale_factor = torch.clamp(scale_factor, min=1e-6, max=1e6)
78
+
79
+ # Calculate center point for centering
80
+ center = (min_pos + max_pos) / 2.0
81
+
82
+ # Normalize: center first, then scale with padding
83
+ actual_scale = scale_factor / (1 - 2 * padding)
84
+ normalized_positions = (positions - center) / actual_scale + 0.5
85
+
86
+ # Ensure all values are within valid range
87
+ normalized_positions = torch.clamp(normalized_positions, 0.0, 1.0)
88
+
89
+ # Handle invalid positions by setting them to scene center
90
+ invalid_mask = ~torch.isfinite(positions).all(dim=1)
91
+ if invalid_mask.any():
92
+ normalized_positions[invalid_mask] = 0.5
93
+
94
+ normalized_extrinsics[b, :, :3, 3] = normalized_positions
95
+
96
+ if return_stats:
97
+ stats['scale_factors'][b] = actual_scale
98
+ stats['translation_vectors'][b] = center
99
+
100
+ # Final validation
101
+ assert torch.isfinite(normalized_extrinsics).all(), "Output contains non-finite values"
102
+
103
+ if return_stats:
104
+ return normalized_extrinsics, stats
105
+ return normalized_extrinsics
106
+
107
+
108
+ def normalize_depth(depth, eps=1e-6, min_percentile=1, max_percentile=99):
109
+ """
110
+ Normalize depth values to [0, 1] range using percentile-based scaling.
111
+
112
+ Args:
113
+ depth: Input depth tensor with shape (B, S, H, W)
114
+ eps: Small epsilon value to prevent division by zero
115
+ min_percentile: Lower percentile for robust min calculation (default: 1)
116
+ max_percentile: Upper percentile for robust max calculation (default: 99)
117
+
118
+ Returns:
119
+ normalized_depth: Depth tensor normalized to [0, 1] range with same shape (B, S, H, W)
120
+ """
121
+ B, S, H, W = depth.shape
122
+ depth = depth.flatten(0,1) # [B*S, H, W]
123
+
124
+ # Handle invalid values
125
+ depth = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0)
126
+
127
+ normalized_list = []
128
+ for i in range(depth.shape[0]):
129
+ depth_img = depth[i] # [H, W]
130
+ depth_flat = depth_img.flatten()
131
+
132
+ # Filter out zero values if needed
133
+ non_zero_mask = depth_flat > 0
134
+ if non_zero_mask.sum() > 0:
135
+ values_to_use = depth_flat[non_zero_mask]
136
+ else:
137
+ values_to_use = depth_flat
138
+
139
+ # Only calculate percentiles when there are enough values
140
+ if values_to_use.numel() > 100: # Ensure enough samples for percentile calculation
141
+ # Calculate min and max percentiles
142
+ depth_min = torch.quantile(values_to_use, min_percentile/100.0)
143
+ depth_max = torch.quantile(values_to_use, max_percentile/100.0)
144
+ else:
145
+ # If too few samples, use min/max values
146
+ depth_min = values_to_use.min()
147
+ depth_max = values_to_use.max()
148
+
149
+ # Handle case where max equals min
150
+ if depth_max == depth_min:
151
+ depth_max = depth_min + 1.0
152
+
153
+ # Use relative epsilon
154
+ scale = torch.abs(depth_max - depth_min)
155
+ eps_val = max(eps, scale.item() * eps)
156
+
157
+ # Perform normalization
158
+ depth_norm_img = (depth_img - depth_min) / (depth_max - depth_min + eps_val)
159
+
160
+ # Ensure output is within [0,1] range
161
+ depth_norm_img = torch.clamp(depth_norm_img, 0.0, 1.0)
162
+
163
+ normalized_list.append(depth_norm_img)
164
+
165
+ # Recombine all normalized images
166
+ depth_norm = torch.stack(normalized_list)
167
+
168
+ return depth_norm.reshape(B, S, H, W)
hyworldmirror/models/utils/rotation.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
2
+
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def quat_to_rotmat(quaternions: torch.Tensor) -> torch.Tensor:
9
+ """
10
+ Quaternion Order: XYZW or say ijkr, scalar-last
11
+
12
+ Convert rotations given as quaternions to rotation matrices.
13
+ Args:
14
+ quaternions: quaternions with real part last,
15
+ as tensor of shape (..., 4).
16
+
17
+ Returns:
18
+ Rotation matrices as tensor of shape (..., 3, 3).
19
+ """
20
+ i, j, k, r = torch.unbind(quaternions, -1)
21
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
22
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
23
+
24
+ o = torch.stack(
25
+ (
26
+ 1 - two_s * (j * j + k * k),
27
+ two_s * (i * j - k * r),
28
+ two_s * (i * k + j * r),
29
+ two_s * (i * j + k * r),
30
+ 1 - two_s * (i * i + k * k),
31
+ two_s * (j * k - i * r),
32
+ two_s * (i * k - j * r),
33
+ two_s * (j * k + i * r),
34
+ 1 - two_s * (i * i + j * j),
35
+ ),
36
+ -1,
37
+ )
38
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
39
+
40
+
41
+ def rotmat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Convert rotations given as rotation matrices to quaternions.
44
+
45
+ Args:
46
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
47
+
48
+ Returns:
49
+ quaternions with real part last, as tensor of shape (..., 4).
50
+ Quaternion Order: XYZW or say ijkr, scalar-last
51
+ """
52
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
53
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
54
+
55
+ batch_dim = matrix.shape[:-2]
56
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
57
+
58
+ q_abs = _sqrt_positive_part(
59
+ torch.stack(
60
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
61
+ )
62
+ )
63
+
64
+ # we produce the desired quaternion multiplied by each of r, i, j, k
65
+ quat_by_rijk = torch.stack(
66
+ [
67
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
68
+ # `int`.
69
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
70
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
71
+ # `int`.
72
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
73
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
74
+ # `int`.
75
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
76
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
77
+ # `int`.
78
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
79
+ ],
80
+ dim=-2,
81
+ )
82
+
83
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
84
+ # the candidate won't be picked.
85
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
86
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
87
+
88
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
89
+ # forall i; we pick the best-conditioned one (with the largest denominator)
90
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
91
+
92
+ # Convert from rijk to ijkr
93
+ out = out[..., [1, 2, 3, 0]]
94
+
95
+ out = standardize_quaternion(out)
96
+
97
+ return out
98
+
99
+
100
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Returns torch.sqrt(torch.max(0, x))
103
+ but with a zero subgradient where x is 0.
104
+ """
105
+ ret = torch.zeros_like(x)
106
+ positive_mask = x > 0
107
+ if torch.is_grad_enabled():
108
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
109
+ else:
110
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
111
+ return ret
112
+
113
+
114
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
115
+ """
116
+ Convert a unit quaternion to a standard form: one in which the real
117
+ part is non negative.
118
+
119
+ Args:
120
+ quaternions: Quaternions with real part last,
121
+ as tensor of shape (..., 4).
122
+
123
+ Returns:
124
+ Standardized quaternions as tensor of shape (..., 4).
125
+ """
126
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
hyworldmirror/models/utils/sh_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ C0 = 0.28209479177387814
25
+ C1 = 0.4886025119029199
26
+ C2 = [
27
+ 1.0925484305920792,
28
+ -1.0925484305920792,
29
+ 0.31539156525252005,
30
+ -1.0925484305920792,
31
+ 0.5462742152960396
32
+ ]
33
+ C3 = [
34
+ -0.5900435899266435,
35
+ 2.890611442640554,
36
+ -0.4570457994644658,
37
+ 0.3731763325901154,
38
+ -0.4570457994644658,
39
+ 1.445305721320277,
40
+ -0.5900435899266435
41
+ ]
42
+ C4 = [
43
+ 2.5033429417967046,
44
+ -1.7701307697799304,
45
+ 0.9461746957575601,
46
+ -0.6690465435572892,
47
+ 0.10578554691520431,
48
+ -0.6690465435572892,
49
+ 0.47308734787878004,
50
+ -1.7701307697799304,
51
+ 0.6258357354491761,
52
+ ]
53
+
54
+
55
+ def eval_sh(deg, sh, dirs):
56
+ """
57
+ Evaluate spherical harmonics at unit directions
58
+ using hardcoded SH polynomials.
59
+ Works with torch/np/jnp.
60
+ ... Can be 0 or more batch dimensions.
61
+ Args:
62
+ deg: int SH deg. Currently, 0-3 supported
63
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
64
+ dirs: jnp.ndarray unit directions [..., 3]
65
+ Returns:
66
+ [..., C]
67
+ """
68
+ assert deg <= 4 and deg >= 0
69
+ coeff = (deg + 1) ** 2
70
+ assert sh.shape[-1] >= coeff
71
+
72
+ result = C0 * sh[..., 0]
73
+ if deg > 0:
74
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
75
+ result = (result -
76
+ C1 * y * sh[..., 1] +
77
+ C1 * z * sh[..., 2] -
78
+ C1 * x * sh[..., 3])
79
+
80
+ if deg > 1:
81
+ xx, yy, zz = x * x, y * y, z * z
82
+ xy, yz, xz = x * y, y * z, x * z
83
+ result = (result +
84
+ C2[0] * xy * sh[..., 4] +
85
+ C2[1] * yz * sh[..., 5] +
86
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
87
+ C2[3] * xz * sh[..., 7] +
88
+ C2[4] * (xx - yy) * sh[..., 8])
89
+
90
+ if deg > 2:
91
+ result = (result +
92
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
93
+ C3[1] * xy * z * sh[..., 10] +
94
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
95
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
96
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
97
+ C3[5] * z * (xx - yy) * sh[..., 14] +
98
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
99
+
100
+ if deg > 3:
101
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
102
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
103
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
104
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
105
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
106
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
107
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
108
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
109
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
110
+ return result
111
+
112
+ def RGB2SH(rgb):
113
+ return (rgb - 0.5) / C0
114
+
115
+ def SH2RGB(sh):
116
+ return sh * C0 + 0.5
hyworldmirror/utils/__init__.py ADDED
File without changes
hyworldmirror/utils/geometry.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for geometry operations.
3
+
4
+ References: DUSt3R, MoGe
5
+ """
6
+
7
+ from numbers import Number
8
+ from typing import Tuple, Union
9
+
10
+ import numpy as np
11
+ from .warnings import no_warnings
12
+
13
+
14
+ def colmap_to_opencv_intrinsics(K):
15
+ """
16
+ Modify camera intrinsics to follow a different convention.
17
+ Coordinates of the center of the top-left pixels are by default:
18
+ - (0.5, 0.5) in Colmap
19
+ - (0,0) in OpenCV
20
+ """
21
+ K = K.copy()
22
+ K[0, 2] -= 0.5
23
+ K[1, 2] -= 0.5
24
+
25
+ return K
26
+
27
+
28
+ def opencv_to_colmap_intrinsics(K):
29
+ """
30
+ Modify camera intrinsics to follow a different convention.
31
+ Coordinates of the center of the top-left pixels are by default:
32
+ - (0.5, 0.5) in Colmap
33
+ - (0,0) in OpenCV
34
+ """
35
+ K = K.copy()
36
+ K[0, 2] += 0.5
37
+ K[1, 2] += 0.5
38
+
39
+ return K
40
+
41
+
42
+ def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12):
43
+ """
44
+ Compute angle difference between 3D vectors using NumPy.
45
+
46
+ Args:
47
+ v1 (np.ndarray): First vector of shape (..., 3)
48
+ v2 (np.ndarray): Second vector of shape (..., 3)
49
+ eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12.
50
+
51
+ Returns:
52
+ np.ndarray: Angle differences in radians
53
+ """
54
+ return np.arctan2(
55
+ np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1)
56
+ )
57
+
58
+
59
+ @no_warnings(category=RuntimeWarning)
60
+ def points_to_normals(
61
+ point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None
62
+ ) -> np.ndarray:
63
+ """
64
+ Calculate normal map from point map. Value range is [-1, 1].
65
+
66
+ Args:
67
+ point (np.ndarray): shape (height, width, 3), point map
68
+ mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None.
69
+ edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None.
70
+
71
+ Returns:
72
+ normal (np.ndarray): shape (height, width, 3), normal map.
73
+ """
74
+ height, width = point.shape[-3:-1]
75
+ has_mask = mask is not None
76
+
77
+ if mask is None:
78
+ mask = np.ones_like(point[..., 0], dtype=bool)
79
+ mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
80
+ mask_pad[1:-1, 1:-1] = mask
81
+ mask = mask_pad
82
+
83
+ pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
84
+ pts[1:-1, 1:-1, :] = point
85
+ up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
86
+ left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
87
+ down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
88
+ right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
89
+ normal = np.stack(
90
+ [
91
+ np.cross(up, left, axis=-1),
92
+ np.cross(left, down, axis=-1),
93
+ np.cross(down, right, axis=-1),
94
+ np.cross(right, up, axis=-1),
95
+ ]
96
+ )
97
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
98
+
99
+ valid = (
100
+ np.stack(
101
+ [
102
+ mask[:-2, 1:-1] & mask[1:-1, :-2],
103
+ mask[1:-1, :-2] & mask[2:, 1:-1],
104
+ mask[2:, 1:-1] & mask[1:-1, 2:],
105
+ mask[1:-1, 2:] & mask[:-2, 1:-1],
106
+ ]
107
+ )
108
+ & mask[None, 1:-1, 1:-1]
109
+ )
110
+ if edge_threshold is not None:
111
+ view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal)
112
+ view_angle = np.minimum(view_angle, np.pi - view_angle)
113
+ valid = valid & (view_angle < np.deg2rad(edge_threshold))
114
+
115
+ normal = (normal * valid[..., None]).sum(axis=0)
116
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
117
+
118
+ if has_mask:
119
+ normal_mask = valid.any(axis=0)
120
+ normal = np.where(normal_mask[..., None], normal, 0)
121
+ return normal, normal_mask
122
+ else:
123
+ return normal
124
+
125
+
126
+ def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
127
+ """
128
+ Create a sliding window view of the input array along a specified axis.
129
+
130
+ This function creates a memory-efficient view of the input array with sliding windows
131
+ of the specified size and stride. The window dimension is appended to the end of the
132
+ output array's shape. This is useful for operations like convolution, pooling, or
133
+ any analysis that requires examining local neighborhoods in the data.
134
+
135
+ Args:
136
+ x (np.ndarray): Input array with shape (..., axis_size, ...)
137
+ window_size (int): Size of the sliding window
138
+ stride (int): Stride of the sliding window (step size between consecutive windows)
139
+ axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis)
140
+
141
+ Returns:
142
+ np.ndarray: View of the input array with shape (..., n_windows, ..., window_size),
143
+ where n_windows = (axis_size - window_size + 1) // stride
144
+
145
+ Raises:
146
+ AssertionError: If window_size is larger than the size of the specified axis
147
+
148
+ Example:
149
+ >>> x = np.array([1, 2, 3, 4, 5, 6])
150
+ >>> sliding_window_1d(x, window_size=3, stride=2)
151
+ array([[1, 2, 3],
152
+ [3, 4, 5]])
153
+ """
154
+ assert x.shape[axis] >= window_size, (
155
+ f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
156
+ )
157
+ axis = axis % x.ndim
158
+ shape = (
159
+ *x.shape[:axis],
160
+ (x.shape[axis] - window_size + 1) // stride,
161
+ *x.shape[axis + 1 :],
162
+ window_size,
163
+ )
164
+ strides = (
165
+ *x.strides[:axis],
166
+ stride * x.strides[axis],
167
+ *x.strides[axis + 1 :],
168
+ x.strides[axis],
169
+ )
170
+ x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
171
+ return x_sliding
172
+
173
+
174
+ def sliding_window_nd(
175
+ x: np.ndarray,
176
+ window_size: Tuple[int, ...],
177
+ stride: Tuple[int, ...],
178
+ axis: Tuple[int, ...],
179
+ ) -> np.ndarray:
180
+ """
181
+ Create sliding windows along multiple dimensions of the input array.
182
+
183
+ This function applies sliding_window_1d sequentially along multiple axes to create
184
+ N-dimensional sliding windows. This is useful for operations that need to examine
185
+ local neighborhoods in multiple dimensions simultaneously.
186
+
187
+ Args:
188
+ x (np.ndarray): Input array
189
+ window_size (Tuple[int, ...]): Size of the sliding window for each axis
190
+ stride (Tuple[int, ...]): Stride of the sliding window for each axis
191
+ axis (Tuple[int, ...]): Axes to perform sliding window over
192
+
193
+ Returns:
194
+ np.ndarray: Array with sliding windows along the specified dimensions.
195
+ The window dimensions are appended to the end of the shape.
196
+
197
+ Note:
198
+ The length of window_size, stride, and axis tuples must be equal.
199
+
200
+ Example:
201
+ >>> x = np.random.rand(10, 10)
202
+ >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1))
203
+ >>> # Creates 3x3 sliding windows with stride 2 in both dimensions
204
+ """
205
+ axis = [axis[i] % x.ndim for i in range(len(axis))]
206
+ for i in range(len(axis)):
207
+ x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
208
+ return x
209
+
210
+
211
+ def sliding_window_2d(
212
+ x: np.ndarray,
213
+ window_size: Union[int, Tuple[int, int]],
214
+ stride: Union[int, Tuple[int, int]],
215
+ axis: Tuple[int, int] = (-2, -1),
216
+ ) -> np.ndarray:
217
+ """
218
+ Create 2D sliding windows over the input array.
219
+
220
+ Convenience function for creating 2D sliding windows, commonly used for image
221
+ processing operations like convolution, pooling, or patch extraction.
222
+
223
+ Args:
224
+ x (np.ndarray): Input array
225
+ window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window.
226
+ If int, same size is used for both dimensions.
227
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window.
228
+ If int, same stride is used for both dimensions.
229
+ axis (Tuple[int, int], optional): Two axes to perform sliding window over.
230
+ Defaults to (-2, -1) (last two dimensions).
231
+
232
+ Returns:
233
+ np.ndarray: Array with 2D sliding windows. The window dimensions (height, width)
234
+ are appended to the end of the shape.
235
+
236
+ Example:
237
+ >>> image = np.random.rand(100, 100)
238
+ >>> patches = sliding_window_2d(image, window_size=8, stride=4)
239
+ >>> # Creates 8x8 patches with stride 4 from the image
240
+ """
241
+ if isinstance(window_size, int):
242
+ window_size = (window_size, window_size)
243
+ if isinstance(stride, int):
244
+ stride = (stride, stride)
245
+ return sliding_window_nd(x, window_size, stride, axis)
246
+
247
+
248
+ def max_pool_1d(
249
+ x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1
250
+ ):
251
+ """
252
+ Perform 1D max pooling on the input array.
253
+
254
+ Max pooling reduces the dimensionality of the input by taking the maximum value
255
+ within each sliding window. This is commonly used in neural networks and signal
256
+ processing for downsampling and feature extraction.
257
+
258
+ Args:
259
+ x (np.ndarray): Input array
260
+ kernel_size (int): Size of the pooling kernel
261
+ stride (int): Stride of the pooling operation
262
+ padding (int, optional): Amount of padding to add on both sides. Defaults to 0.
263
+ axis (int, optional): Axis to perform max pooling over. Defaults to -1.
264
+
265
+ Returns:
266
+ np.ndarray: Max pooled array with reduced size along the specified axis
267
+
268
+ Note:
269
+ - For floating point arrays, padding is done with np.nan values
270
+ - For integer arrays, padding is done with the minimum value of the dtype
271
+ - np.nanmax is used to handle NaN values in the computation
272
+
273
+ Example:
274
+ >>> x = np.array([1, 3, 2, 4, 5, 1, 2])
275
+ >>> max_pool_1d(x, kernel_size=3, stride=2)
276
+ array([3, 5, 2])
277
+ """
278
+ axis = axis % x.ndim
279
+ if padding > 0:
280
+ fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min
281
+ padding_arr = np.full(
282
+ (*x.shape[:axis], padding, *x.shape[axis + 1 :]),
283
+ fill_value=fill_value,
284
+ dtype=x.dtype,
285
+ )
286
+ x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
287
+ a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
288
+ max_pool = np.nanmax(a_sliding, axis=-1)
289
+ return max_pool
290
+
291
+
292
+ def max_pool_nd(
293
+ x: np.ndarray,
294
+ kernel_size: Tuple[int, ...],
295
+ stride: Tuple[int, ...],
296
+ padding: Tuple[int, ...],
297
+ axis: Tuple[int, ...],
298
+ ) -> np.ndarray:
299
+ """
300
+ Perform N-dimensional max pooling on the input array.
301
+
302
+ This function applies max_pool_1d sequentially along multiple axes to perform
303
+ multi-dimensional max pooling. This is useful for downsampling multi-dimensional
304
+ data while preserving the most important features.
305
+
306
+ Args:
307
+ x (np.ndarray): Input array
308
+ kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis
309
+ stride (Tuple[int, ...]): Stride of the pooling operation for each axis
310
+ padding (Tuple[int, ...]): Amount of padding for each axis
311
+ axis (Tuple[int, ...]): Axes to perform max pooling over
312
+
313
+ Returns:
314
+ np.ndarray: Max pooled array with reduced size along the specified axes
315
+
316
+ Note:
317
+ The length of kernel_size, stride, padding, and axis tuples must be equal.
318
+ Max pooling is applied sequentially along each axis in the order specified.
319
+
320
+ Example:
321
+ >>> x = np.random.rand(10, 10, 10)
322
+ >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2),
323
+ ... padding=(0, 0, 0), axis=(-3, -2, -1))
324
+ >>> # Reduces each dimension by half with 2x2x2 max pooling
325
+ """
326
+ for i in range(len(axis)):
327
+ x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
328
+ return x
329
+
330
+
331
+ def max_pool_2d(
332
+ x: np.ndarray,
333
+ kernel_size: Union[int, Tuple[int, int]],
334
+ stride: Union[int, Tuple[int, int]],
335
+ padding: Union[int, Tuple[int, int]],
336
+ axis: Tuple[int, int] = (-2, -1),
337
+ ):
338
+ """
339
+ Perform 2D max pooling on the input array.
340
+
341
+ Convenience function for 2D max pooling, commonly used in computer vision
342
+ and image processing for downsampling images while preserving important features.
343
+
344
+ Args:
345
+ x (np.ndarray): Input array
346
+ kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel.
347
+ If int, same size is used for both dimensions.
348
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation.
349
+ If int, same stride is used for both dimensions.
350
+ padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions.
351
+ If int, same padding is used for both dimensions.
352
+ axis (Tuple[int, int], optional): Two axes to perform max pooling over.
353
+ Defaults to (-2, -1) (last two dimensions).
354
+
355
+ Returns:
356
+ np.ndarray: 2D max pooled array with reduced size along the specified axes
357
+
358
+ Example:
359
+ >>> image = np.random.rand(64, 64)
360
+ >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0)
361
+ >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling
362
+ """
363
+ if isinstance(kernel_size, Number):
364
+ kernel_size = (kernel_size, kernel_size)
365
+ if isinstance(stride, Number):
366
+ stride = (stride, stride)
367
+ if isinstance(padding, Number):
368
+ padding = (padding, padding)
369
+ axis = tuple(axis)
370
+ return max_pool_nd(x, kernel_size, stride, padding, axis)
371
+
372
+
373
+ @no_warnings(category=RuntimeWarning)
374
+ def depth_edge(
375
+ depth: np.ndarray,
376
+ atol: float = None,
377
+ rtol: float = None,
378
+ kernel_size: int = 3,
379
+ mask: np.ndarray = None,
380
+ ) -> np.ndarray:
381
+ """
382
+ Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
383
+
384
+ Args:
385
+ depth (np.ndarray): shape (..., height, width), linear depth map
386
+ atol (float): absolute tolerance
387
+ rtol (float): relative tolerance
388
+
389
+ Returns:
390
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
391
+ """
392
+ if mask is None:
393
+ diff = max_pool_2d(
394
+ depth, kernel_size, stride=1, padding=kernel_size // 2
395
+ ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
396
+ else:
397
+ diff = max_pool_2d(
398
+ np.where(mask, depth, -np.inf),
399
+ kernel_size,
400
+ stride=1,
401
+ padding=kernel_size // 2,
402
+ ) + max_pool_2d(
403
+ np.where(mask, -depth, -np.inf),
404
+ kernel_size,
405
+ stride=1,
406
+ padding=kernel_size // 2,
407
+ )
408
+
409
+ edge = np.zeros_like(depth, dtype=bool)
410
+ if atol is not None:
411
+ edge |= diff > atol
412
+
413
+ if rtol is not None:
414
+ edge |= diff / depth > rtol
415
+ return edge
416
+
417
+
418
+ def depth_aliasing(
419
+ depth: np.ndarray,
420
+ atol: float = None,
421
+ rtol: float = None,
422
+ kernel_size: int = 3,
423
+ mask: np.ndarray = None,
424
+ ) -> np.ndarray:
425
+ """
426
+ Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
427
+ Args:
428
+ depth (np.ndarray): shape (..., height, width), linear depth map
429
+ atol (float): absolute tolerance
430
+ rtol (float): relative tolerance
431
+
432
+ Returns:
433
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
434
+ """
435
+ if mask is None:
436
+ diff_max = (
437
+ max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
438
+ )
439
+ diff_min = (
440
+ max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
441
+ )
442
+ else:
443
+ diff_max = (
444
+ max_pool_2d(
445
+ np.where(mask, depth, -np.inf),
446
+ kernel_size,
447
+ stride=1,
448
+ padding=kernel_size // 2,
449
+ )
450
+ - depth
451
+ )
452
+ diff_min = (
453
+ max_pool_2d(
454
+ np.where(mask, -depth, -np.inf),
455
+ kernel_size,
456
+ stride=1,
457
+ padding=kernel_size // 2,
458
+ )
459
+ + depth
460
+ )
461
+ diff = np.minimum(diff_max, diff_min)
462
+
463
+ edge = np.zeros_like(depth, dtype=bool)
464
+ if atol is not None:
465
+ edge |= diff > atol
466
+ if rtol is not None:
467
+ edge |= diff / depth > rtol
468
+ return edge
469
+
470
+
471
+ @no_warnings(category=RuntimeWarning)
472
+ def normals_edge(
473
+ normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None
474
+ ) -> np.ndarray:
475
+ """
476
+ Compute the edge mask from normal map.
477
+
478
+ Args:
479
+ normal (np.ndarray): shape (..., height, width, 3), normal map
480
+ tol (float): tolerance in degrees
481
+
482
+ Returns:
483
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
484
+ """
485
+ assert normals.ndim >= 3 and normals.shape[-1] == 3, (
486
+ "normal should be of shape (..., height, width, 3)"
487
+ )
488
+ normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
489
+
490
+ padding = kernel_size // 2
491
+ normals_window = sliding_window_2d(
492
+ np.pad(
493
+ normals,
494
+ (
495
+ *([(0, 0)] * (normals.ndim - 3)),
496
+ (padding, padding),
497
+ (padding, padding),
498
+ (0, 0),
499
+ ),
500
+ mode="edge",
501
+ ),
502
+ window_size=kernel_size,
503
+ stride=1,
504
+ axis=(-3, -2),
505
+ )
506
+ if mask is None:
507
+ angle_diff = np.arccos(
508
+ (normals[..., None, None] * normals_window).sum(axis=-3)
509
+ ).max(axis=(-2, -1))
510
+ else:
511
+ mask_window = sliding_window_2d(
512
+ np.pad(
513
+ mask,
514
+ (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)),
515
+ mode="edge",
516
+ ),
517
+ window_size=kernel_size,
518
+ stride=1,
519
+ axis=(-3, -2),
520
+ )
521
+ angle_diff = np.where(
522
+ mask_window,
523
+ np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)),
524
+ 0,
525
+ ).max(axis=(-2, -1))
526
+
527
+ angle_diff = max_pool_2d(
528
+ angle_diff, kernel_size, stride=1, padding=kernel_size // 2
529
+ )
530
+ edge = angle_diff > np.deg2rad(tol)
531
+ return edge
hyworldmirror/utils/inference_utils.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference utilities for WorldMirror pipeline.
3
+
4
+ Includes: image preprocessing, input preparation, prior loading, mask computation,
5
+ result saving, and timing utilities.
6
+ """
7
+
8
+ import glob
9
+ import json
10
+ import os
11
+ import time
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from pathlib import Path
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+ from PIL import Image
19
+ from torchvision import transforms
20
+
21
+ from ..models.utils.camera_utils import vector_to_camera_matrices
22
+ from ..models.utils.geometry import depth_to_world_coords_points
23
+ from .save_utils import (
24
+ save_depth_png, save_depth_npy, save_normal_png,
25
+ save_gs_ply, save_points_ply, save_camera_params,
26
+ )
27
+ from .video_utils import video_to_image_frames, video_to_image_frames_new
28
+ from .visual_util import segment_sky, download_file_from_url
29
+ from .geometry import depth_edge, normals_edge
30
+
31
+ _IO_WORKERS = 8
32
+
33
+ # ============================================================
34
+ # Image Preprocessing
35
+ # ============================================================
36
+
37
+ def _handle_alpha_channel(img_data):
38
+ """Process RGBA images by blending with white background."""
39
+ if img_data.mode == "RGBA":
40
+ white_bg = Image.new("RGBA", img_data.size, (255, 255, 255, 255))
41
+ img_data = Image.alpha_composite(white_bg, img_data)
42
+ return img_data.convert("RGB")
43
+
44
+
45
+ def _calculate_resize_dims(orig_w, orig_h, max_dim, resize_strategy, patch_size=14):
46
+ """Calculate new dimensions based on resize strategy."""
47
+ if orig_w >= orig_h:
48
+ new_w = max_dim
49
+ new_h = round(orig_h * (new_w / orig_w) / patch_size) * patch_size
50
+ else:
51
+ new_h = max_dim
52
+ new_w = round(orig_w * (new_h / orig_h) / patch_size) * patch_size
53
+ return new_w, new_h
54
+
55
+
56
+ def _apply_padding(tensor_img, target_dim):
57
+ """Apply padding to make tensor square."""
58
+ h_pad = target_dim - tensor_img.shape[1]
59
+ w_pad = target_dim - tensor_img.shape[2]
60
+ if h_pad > 0 or w_pad > 0:
61
+ pad_top, pad_bottom = h_pad // 2, h_pad - h_pad // 2
62
+ pad_left, pad_right = w_pad // 2, w_pad - w_pad // 2
63
+ return torch.nn.functional.pad(
64
+ tensor_img, (pad_left, pad_right, pad_top, pad_bottom),
65
+ mode="constant", value=1.0,
66
+ )
67
+ return tensor_img
68
+
69
+
70
+ def prepare_images_to_tensor(file_paths, resize_strategy="crop", target_size=518):
71
+ """Process image files into uniform tensor batch [1, N, 3, H, W]."""
72
+ if not file_paths:
73
+ raise ValueError("At least 1 image is required")
74
+ if resize_strategy not in ["crop", "pad"]:
75
+ raise ValueError("Strategy must be 'crop' or 'pad'")
76
+
77
+ tensor_list = []
78
+ converter = transforms.ToTensor()
79
+
80
+ for file_path in file_paths:
81
+ img_data = Image.open(file_path)
82
+ img_data = _handle_alpha_channel(img_data)
83
+ orig_w, orig_h = img_data.size
84
+ new_w, new_h = _calculate_resize_dims(orig_w, orig_h, target_size, resize_strategy)
85
+
86
+ img_data = img_data.resize((new_w, new_h), Image.Resampling.BICUBIC)
87
+ tensor_img = converter(img_data)
88
+
89
+ if resize_strategy == "crop":
90
+ if new_h > target_size:
91
+ crop_start = (new_h - target_size) // 2
92
+ tensor_img = tensor_img[:, crop_start:crop_start + target_size, :]
93
+ if new_w > target_size:
94
+ crop_start = (new_w - target_size) // 2
95
+ tensor_img = tensor_img[:, :, crop_start:crop_start + target_size]
96
+ elif resize_strategy == "pad":
97
+ tensor_img = _apply_padding(tensor_img, target_size)
98
+
99
+ tensor_list.append(tensor_img)
100
+
101
+ shapes = set((t.shape[1], t.shape[2]) for t in tensor_list)
102
+ if len(shapes) > 1:
103
+ raise ValueError(
104
+ f"Inconsistent resolutions after preprocessing: {shapes}. "
105
+ f"All input images must have the same aspect ratio."
106
+ )
107
+
108
+ batch_tensor = torch.stack(tensor_list)
109
+ if batch_tensor.dim() == 3:
110
+ batch_tensor = batch_tensor.unsqueeze(0)
111
+ return batch_tensor.unsqueeze(0)
112
+
113
+
114
+ # ============================================================
115
+ # Input Preparation
116
+ # ============================================================
117
+
118
+ def prepare_input(input_path, target_size=518, fps=1,
119
+ video_strategy="new", min_frames=1, max_frames=64,
120
+ temp_dir=None):
121
+ """Read images or extract video frames. Returns (img_paths, subdir_name)."""
122
+ input_path = Path(input_path)
123
+ video_exts = ['.mp4', '.avi', '.mov', '.webm', '.gif']
124
+
125
+ if input_path.is_file() and input_path.suffix.lower() in video_exts:
126
+ subdir_name = input_path.stem
127
+ frames_dir = Path(temp_dir or "/tmp") / f"frames_{subdir_name}"
128
+ frames_dir.mkdir(parents=True, exist_ok=True)
129
+ min_f = max(1, min_frames)
130
+ max_f = min(64, max_frames)
131
+ if video_strategy == "new":
132
+ img_paths = video_to_image_frames_new(
133
+ str(input_path), str(frames_dir),
134
+ min_frames=min_f, max_frames=max_f, fallback_fps=fps,
135
+ )
136
+ else:
137
+ img_paths = video_to_image_frames(str(input_path), str(frames_dir), fps=fps)
138
+ if len(img_paths) > max_f:
139
+ indices = np.linspace(0, len(img_paths) - 1, max_f, dtype=int)
140
+ img_paths = [img_paths[i] for i in indices]
141
+ if not img_paths:
142
+ raise RuntimeError(f"Failed to extract frames from {input_path}")
143
+ img_paths = sorted(img_paths)
144
+ print(f"[Input] Extracted {len(img_paths)} frames from video: {input_path}")
145
+ elif input_path.is_dir():
146
+ subdir_name = input_path.name
147
+ img_paths = []
148
+ for ext in ["*.jpeg", "*.jpg", "*.png", "*.webp"]:
149
+ img_paths.extend(glob.glob(os.path.join(str(input_path), ext)))
150
+ if not img_paths:
151
+ raise FileNotFoundError(f"No images found in {input_path}")
152
+ img_paths = sorted(img_paths)
153
+ print(f"[Input] Loaded {len(img_paths)} images from: {input_path}")
154
+ elif input_path.is_file() and input_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']:
155
+ subdir_name = input_path.stem
156
+ img_paths = [str(input_path)]
157
+ print(f"[Input] Single image input: {input_path}")
158
+ else:
159
+ raise ValueError(f"Invalid input path: {input_path}")
160
+
161
+ return img_paths, subdir_name
162
+
163
+
164
+ def compute_adaptive_target_size(img_paths, max_target_size=518, patch_size=14):
165
+ """Compute inference resolution = min(image_longest_edge, max_target_size).
166
+
167
+ Rounds down to nearest multiple of patch_size. Avoids upsampling small images.
168
+ """
169
+ first_img = Image.open(img_paths[0])
170
+ orig_w, orig_h = first_img.size
171
+ longest_edge = max(orig_w, orig_h)
172
+ effective = min(longest_edge, max_target_size)
173
+ effective = (effective // patch_size) * patch_size
174
+ return max(effective, patch_size * 2)
175
+
176
+
177
+ # ============================================================
178
+ # Prior Loading
179
+ # ============================================================
180
+
181
+ def compute_preprocessing_transform(img_paths, target_size, patch_size=14):
182
+ """Compute the resize + center-crop transform applied by prepare_images_to_tensor.
183
+
184
+ Returns dict with orig/new/final sizes and scale/crop parameters.
185
+ """
186
+ first_img = Image.open(img_paths[0])
187
+ orig_w, orig_h = first_img.size
188
+ new_w, new_h = _calculate_resize_dims(orig_w, orig_h, target_size, "crop", patch_size)
189
+
190
+ crop_y = (new_h - target_size) // 2 if new_h > target_size else 0
191
+ crop_x = (new_w - target_size) // 2 if new_w > target_size else 0
192
+
193
+ return {
194
+ "orig_w": orig_w, "orig_h": orig_h,
195
+ "new_w": new_w, "new_h": new_h,
196
+ "crop_x": crop_x, "crop_y": crop_y,
197
+ "final_w": min(new_w, target_size), "final_h": min(new_h, target_size),
198
+ "scale_x": new_w / orig_w, "scale_y": new_h / orig_h,
199
+ }
200
+
201
+
202
+ def load_prior_camera(prior_cam_path, img_paths, preprocess_transform=None):
203
+ """Load camera priors from JSON. Returns (extrinsics [1,N,4,4], intrinsics [1,N,3,3])."""
204
+ with open(prior_cam_path, "r") as f:
205
+ cam_data = json.load(f)
206
+
207
+ stem_to_idx = {Path(p).stem: i for i, p in enumerate(img_paths)}
208
+ N = len(img_paths)
209
+
210
+ extrinsics = None
211
+ extr_list = cam_data.get("extrinsics", [])
212
+ if extr_list:
213
+ extr_array = np.zeros((N, 4, 4), dtype=np.float32)
214
+ matched = 0
215
+ for entry in extr_list:
216
+ cam_id = str(entry["camera_id"])
217
+ idx = stem_to_idx.get(cam_id)
218
+ if idx is None and cam_id.isdigit() and int(cam_id) < N:
219
+ idx = int(cam_id)
220
+ if idx is not None:
221
+ extr_array[idx] = np.array(entry["matrix"], dtype=np.float32)
222
+ matched += 1
223
+ if matched == N:
224
+ extrinsics = torch.from_numpy(extr_array).unsqueeze(0)
225
+ print(f"[Prior] Loaded extrinsics for {matched}/{N} cameras")
226
+ else:
227
+ print(f"[Prior] Warning: extrinsics matched {matched}/{N}, disabling")
228
+
229
+ intrinsics = None
230
+ intr_list = cam_data.get("intrinsics", [])
231
+ if intr_list:
232
+ intr_array = np.zeros((N, 3, 3), dtype=np.float32)
233
+ matched = 0
234
+ for entry in intr_list:
235
+ cam_id = str(entry["camera_id"])
236
+ idx = stem_to_idx.get(cam_id)
237
+ if idx is None and cam_id.isdigit() and int(cam_id) < N:
238
+ idx = int(cam_id)
239
+ if idx is not None:
240
+ intr_array[idx] = np.array(entry["matrix"], dtype=np.float32)
241
+ matched += 1
242
+ if matched == N:
243
+ intrinsics = torch.from_numpy(intr_array).unsqueeze(0)
244
+ print(f"[Prior] Loaded intrinsics for {matched}/{N} cameras")
245
+ else:
246
+ print(f"[Prior] Warning: intrinsics matched {matched}/{N}, disabling")
247
+
248
+ if intrinsics is not None and preprocess_transform is not None:
249
+ sx, sy = preprocess_transform["scale_x"], preprocess_transform["scale_y"]
250
+ cx_off, cy_off = preprocess_transform["crop_x"], preprocess_transform["crop_y"]
251
+ intrinsics = intrinsics.clone()
252
+ intrinsics[:, :, 0, :] *= sx
253
+ intrinsics[:, :, 1, :] *= sy
254
+ intrinsics[:, :, 0, 2] -= cx_off
255
+ intrinsics[:, :, 1, 2] -= cy_off
256
+
257
+ return extrinsics, intrinsics
258
+
259
+
260
+ def _read_depth_file(depth_path):
261
+ """Read a single depth file (.npy, .exr, .png). Returns float32 [H, W]."""
262
+ ext = Path(depth_path).suffix.lower()
263
+ if ext == ".npy":
264
+ depthmap = np.load(depth_path).astype(np.float32)
265
+ if depthmap.ndim == 3:
266
+ depthmap = depthmap[:, :, 0]
267
+ elif ext == ".exr":
268
+ depthmap = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH).astype(np.float32)
269
+ if depthmap.ndim == 3:
270
+ depthmap = depthmap[:, :, 0]
271
+ elif ext == ".png":
272
+ depthmap = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
273
+ if depthmap is None:
274
+ raise FileNotFoundError(f"Cannot read depth PNG: {depth_path}")
275
+ depthmap = depthmap.astype(np.float32)
276
+ if depthmap.ndim == 3:
277
+ depthmap = depthmap[:, :, 0]
278
+ if depthmap.max() > 255:
279
+ depthmap = depthmap / 1000.0
280
+ else:
281
+ raise ValueError(f"Unsupported depth format: {ext}")
282
+ return np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
283
+
284
+
285
+ def load_prior_depth(prior_depth_path, img_paths, target_h, target_w,
286
+ preprocess_transform=None):
287
+ """Load depth priors from a folder. Returns [1, N, H, W] or None."""
288
+ depth_dir = Path(prior_depth_path)
289
+ if not depth_dir.is_dir():
290
+ return None
291
+
292
+ depth_files = {}
293
+ for f in sorted(depth_dir.iterdir()):
294
+ if f.suffix.lower() in (".npy", ".exr", ".png"):
295
+ if f.stem not in depth_files or f.suffix.lower() == ".npy":
296
+ depth_files[f.stem] = str(f)
297
+
298
+ N = len(img_paths)
299
+ depth_maps = []
300
+ for img_p in img_paths:
301
+ img_stem = Path(img_p).stem
302
+ dpath = depth_files.get(img_stem)
303
+ if dpath is None:
304
+ img_nums = ''.join(filter(str.isdigit, img_stem))
305
+ for dstem, dc in depth_files.items():
306
+ if img_nums and img_nums == ''.join(filter(str.isdigit, dstem)):
307
+ dpath = dc
308
+ break
309
+ if dpath is None:
310
+ return None
311
+
312
+ depthmap = _read_depth_file(dpath)
313
+ if preprocess_transform is not None:
314
+ nw, nh = preprocess_transform["new_w"], preprocess_transform["new_h"]
315
+ cx, cy = preprocess_transform["crop_x"], preprocess_transform["crop_y"]
316
+ fw, fh = preprocess_transform["final_w"], preprocess_transform["final_h"]
317
+ if depthmap.shape[:2] != (nh, nw):
318
+ depthmap = cv2.resize(depthmap, (nw, nh), interpolation=cv2.INTER_LINEAR)
319
+ depthmap = depthmap[cy:cy + fh, cx:cx + fw]
320
+ else:
321
+ if depthmap.shape[:2] != (target_h, target_w):
322
+ depthmap = cv2.resize(depthmap, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
323
+ depth_maps.append(depthmap)
324
+
325
+ depth_tensor = torch.from_numpy(np.stack(depth_maps, axis=0)).unsqueeze(0)
326
+ print(f"[Prior] Loaded {N} depth maps from {prior_depth_path}")
327
+ return depth_tensor
328
+
329
+
330
+ # ============================================================
331
+ # Mask Computation
332
+ # ============================================================
333
+
334
+ def create_filter_mask(
335
+ pts3d_conf, depth_preds, normal_preds, sky_mask,
336
+ confidence_percentile=10.0, edge_normal_threshold=5.0,
337
+ edge_depth_threshold=0.03, apply_confidence_mask=True,
338
+ apply_edge_mask=True, apply_sky_mask=False, gs_depth_preds=None,
339
+ ):
340
+ """Create filter mask based on confidence, edges, and sky segmentation.
341
+
342
+ Returns pts_mask [S,H,W] or (pts_mask, gs_mask) tuple if gs_depth_preds given.
343
+ """
344
+ S, H, W = pts3d_conf.shape[:3]
345
+ final_mask_list = []
346
+ gs_mask_list = [] if gs_depth_preds is not None else None
347
+
348
+ for i in range(S):
349
+ final_mask = None
350
+ if apply_confidence_mask:
351
+ threshold = np.quantile(pts3d_conf[i], confidence_percentile / 100.0)
352
+ conf_mask = pts3d_conf[i] >= threshold
353
+ final_mask = conf_mask if final_mask is None else final_mask & conf_mask
354
+
355
+ pre_edge_mask = final_mask
356
+
357
+ if apply_edge_mask:
358
+ n_edges = normals_edge(normal_preds[i], tol=edge_normal_threshold, mask=pre_edge_mask)
359
+ d_edges = depth_edge(depth_preds[i, :, :, 0], rtol=edge_depth_threshold, mask=pre_edge_mask)
360
+ edge_mask = ~(d_edges & n_edges)
361
+ final_mask = edge_mask if final_mask is None else final_mask & edge_mask
362
+
363
+ if gs_depth_preds is not None:
364
+ gs_d_edges = depth_edge(gs_depth_preds[i, :, :, 0], rtol=edge_depth_threshold, mask=pre_edge_mask)
365
+ gs_edge_mask = ~(gs_d_edges & n_edges)
366
+ gs_frame_mask = gs_edge_mask if pre_edge_mask is None else pre_edge_mask & gs_edge_mask
367
+
368
+ if apply_sky_mask:
369
+ final_mask = sky_mask[i] if final_mask is None else final_mask & sky_mask[i]
370
+ if gs_depth_preds is not None and apply_edge_mask:
371
+ gs_frame_mask = gs_frame_mask & sky_mask[i]
372
+
373
+ final_mask_list.append(final_mask)
374
+ if gs_mask_list is not None:
375
+ gs_mask_list.append(gs_frame_mask if apply_edge_mask else final_mask)
376
+
377
+ def _stack(ml):
378
+ return np.stack(ml, axis=0) if ml[0] is not None else np.ones((S, H, W), dtype=bool)
379
+
380
+ pts_mask = _stack(final_mask_list)
381
+ if gs_mask_list is not None:
382
+ return pts_mask, _stack(gs_mask_list)
383
+ return pts_mask
384
+
385
+
386
+ def _compute_sky_mask_from_model(predictions, H, W, S, threshold=0.5):
387
+ """Build sky mask from model predictions. Returns [S,H,W] bool or None."""
388
+ for key in ("gs_depth_mask_logits", "gs_depth_mask", "depth_mask_logits", "depth_mask"):
389
+ if key in predictions:
390
+ prob = predictions[key].sigmoid() if "logits" in key else predictions[key]
391
+ dm = prob[0].detach().cpu()
392
+ if dm.dim() == 4 and dm.shape[-1] == 1:
393
+ dm = dm.squeeze(-1)
394
+ if dm.dim() != 3 or dm.shape[0] != S:
395
+ return None
396
+ mask = (dm > threshold).numpy().astype(bool)
397
+ if mask.shape[1] != H or mask.shape[2] != W:
398
+ mask = np.stack([cv2.resize(mask[i].astype(np.uint8), (W, H),
399
+ interpolation=cv2.INTER_NEAREST) > 0
400
+ for i in range(S)], axis=0)
401
+ return mask
402
+ return None
403
+
404
+
405
+ def compute_sky_mask(img_paths, H, W, S, predictions=None, source="auto",
406
+ model_threshold=0.5, processed_aspect_ratio=None):
407
+ """Compute sky segmentation mask [S,H,W] (True=non-sky, False=sky)."""
408
+ if source == "model":
409
+ mask = _compute_sky_mask_from_model(predictions, H, W, S, model_threshold) if predictions else None
410
+ return mask if mask is not None else np.ones((S, H, W), dtype=bool)
411
+
412
+ skyseg_path = "skyseg.onnx"
413
+ if not os.path.exists(skyseg_path):
414
+ download_file_from_url(
415
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
416
+ skyseg_path,
417
+ )
418
+ import onnxruntime
419
+ session = onnxruntime.InferenceSession(skyseg_path)
420
+ sky_list = []
421
+ for i in range(S):
422
+ if processed_aspect_ratio is not None:
423
+ pil_img = Image.open(img_paths[i]).convert("RGB")
424
+ sw, sh = pil_img.size
425
+ if sw / sh > processed_aspect_ratio:
426
+ cw = int(round(sh * processed_aspect_ratio))
427
+ ch = sh
428
+ else:
429
+ cw = sw
430
+ ch = int(round(sw / processed_aspect_ratio))
431
+ left, top = (sw - cw) // 2, (sh - ch) // 2
432
+ pil_img = pil_img.crop((left, top, left + cw, top + ch))
433
+ frame = segment_sky(cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR), session)
434
+ else:
435
+ frame = segment_sky(img_paths[i], session)
436
+ if frame.shape[:2] != (H, W):
437
+ frame = cv2.resize(frame, (W, H))
438
+ sky_list.append(frame)
439
+
440
+ sky_mask = np.stack(sky_list, axis=0) > 0
441
+ if source == "auto" and predictions is not None:
442
+ model_mask = _compute_sky_mask_from_model(predictions, H, W, S, model_threshold)
443
+ if model_mask is not None:
444
+ sky_mask = sky_mask & model_mask
445
+ return sky_mask
446
+
447
+
448
+ def compute_filter_mask(predictions, imgs, img_paths, H, W, S,
449
+ apply_confidence_mask=False, apply_edge_mask=False,
450
+ apply_sky_mask=False, confidence_percentile=10.0,
451
+ edge_normal_threshold=5.0, edge_depth_threshold=0.03,
452
+ sky_mask=None, use_gs_depth=False):
453
+ """Compute unified filter mask. Returns (filter_mask, gs_filter_mask) tuple."""
454
+ if not (apply_confidence_mask or apply_edge_mask or apply_sky_mask):
455
+ return np.ones((S, H, W), dtype=bool), None
456
+
457
+ if apply_sky_mask and sky_mask is None:
458
+ sky_mask = compute_sky_mask(img_paths, H, W, S, processed_aspect_ratio=W / H)
459
+ elif sky_mask is None:
460
+ sky_mask = np.ones((S, H, W), dtype=bool)
461
+
462
+ if "pts3d_conf" in predictions:
463
+ conf_np = predictions["pts3d_conf"][0].detach().cpu().float().numpy()
464
+ elif "depth_conf" in predictions:
465
+ conf_np = predictions["depth_conf"][0].detach().cpu().float().numpy()
466
+ else:
467
+ conf_np = np.ones((S, H, W), dtype=np.float32)
468
+ depth_np = predictions["depth"][0].detach().cpu().float().numpy()
469
+ normal_np = predictions["normals"][0].detach().cpu().float().numpy()
470
+
471
+ gs_depth_np = None
472
+ if use_gs_depth and "gs_depth" in predictions:
473
+ raw = predictions["gs_depth"][0].detach().cpu().float().numpy()
474
+ gs_depth_np = raw if raw.ndim == 4 else raw[..., np.newaxis]
475
+
476
+ result = create_filter_mask(
477
+ conf_np, depth_np, normal_np, sky_mask,
478
+ confidence_percentile=confidence_percentile,
479
+ edge_normal_threshold=edge_normal_threshold,
480
+ edge_depth_threshold=edge_depth_threshold,
481
+ apply_confidence_mask=apply_confidence_mask,
482
+ apply_edge_mask=apply_edge_mask,
483
+ apply_sky_mask=apply_sky_mask,
484
+ gs_depth_preds=gs_depth_np,
485
+ )
486
+
487
+ if gs_depth_np is not None:
488
+ pts_mask, gs_mask = result
489
+ total = pts_mask.size
490
+ print(f"[Mask] Filter: pts kept {pts_mask.sum()}/{total}, gs kept {gs_mask.sum()}/{total}")
491
+ return pts_mask, gs_mask
492
+
493
+ print(f"[Mask] Filter: kept {result.sum()}/{result.size} points")
494
+ return result, None
495
+
496
+
497
+ # ============================================================
498
+ # Save Utilities
499
+ # ============================================================
500
+
501
+ def _timed_call(func, *args, **kwargs):
502
+ t0 = time.perf_counter()
503
+ result = func(*args, **kwargs)
504
+ return result, time.perf_counter() - t0
505
+
506
+
507
+ def _save_depth_parallel(depth_cpu, depth_dir, S):
508
+ def _save_one(i):
509
+ save_depth_png(depth_dir / f"depth_{i:04d}.png", depth_cpu[i, :, :, 0])
510
+ save_depth_npy(depth_dir / f"depth_{i:04d}.npy", depth_cpu[i, :, :, 0])
511
+ with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
512
+ list(pool.map(_save_one, range(S)))
513
+
514
+
515
+ def _save_conf_parallel(depth_conf_cpu, conf_dir, S):
516
+ def _save_one(i):
517
+ conf = depth_conf_cpu[i]
518
+ c_min, c_max = conf.min(), conf.max()
519
+ norm = (conf - c_min) / (c_max - c_min) if c_max - c_min > 1e-8 else torch.ones_like(conf)
520
+ Image.fromarray((norm.clamp(0, 1) * 255).to(torch.uint8).numpy(), mode="L").save(
521
+ str(conf_dir / f"conf_{i+1:04d}.png"))
522
+ with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
523
+ list(pool.map(_save_one, range(S)))
524
+
525
+
526
+ def _save_normal_parallel(normals_cpu, normal_dir, S):
527
+ def _save_one(i):
528
+ save_normal_png(normal_dir / f"normal_{i:04d}.png", normals_cpu[i])
529
+ with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
530
+ list(pool.map(_save_one, range(S)))
531
+
532
+
533
+ def _save_sky_mask_parallel(sky_mask, sky_mask_dir, S):
534
+ def _save_one(i):
535
+ Image.fromarray((~sky_mask[i]).astype(np.uint8) * 255, mode="L").save(
536
+ str(sky_mask_dir / f"sky_mask_{i:04d}.png"))
537
+ with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
538
+ list(pool.map(_save_one, range(S)))
539
+
540
+
541
+ def _voxel_prune_gaussians(means, scales, quats, colors, opacities, weights, voxel_size=0.002):
542
+ """Voxel-based merging of Gaussian splats via weighted average."""
543
+ N = means.shape[0]
544
+ if N == 0:
545
+ return means, scales, quats, colors, opacities
546
+
547
+ voxel_idx = (means / voxel_size).floor().long()
548
+ voxel_idx = voxel_idx - voxel_idx.min(dim=0)[0]
549
+ vmax = voxel_idx.max(dim=0)[0] + 1
550
+ flat = voxel_idx[:, 0] * vmax[1] * vmax[2] + voxel_idx[:, 1] * vmax[2] + voxel_idx[:, 2]
551
+
552
+ unique, inv = torch.unique(flat, return_inverse=True)
553
+ K = len(unique)
554
+ if K == N:
555
+ return means, scales, quats, colors, opacities
556
+
557
+ w = weights
558
+ wsum = torch.zeros(K, dtype=w.dtype).scatter_add_(0, inv, w).clamp(min=1e-8)
559
+
560
+ def _wavg(vals):
561
+ out = torch.zeros(K, *vals.shape[1:], dtype=vals.dtype)
562
+ for d in range(vals.shape[1]):
563
+ out[:, d].scatter_add_(0, inv, vals[:, d] * w)
564
+ return out / wsum.unsqueeze(-1)
565
+
566
+ m_opa = torch.zeros(K, dtype=opacities.dtype).scatter_add_(0, inv, w * w) / wsum
567
+ m_quats = torch.zeros(K, 4, dtype=quats.dtype)
568
+ for d in range(4):
569
+ m_quats[:, d].scatter_add_(0, inv, quats[:, d] * w)
570
+ m_quats = m_quats / m_quats.norm(dim=1, keepdim=True).clamp(min=1e-8)
571
+
572
+ print(f"[Save] Voxel prune: {N} -> {K} gaussians")
573
+ return _wavg(means), _wavg(scales), m_quats, _wavg(colors), m_opa
574
+
575
+
576
+ def _compress_points_voxel_then_sample(pts_np, cols_np, max_points=2_000_000, voxel_size=0.005):
577
+ """Compress point cloud: voxel merge then uniform random sampling."""
578
+ n_in = int(pts_np.shape[0])
579
+ if n_in == 0:
580
+ return pts_np, cols_np
581
+
582
+ if voxel_size > 0:
583
+ voxel = np.floor(pts_np / voxel_size).astype(np.int64)
584
+ voxel -= voxel.min(axis=0, keepdims=True)
585
+ _, inv = np.unique(voxel, axis=0, return_inverse=True)
586
+ k = int(inv.max()) + 1
587
+ if k < n_in:
588
+ counts = np.maximum(np.bincount(inv, minlength=k).astype(np.float32), 1.0)
589
+ pts_np = np.stack([np.bincount(inv, weights=pts_np[:, d], minlength=k)
590
+ for d in range(3)], axis=1).astype(np.float32) / counts[:, None]
591
+ cols_np = np.clip(np.round(
592
+ np.stack([np.bincount(inv, weights=cols_np[:, d].astype(np.float32), minlength=k)
593
+ for d in range(3)], axis=1) / counts[:, None]
594
+ ), 0, 255).astype(np.uint8)
595
+
596
+ if max_points > 0 and pts_np.shape[0] > max_points:
597
+ idx = np.random.default_rng(42).choice(pts_np.shape[0], size=max_points, replace=False)
598
+ pts_np, cols_np = pts_np[idx], cols_np[idx]
599
+ return pts_np, cols_np
600
+
601
+
602
+ def _compute_points_from_depth(depth_pred, imgs, extrinsics, intrinsics, S, H, W, filter_mask=None):
603
+ """Derive 3D point cloud from depth + camera outputs."""
604
+ depth_pred, extrinsics, intrinsics = depth_pred.float(), extrinsics.float(), intrinsics.float()
605
+ points_list, colors_list = [], []
606
+ for i in range(S):
607
+ d = depth_pred[0, i, :, :, 0]
608
+ w2c = torch.cat([extrinsics[i][:3, :4],
609
+ torch.tensor([[0, 0, 0, 1]], device=extrinsics.device)], dim=0)
610
+ c2w = torch.linalg.inv(w2c)[:3, :4]
611
+ pts_i, _, mask = depth_to_world_coords_points(d[None], c2w[None], intrinsics[i][None])
612
+ img_colors = (imgs[0, i].permute(1, 2, 0) * 255).to(torch.uint8)
613
+ valid = mask[0]
614
+ if filter_mask is not None:
615
+ valid = valid & torch.from_numpy(filter_mask[i]).to(valid.device)
616
+ if valid.sum().item() > 0:
617
+ points_list.append(pts_i[0][valid])
618
+ colors_list.append(img_colors[valid])
619
+
620
+ if not points_list:
621
+ return np.empty((0, 3), dtype=np.float32), np.empty((0, 3), dtype=np.uint8)
622
+ return (torch.cat(points_list).detach().cpu().float().numpy(),
623
+ torch.cat(colors_list).detach().cpu().to(torch.uint8).numpy())
624
+
625
+
626
+ def _save_colmap_lightweight(extrinsics, intrinsics, outdir, final_w, final_h, S, image_names):
627
+ """Save lightweight COLMAP reconstruction (cameras + images only)."""
628
+ import pycolmap
629
+ sparse_dir = outdir / "sparse" / "0"
630
+ sparse_dir.mkdir(parents=True, exist_ok=True)
631
+ scene = pycolmap.Reconstruction()
632
+ for i in range(S):
633
+ focal_avg = (intrinsics[i][0, 0] + intrinsics[i][1, 1]) / 2
634
+ camera = pycolmap.Camera(
635
+ model="SIMPLE_PINHOLE", width=final_w, height=final_h,
636
+ params=np.array([focal_avg, intrinsics[i][0, 2], intrinsics[i][1, 2]]),
637
+ camera_id=i + 1,
638
+ )
639
+ scene.add_camera(camera)
640
+ cam_from_world = pycolmap.Rigid3d(
641
+ pycolmap.Rotation3d(extrinsics[i][:3, :3]), extrinsics[i][:3, 3])
642
+ img = pycolmap.Image(id=i + 1, name=image_names[i], camera_id=i + 1,
643
+ cam_from_world=cam_from_world)
644
+ img.registered = True
645
+ scene.add_image(img)
646
+ scene.write(str(sparse_dir))
647
+ print(f"[Save] COLMAP sparse -> {sparse_dir}")
648
+
649
+
650
+ def save_results(predictions, imgs, img_paths, outdir,
651
+ save_depth=True, save_normal=True, save_gs=True,
652
+ save_camera=True, save_colmap=False, save_points=True,
653
+ save_sky_mask=False, save_conf=False, log_time=False,
654
+ max_resolution=1920,
655
+ filter_mask=None, gs_filter_mask=None, sky_mask=None,
656
+ compress_pts=True, compress_pts_max_points=2_000_000,
657
+ compress_pts_voxel_size=0.002,
658
+ compress_gs_max_points=5_000_000):
659
+ """Save all results with parallel I/O. Returns timing dict."""
660
+ timings = {}
661
+ outdir = Path(outdir)
662
+ outdir.mkdir(parents=True, exist_ok=True)
663
+ B, S, C, H, W = imgs.shape
664
+
665
+ ar = W / H
666
+ max_w = max(Image.open(p).size[0] for p in img_paths)
667
+ new_w, new_h = max_w, int(round(max_w / ar))
668
+ longest = max(new_w, new_h)
669
+ if longest > max_resolution:
670
+ sf = max_resolution / longest
671
+ new_w, new_h = int(new_w * sf), int(new_h * sf)
672
+ new_w -= new_w % 2
673
+ new_h -= new_h % 2
674
+ image_names = [f"image_{i+1:04d}.jpg" for i in range(S)]
675
+
676
+ depth_cpu = predictions["depth"][0].detach().cpu() if "depth" in predictions else None
677
+ conf_cpu = predictions.get("depth_conf", [None])[0]
678
+ if conf_cpu is not None:
679
+ conf_cpu = conf_cpu.detach().cpu()
680
+ normals_cpu = predictions["normals"][0].detach().cpu() if "normals" in predictions else None
681
+
682
+ futures = {}
683
+ executor = ThreadPoolExecutor(max_workers=_IO_WORKERS)
684
+
685
+ if save_depth and depth_cpu is not None:
686
+ d_dir = outdir / "depth"
687
+ d_dir.mkdir(exist_ok=True)
688
+ futures["save_depth"] = executor.submit(_timed_call, _save_depth_parallel, depth_cpu, d_dir, S)
689
+
690
+ if save_conf and conf_cpu is not None:
691
+ c_dir = outdir / "depth_conf"
692
+ c_dir.mkdir(exist_ok=True)
693
+ futures["save_conf"] = executor.submit(_timed_call, _save_conf_parallel, conf_cpu, c_dir, S)
694
+
695
+ if save_normal and normals_cpu is not None:
696
+ n_dir = outdir / "normal"
697
+ n_dir.mkdir(exist_ok=True)
698
+ futures["save_normal"] = executor.submit(_timed_call, _save_normal_parallel, normals_cpu, n_dir, S)
699
+
700
+ if save_sky_mask and sky_mask is not None:
701
+ sm_dir = outdir / "sky_mask"
702
+ sm_dir.mkdir(exist_ok=True)
703
+ futures["save_sky_mask"] = executor.submit(_timed_call, _save_sky_mask_parallel, sky_mask, sm_dir, S)
704
+
705
+ if save_gs and "splats" in predictions:
706
+ sp = predictions["splats"]
707
+ means = sp["means"][0].reshape(-1, 3).detach().cpu()
708
+ scales = sp["scales"][0].reshape(-1, 3).detach().cpu()
709
+ quats = sp["quats"][0].reshape(-1, 4).detach().cpu()
710
+ colors = (sp["sh"][0] if "sh" in sp else sp["colors"][0]).reshape(-1, 3).detach().cpu()
711
+ opacities = sp["opacities"][0].reshape(-1).detach().cpu()
712
+ weights = sp["weights"][0].reshape(-1).detach().cpu() if "weights" in sp else torch.ones_like(opacities)
713
+
714
+ keep = None
715
+ if gs_filter_mask is not None:
716
+ keep = torch.from_numpy(gs_filter_mask.reshape(-1)).bool()
717
+ elif filter_mask is not None:
718
+ keep = torch.from_numpy(filter_mask.reshape(-1)).bool()
719
+ if keep is not None:
720
+ means, scales, quats = means[keep], scales[keep], quats[keep]
721
+ colors, opacities, weights = colors[keep], opacities[keep], weights[keep]
722
+
723
+ means, scales, quats, colors, opacities = _voxel_prune_gaussians(
724
+ means, scales, quats, colors, opacities, weights)
725
+ if compress_gs_max_points > 0 and means.shape[0] > compress_gs_max_points:
726
+ idx = torch.from_numpy(
727
+ np.random.default_rng(42).choice(means.shape[0], size=compress_gs_max_points, replace=False)
728
+ ).long()
729
+ means, scales, quats, colors, opacities = means[idx], scales[idx], quats[idx], colors[idx], opacities[idx]
730
+
731
+ futures["save_gs_ply"] = executor.submit(
732
+ _timed_call, save_gs_ply, outdir / "gaussians.ply", means, scales, quats, colors, opacities)
733
+
734
+ if save_camera and "camera_poses" in predictions and "camera_intrs" in predictions:
735
+ cam_p = predictions["camera_poses"][0].detach().cpu().float().numpy()
736
+ cam_i = predictions["camera_intrs"][0].detach().cpu().float().numpy()
737
+ futures["save_camera"] = executor.submit(_timed_call, save_camera_params, cam_p, cam_i, str(outdir))
738
+
739
+ if save_points and "depth" in predictions and "camera_params" in predictions:
740
+ e3x4, intr = vector_to_camera_matrices(predictions["camera_params"], image_hw=(H, W))
741
+ pts_np, cols_np = _compute_points_from_depth(
742
+ predictions["depth"], imgs, e3x4[0], intr[0], S, H, W, filter_mask=filter_mask)
743
+ futures["save_points"] = executor.submit(
744
+ _timed_call, _save_points_artifacts, outdir / "points.ply", pts_np, cols_np,
745
+ compress_pts, compress_pts_max_points, compress_pts_voxel_size)
746
+
747
+ if save_colmap and "camera_params" in predictions:
748
+ e3x4, intr = vector_to_camera_matrices(predictions["camera_params"], image_hw=(new_h, new_w))
749
+ futures["save_colmap"] = executor.submit(
750
+ _timed_call, _save_colmap_lightweight,
751
+ e3x4[0].detach().cpu().float().numpy(), intr[0].detach().cpu().float().numpy(),
752
+ outdir, new_w, new_h, S, image_names)
753
+
754
+ for key, future in futures.items():
755
+ result, elapsed = future.result()
756
+ if log_time:
757
+ timings[key] = elapsed
758
+ if isinstance(result, dict):
759
+ timings.update(result)
760
+
761
+ executor.shutdown(wait=False)
762
+ return timings
763
+
764
+
765
+ def _save_points_artifacts(path, pts_np, cols_np,
766
+ compress=False, max_points=2_000_000,
767
+ voxel_size=0.005):
768
+ timings = {}
769
+ if compress:
770
+ t0 = time.perf_counter()
771
+ pts_np, cols_np = _compress_points_voxel_then_sample(pts_np, cols_np, max_points, voxel_size)
772
+ timings["compress_points"] = time.perf_counter() - t0
773
+ save_points_ply(path, pts_np, cols_np)
774
+ return timings
775
+
776
+
777
+ # ============================================================
778
+ # Timing Report
779
+ # ============================================================
780
+
781
+ def print_and_save_timings(timings, outdir):
782
+ """Print formatted timing table and save to JSON."""
783
+ def _p(label, value, indent=0):
784
+ print(f"{' ' * (indent + 1)}{label:<38s} {value:>10.3f}s")
785
+
786
+ print(f"\n{'='*72}\n TIMING REPORT\n{'='*72}")
787
+
788
+ print(" [Serial Stages]")
789
+ for key, label in [("data_loading", "Data loading"),
790
+ ("inference_preprocess", "Inference preprocess"),
791
+ ("inference", "Model inference"),
792
+ ("compute_mask", "Compute filter mask")]:
793
+ if key in timings:
794
+ _p(label, timings[key], 1)
795
+
796
+ save_wall = timings.get("save_total_wall")
797
+ save_keys = [("save_depth", "Depth"), ("save_conf", "Depth conf"),
798
+ ("save_normal", "Normal"), ("save_sky_mask", "Sky mask"),
799
+ ("save_gs_ply", "Gaussians"), ("save_camera", "Camera"),
800
+ ("save_points", "Points"), ("save_colmap", "COLMAP")]
801
+ present = [(k, n) for k, n in save_keys if k in timings]
802
+ if save_wall is not None or present:
803
+ print(" [Save Stage | Parallel]")
804
+ if save_wall is not None:
805
+ _p("Save wall-clock", save_wall, 1)
806
+ for k, name in present:
807
+ _p(f"- {name}", timings[k], 2)
808
+
809
+ if "case_total" in timings:
810
+ print(" [Total]")
811
+ _p("Case total", timings["case_total"], 1)
812
+
813
+ if "gpu_mem_peak_per_rank_gb" in timings:
814
+ print(" [GPU Memory]")
815
+ for i, p in enumerate(timings["gpu_mem_peak_per_rank_gb"]):
816
+ print(f" Rank {i}: {p:.2f} GB")
817
+ print(f" Average: {timings['gpu_mem_peak_avg_gb']:.2f} GB")
818
+
819
+ print(f"{'='*72}\n")
820
+
821
+ outdir = Path(outdir)
822
+ outdir.mkdir(parents=True, exist_ok=True)
823
+ with open(outdir / "pipeline_timing.json", "w") as f:
824
+ json.dump(timings, f, indent=2)
hyworldmirror/utils/render_utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Render interpolated video from Gaussian Splatting predictions.
3
+
4
+ Interpolates smooth camera trajectories using SLERP quaternions,
5
+ renders each frame via gsplat, and saves MP4 videos.
6
+ """
7
+
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from ..models.models.rasterization import GaussianSplatRenderer
15
+
16
+
17
+ def rotation_matrix_to_quaternion(R):
18
+ """Convert rotation matrix to quaternion (scalar-first: [w, x, y, z]).
19
+
20
+ Note: This uses the Hamilton convention [w, x, y, z], which differs from
21
+ models/utils/rotation.py that uses PyTorch3D convention [x, y, z, w].
22
+ """
23
+ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
24
+
25
+ q = torch.zeros(R.shape[:-2] + (4,), device=R.device, dtype=R.dtype)
26
+
27
+ mask1 = trace > 0
28
+ s = torch.sqrt(trace[mask1] + 1.0) * 2
29
+ q[mask1, 0] = 0.25 * s
30
+ q[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s
31
+ q[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s
32
+ q[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s
33
+
34
+ mask2 = (~mask1) & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
35
+ s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2
36
+ q[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s
37
+ q[mask2, 1] = 0.25 * s
38
+ q[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s
39
+ q[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s
40
+
41
+ mask3 = (~mask1) & (~mask2) & (R[..., 1, 1] > R[..., 2, 2])
42
+ s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2
43
+ q[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s
44
+ q[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s
45
+ q[mask3, 2] = 0.25 * s
46
+ q[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s
47
+
48
+ mask4 = (~mask1) & (~mask2) & (~mask3)
49
+ s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2
50
+ q[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s
51
+ q[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s
52
+ q[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s
53
+ q[mask4, 3] = 0.25 * s
54
+
55
+ return q
56
+
57
+
58
+ def quaternion_to_rotation_matrix(q):
59
+ """Convert quaternion (scalar-first: [w, x, y, z]) to rotation matrix."""
60
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
61
+
62
+ norm = torch.sqrt(w*w + x*x + y*y + z*z)
63
+ w, x, y, z = w/norm, x/norm, y/norm, z/norm
64
+
65
+ R = torch.zeros(q.shape[:-1] + (3, 3), device=q.device, dtype=q.dtype)
66
+
67
+ R[..., 0, 0] = 1 - 2*(y*y + z*z)
68
+ R[..., 0, 1] = 2*(x*y - w*z)
69
+ R[..., 0, 2] = 2*(x*z + w*y)
70
+ R[..., 1, 0] = 2*(x*y + w*z)
71
+ R[..., 1, 1] = 1 - 2*(x*x + z*z)
72
+ R[..., 1, 2] = 2*(y*z - w*x)
73
+ R[..., 2, 0] = 2*(x*z - w*y)
74
+ R[..., 2, 1] = 2*(y*z + w*x)
75
+ R[..., 2, 2] = 1 - 2*(x*x + y*y)
76
+
77
+ return R
78
+
79
+
80
+ def slerp_quaternions(q1, q2, t):
81
+ """Spherical linear interpolation between quaternions."""
82
+ dot = (q1 * q2).sum(dim=-1, keepdim=True)
83
+
84
+ mask = dot < 0
85
+ q2 = torch.where(mask, -q2, q2)
86
+ dot = torch.where(mask, -dot, dot)
87
+
88
+ DOT_THRESHOLD = 0.9995
89
+ mask_linear = dot > DOT_THRESHOLD
90
+
91
+ result = torch.zeros_like(q1)
92
+
93
+ if mask_linear.any():
94
+ result_linear = q1 + t * (q2 - q1)
95
+ norm = torch.norm(result_linear, dim=-1, keepdim=True)
96
+ result_linear = result_linear / norm
97
+ result = torch.where(mask_linear, result_linear, result)
98
+
99
+ mask_slerp = ~mask_linear
100
+ if mask_slerp.any():
101
+ theta_0 = torch.acos(torch.abs(dot))
102
+ sin_theta_0 = torch.sin(theta_0)
103
+
104
+ theta = theta_0 * t
105
+ sin_theta = torch.sin(theta)
106
+
107
+ s0 = torch.cos(theta) - dot * sin_theta / sin_theta_0
108
+ s1 = sin_theta / sin_theta_0
109
+
110
+ result_slerp = (s0 * q1) + (s1 * q2)
111
+ result = torch.where(mask_slerp, result_slerp, result)
112
+
113
+ return result
114
+
115
+
116
+ def render_interpolated_video(gs_renderer: GaussianSplatRenderer,
117
+ splats: dict,
118
+ camtoworlds: torch.Tensor,
119
+ intrinsics: torch.Tensor,
120
+ hw: tuple,
121
+ out_path: Path,
122
+ interp_per_pair: int = 20,
123
+ loop_reverse: bool = True,
124
+ save_mode: str = "split",
125
+ frame_times: list = None,
126
+ render_depth: bool = False) -> None:
127
+ """Render an interpolated fly-through video from Gaussian splat predictions.
128
+
129
+ Args:
130
+ gs_renderer: GaussianSplatRenderer instance (from the model).
131
+ splats: Dict with keys 'means', 'scales', 'quats', 'opacities', 'sh'/'colors'.
132
+ camtoworlds: Camera-to-world matrices [B, S, 4, 4].
133
+ intrinsics: Camera intrinsic matrices [B, S, 3, 3].
134
+ hw: Tuple of (height, width) for rendering.
135
+ out_path: Output path (without extension).
136
+ interp_per_pair: Number of interpolated frames per camera pair.
137
+ loop_reverse: Append reversed video for smooth looping.
138
+ save_mode: 'split' (separate rgb/depth) or 'both' (combined).
139
+ frame_times: Optional list of timestamps for adaptive interpolation.
140
+ render_depth: Whether to also render depth video.
141
+ """
142
+ import moviepy.editor as mpy
143
+
144
+ b, s, _, _ = camtoworlds.shape
145
+ h, w = hw
146
+
147
+ def build_interpolated_traj(index, base_interp_per_pair: int):
148
+ exts, ints = [], []
149
+ tmp_camtoworlds = camtoworlds[:, index]
150
+ tmp_intrinsics = intrinsics[:, index]
151
+
152
+ use_time_based = frame_times is not None and len(frame_times) == len(index)
153
+ if use_time_based and len(index) > 1:
154
+ times = np.array([frame_times[i] for i in index], dtype=np.float32)
155
+ gaps = np.diff(times)
156
+ gaps[gaps < 0] = 0.0
157
+ total_gap = float(gaps.sum())
158
+ target_total_interp = max(1, (len(index) - 1) * base_interp_per_pair)
159
+ gap_scale = target_total_interp / total_gap if total_gap > 1e-6 else 0.0
160
+ else:
161
+ gaps = None
162
+ gap_scale = None
163
+
164
+ for i in range(len(index)-1):
165
+ exts.append(tmp_camtoworlds[:, i:i+1])
166
+ ints.append(tmp_intrinsics[:, i:i+1])
167
+ R0, t0 = tmp_camtoworlds[:, i, :3, :3], tmp_camtoworlds[:, i, :3, 3]
168
+ R1, t1 = tmp_camtoworlds[:, i + 1, :3, :3], tmp_camtoworlds[:, i + 1, :3, 3]
169
+
170
+ q0 = rotation_matrix_to_quaternion(R0)
171
+ q1 = rotation_matrix_to_quaternion(R1)
172
+
173
+ if use_time_based:
174
+ gap = float(gaps[i]) if gaps is not None else 0.0
175
+ num_interp = max(0, int(round(gap * gap_scale)))
176
+ else:
177
+ num_interp = base_interp_per_pair
178
+
179
+ for j in range(1, num_interp + 1):
180
+ alpha = j / (num_interp + 1)
181
+ t_interp = (1 - alpha) * t0 + alpha * t1
182
+ q_interp = slerp_quaternions(q0, q1, alpha)
183
+ R_interp = quaternion_to_rotation_matrix(q_interp)
184
+
185
+ ext = torch.eye(4, device=R_interp.device, dtype=R_interp.dtype)[None].repeat(b, 1, 1)
186
+ ext[:, :3, :3] = R_interp
187
+ ext[:, :3, 3] = t_interp
188
+
189
+ K0 = tmp_intrinsics[:, i]
190
+ K1 = tmp_intrinsics[:, i + 1]
191
+ K = (1 - alpha) * K0 + alpha * K1
192
+
193
+ exts.append(ext[:, None])
194
+ ints.append(K[:, None])
195
+
196
+ exts = torch.cat(exts, dim=1)[:1]
197
+ ints = torch.cat(ints, dim=1)[:1]
198
+ return exts, ints
199
+
200
+ def build_wobble_traj(nums, delta):
201
+ if s != 1:
202
+ raise ValueError("Wobble trajectory requires exactly 1 input view")
203
+ t = torch.linspace(0, 1, nums, dtype=torch.float32, device=camtoworlds.device)
204
+ t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
205
+ tf = torch.eye(4, dtype=torch.float32, device=camtoworlds.device)
206
+ radius = delta * 0.15
207
+ tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
208
+ radius = radius[..., None]
209
+ radius = radius * t
210
+ tf[..., 0, 3] = torch.sin(2 * torch.pi * t) * radius
211
+ tf[..., 1, 3] = -torch.cos(2 * torch.pi * t) * radius
212
+ exts = camtoworlds @ tf
213
+ ints = intrinsics.repeat(1, exts.shape[1], 1, 1)
214
+ return exts, ints
215
+
216
+ if s > 1:
217
+ all_ext, all_int = build_interpolated_traj([i for i in range(s)], interp_per_pair)
218
+ else:
219
+ all_ext, all_int = build_wobble_traj(interp_per_pair * 12, splats["means"][0].median(dim=0).values.norm(dim=-1)[None])
220
+
221
+ rendered_rgbs, rendered_depths = [], []
222
+ chunk = 40
223
+
224
+ # Always prune splats to remove scale outliers
225
+ try:
226
+ pruned_splats = gs_renderer.prune_gs(splats, gs_renderer.voxel_size)
227
+ except (AttributeError, RuntimeError):
228
+ pruned_splats = splats
229
+
230
+ for st in tqdm(range(0, all_ext.shape[1], chunk)):
231
+ ed = min(st + chunk, all_ext.shape[1])
232
+ colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
233
+ pruned_splats["means"][:1], pruned_splats["quats"][:1], pruned_splats["scales"][:1],
234
+ pruned_splats["opacities"][:1],
235
+ pruned_splats["sh"][:1] if "sh" in pruned_splats else pruned_splats["colors"][:1],
236
+ all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
237
+ width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None,
238
+ )
239
+ rendered_rgbs.append(colors)
240
+ if render_depth:
241
+ rendered_depths.append(depths)
242
+
243
+ rgbs = torch.cat(rendered_rgbs, dim=1)[0] # [N, H, W, 3]
244
+ if render_depth:
245
+ depths_all = torch.cat(rendered_depths, dim=1)[0, ..., 0] # [N, H, W]
246
+ del rendered_rgbs, rendered_depths
247
+
248
+ def _depth_vis(d: torch.Tensor) -> torch.Tensor:
249
+ """Simple turbo colormap depth visualization."""
250
+ import matplotlib.pyplot as plt
251
+ valid = d > 0
252
+ if valid.any():
253
+ near = d[valid].float().quantile(0.01).log()
254
+ else:
255
+ near = torch.tensor(0.0, device=d.device)
256
+ far = d.flatten().float().quantile(0.99).log()
257
+ x = d.float().clamp(min=1e-9).log()
258
+ x = 1.0 - (x - near) / (far - near + 1e-9)
259
+ x_np = x.cpu().numpy()
260
+ colored = torch.from_numpy(plt.cm.turbo(x_np)[..., :3]).permute(2, 0, 1).float()
261
+ return colored
262
+
263
+ rgb_frames = []
264
+ depth_frames = []
265
+
266
+ if render_depth:
267
+ for rgb, dep in zip(rgbs, depths_all):
268
+ rgb_frames.append(rgb.permute(2, 0, 1))
269
+ depth_frames.append(_depth_vis(dep))
270
+ else:
271
+ for rgb in rgbs:
272
+ rgb_frames.append(rgb.permute(2, 0, 1))
273
+
274
+ def _make_video(frames, path):
275
+ video = torch.stack([f.cpu() for f in frames]).clamp(0, 1)
276
+ video = video.permute(0, 2, 3, 1)
277
+ video = (video * 255).to(torch.uint8).numpy()
278
+ if loop_reverse and video.shape[0] > 1:
279
+ video = np.concatenate([video, video[::-1][1:-1]], axis=0)
280
+ clip = mpy.ImageSequenceClip(list(video), fps=30)
281
+ clip.write_videofile(str(path), logger=None)
282
+
283
+ out_path = Path(out_path)
284
+ out_path.mkdir(parents=True, exist_ok=True)
285
+ if save_mode == 'split':
286
+ _make_video(rgb_frames, out_path / "rendered_rgb.mp4")
287
+ if render_depth:
288
+ _make_video(depth_frames, out_path / "rendered_depth.mp4")
289
+ elif save_mode == 'both' and render_depth:
290
+ combined = [torch.cat([r, d], dim=1) for r, d in zip(rgb_frames, depth_frames)]
291
+ _make_video(combined, out_path / "rendered.mp4")
292
+
293
+ print(f"Video saved to {out_path} (mode: {save_mode})")
294
+ torch.cuda.empty_cache()
hyworldmirror/utils/save_utils.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving images, depths, normals, point clouds, and Gaussian splat data.
3
+ tencent
4
+ """
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from plyfile import PlyData, PlyElement
11
+ from io import BytesIO
12
+ import json
13
+ import os
14
+
15
+ def save_camera_params(extrinsics, intrinsics, target_dir):
16
+ """
17
+ Save camera parameters (extrinsics and intrinsics) in JSON format
18
+
19
+ Args:
20
+ extrinsics: numpy array, shape [N, 4, 4] - extrinsic matrices for N cameras
21
+ intrinsics: numpy array, shape [N, 3, 3] - intrinsic matrices for N cameras
22
+ target_dir: str - directory to save the parameters
23
+
24
+ Returns:
25
+ str: path to the saved file
26
+ """
27
+ camera_data = {
28
+ "num_cameras": int(extrinsics.shape[0]),
29
+ "extrinsics": [],
30
+ "intrinsics": []
31
+ }
32
+
33
+ # Convert each camera's parameters to list format
34
+ for i in range(extrinsics.shape[0]):
35
+ camera_data["extrinsics"].append({
36
+ "camera_id": i,
37
+ "matrix": extrinsics[i].tolist() # [4, 4] -> list
38
+ })
39
+ camera_data["intrinsics"].append({
40
+ "camera_id": i,
41
+ "matrix": intrinsics[i].tolist() # [3, 3] -> list
42
+ })
43
+
44
+ # Save as JSON file
45
+ camera_params_path = os.path.join(target_dir, "camera_params.json")
46
+ with open(camera_params_path, 'w') as f:
47
+ json.dump(camera_data, f, indent=2)
48
+
49
+ return camera_params_path
50
+
51
+ def save_image_png(path: Path, image_tensor: torch.Tensor) -> None:
52
+ # image_tensor: [H, W, 3]
53
+ img = (image_tensor.detach().cpu() * 255.0).to(torch.uint8).numpy()
54
+ Image.fromarray(img).save(str(path))
55
+
56
+
57
+ def save_depth_png(path: Path, depth_tensor: torch.Tensor) -> None:
58
+ # depth_tensor: [H, W]
59
+ d = depth_tensor.detach()
60
+ d = d - d.min()
61
+ d = d / (d.max() + 1e-9)
62
+ img = (d.clamp(0, 1) * 255.0).to(torch.uint8).cpu().numpy()
63
+ Image.fromarray(img, mode="L").save(str(path))
64
+
65
+
66
+ def save_depth_npy(path: Path, depth_tensor: torch.Tensor) -> None:
67
+ # depth_tensor: [H, W]
68
+ # Save actual depth values in numpy format
69
+ d = depth_tensor.detach().cpu().numpy()
70
+ np.save(str(path), d)
71
+
72
+
73
+ def save_normal_png(path: Path, normal_hwc: torch.Tensor) -> None:
74
+ # normal_hwc: [H, W, 3], in [-1, 1]
75
+ n = (normal_hwc.detach().cpu() + 1.0) * 0.5
76
+ img = (n.clamp(0, 1) * 255.0).to(torch.uint8).numpy()
77
+ Image.fromarray(img).save(str(path))
78
+
79
+
80
+ def _build_vertex_ply_element(pts: np.ndarray, colors: np.ndarray) -> PlyElement:
81
+ """Build a PLY vertex element from points and colors arrays.
82
+
83
+ Args:
84
+ pts: Point coordinates, shape [N, 3], dtype float32
85
+ colors: RGB colors, shape [N, 3], dtype uint8
86
+
87
+ Returns:
88
+ PlyElement describing the vertices
89
+ """
90
+ vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4"),
91
+ ("red", "u1"), ("green", "u1"), ("blue", "u1")]
92
+ vertex_elements = np.empty(len(pts), dtype=vertex_dtype)
93
+ vertex_elements["x"] = pts[:, 0]
94
+ vertex_elements["y"] = pts[:, 1]
95
+ vertex_elements["z"] = pts[:, 2]
96
+ vertex_elements["red"] = colors[:, 0]
97
+ vertex_elements["green"] = colors[:, 1]
98
+ vertex_elements["blue"] = colors[:, 2]
99
+ return PlyElement.describe(vertex_elements, "vertex")
100
+
101
+
102
+ def save_scene_ply(path: Path,
103
+ points_xyz: torch.Tensor,
104
+ point_colors: torch.Tensor,
105
+ valid_mask: torch.Tensor = None) -> None:
106
+ """Save point cloud to PLY format"""
107
+ pts = points_xyz.detach().cpu().to(torch.float32).numpy().reshape(-1, 3)
108
+ colors = point_colors.detach().cpu().to(torch.uint8).numpy().reshape(-1, 3)
109
+
110
+ # Filter out invalid points (NaN, Inf)
111
+ if valid_mask is None:
112
+ valid_mask = np.isfinite(pts).all(axis=1)
113
+ else:
114
+ valid_mask = valid_mask.detach().cpu().numpy().reshape(-1)
115
+ pts = pts[valid_mask]
116
+ colors = colors[valid_mask]
117
+
118
+ # Handle empty point cloud
119
+ if len(pts) == 0:
120
+ pts = np.array([[0, 0, 0]], dtype=np.float32)
121
+ colors = np.array([[255, 255, 255]], dtype=np.uint8)
122
+
123
+ PlyData([_build_vertex_ply_element(pts, colors)]).write(str(path))
124
+
125
+
126
+ def save_points_ply(path: Path, pts_np: np.ndarray, cols_np: np.ndarray) -> None:
127
+ """Save point cloud to PLY format from numpy arrays"""
128
+ PlyData([_build_vertex_ply_element(pts_np, cols_np)]).write(str(path))
129
+
130
+
131
+ def _build_gs_ply_data(means, scales, rotations, rgbs, opacities, quantile_threshold):
132
+ """Build Gaussian splat PLY data with scale-based filtering.
133
+
134
+ Args:
135
+ means: Gaussian centers [N, 3]
136
+ scales: Gaussian scales [N, 3]
137
+ rotations: Gaussian rotations as quaternions [N, 4]
138
+ rgbs: RGB colors [N, 3]
139
+ opacities: Opacity values [N]
140
+ quantile_threshold: Percentile threshold for scale filtering (e.g. 0.98 or 0.90)
141
+
142
+ Returns:
143
+ PlyData object ready to be written or returned
144
+ """
145
+ scale_threshold = torch.quantile(scales.max(dim=-1)[0], quantile_threshold, dim=0)
146
+ filter_mask = scales.max(dim=-1)[0] <= scale_threshold
147
+
148
+ means = means[filter_mask].reshape(-1, 3)
149
+ scales = scales[filter_mask].reshape(-1, 3)
150
+ rotations = rotations[filter_mask].reshape(-1, 4)
151
+ rgbs = rgbs[filter_mask].reshape(-1, 3)
152
+ opacities = opacities[filter_mask].reshape(-1)
153
+
154
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
155
+ for i in range(3):
156
+ attributes.append(f"f_dc_{i}")
157
+ attributes.append("opacity")
158
+ for i in range(3):
159
+ attributes.append(f"scale_{i}")
160
+ for i in range(4):
161
+ attributes.append(f"rot_{i}")
162
+
163
+ dtype_full = [(attribute, "f4") for attribute in attributes]
164
+ elements = np.empty(means.shape[0], dtype=dtype_full)
165
+
166
+ attributes_data = (
167
+ means.float().detach().cpu().numpy(),
168
+ torch.zeros_like(means).float().detach().cpu().numpy(),
169
+ rgbs.detach().cpu().contiguous().numpy(),
170
+ opacities[..., None].detach().cpu().numpy(),
171
+ scales.log().detach().cpu().numpy(),
172
+ rotations.detach().cpu().numpy(),
173
+ )
174
+ attributes_data = np.concatenate(attributes_data, axis=1)
175
+ elements[:] = list(map(tuple, attributes_data))
176
+
177
+ return PlyData([PlyElement.describe(elements, "vertex")])
178
+
179
+
180
+ def save_gs_ply(path: Path,
181
+ means: torch.Tensor,
182
+ scales: torch.Tensor,
183
+ rotations: torch.Tensor,
184
+ rgbs: torch.Tensor,
185
+ opacities: torch.Tensor) -> None:
186
+ """
187
+ Export Gaussian splat data to PLY format.
188
+
189
+ Args:
190
+ path: Output PLY file path
191
+ means: Gaussian centers [N, 3]
192
+ scales: Gaussian scales [N, 3]
193
+ rotations: Gaussian rotations as quaternions [N, 4]
194
+ rgbs: RGB colors [N, 3]
195
+ opacities: Opacity values [N]
196
+ """
197
+ # Ensure float32 for quantile and numpy conversion (bf16 not supported)
198
+ means, scales, rotations, rgbs, opacities = (
199
+ t.float() for t in (means, scales, rotations, rgbs, opacities)
200
+ )
201
+ plydata = _build_gs_ply_data(means, scales, rotations, rgbs, opacities, quantile_threshold=0.98)
202
+ plydata.write(str(path))
203
+
204
+
205
+ def convert_gs_to_ply(means, scales, rotations, rgbs, opacities):
206
+ """
207
+ Export Gaussian splat data to PLY format.
208
+
209
+ Args:
210
+ means: Gaussian centers [N, 3]
211
+ scales: Gaussian scales [N, 3]
212
+ rotations: Gaussian rotations as quaternions [N, 4]
213
+ rgbs: RGB colors [N, 3]
214
+ opacities: Opacity values [N]
215
+ """
216
+ return _build_gs_ply_data(means, scales, rotations, rgbs, opacities, quantile_threshold=0.90)
217
+
218
+
219
+ def process_ply_to_splat(plydata, output_path):
220
+ vert = plydata["vertex"]
221
+ sorted_indices = np.argsort(
222
+ -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
223
+ / (1 + np.exp(-vert["opacity"]))
224
+ )
225
+ buffer = BytesIO()
226
+ for idx in sorted_indices:
227
+ v = plydata["vertex"][idx]
228
+ position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
229
+ scales = np.exp(
230
+ np.array(
231
+ [v["scale_0"], v["scale_1"], v["scale_2"]],
232
+ dtype=np.float32,
233
+ )
234
+ )
235
+ rot = np.array(
236
+ [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
237
+ dtype=np.float32,
238
+ )
239
+ SH_C0 = 0.28209479177387814
240
+ color = np.array(
241
+ [
242
+ 0.5 + SH_C0 * v["f_dc_0"],
243
+ 0.5 + SH_C0 * v["f_dc_1"],
244
+ 0.5 + SH_C0 * v["f_dc_2"],
245
+ 1 / (1 + np.exp(-v["opacity"])),
246
+ ]
247
+ )
248
+ buffer.write(position.tobytes())
249
+ buffer.write(scales.tobytes())
250
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
251
+ buffer.write(
252
+ ((rot / np.linalg.norm(rot)) * 128 + 128)
253
+ .clip(0, 255)
254
+ .astype(np.uint8)
255
+ .tobytes()
256
+ )
257
+ value = buffer.getvalue()
258
+ with open(output_path, "wb") as f:
259
+ f.write(value)
260
+
261
+ return output_path
hyworldmirror/utils/video_utils.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import time
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import subprocess
10
+
11
+
12
+ def video_to_image_frames(input_video_path, save_directory=None, fps=1):
13
+ """
14
+ Extracts image frames from a video file at the specified frame rate and saves them as JPEG format.
15
+ Supports regular video files, webcam captures, WebM files, and GIF files, including incomplete files.
16
+
17
+ Args:
18
+ input_video_path: Path to the input video file
19
+ save_directory: Directory to save extracted frames (default: None)
20
+ fps: Number of frames to extract per second (default: 1)
21
+
22
+ Returns: List of file paths to extracted frames
23
+ """
24
+ extracted_frame_paths = []
25
+ frame_indices = [] # Track frame indices for metadata
26
+ source_fps = None
27
+
28
+ # For GIF files, use PIL library for better handling
29
+ if input_video_path.lower().endswith('.gif'):
30
+ try:
31
+ print(f"Processing GIF file using PIL: {input_video_path}")
32
+
33
+ with Image.open(input_video_path) as gif_img:
34
+ # Get GIF properties
35
+ frame_duration_ms = gif_img.info.get('duration', 100)
36
+ gif_frame_rate = 1000.0 / frame_duration_ms if frame_duration_ms > 0 else 10.0
37
+ source_fps = gif_frame_rate
38
+
39
+ print(f"GIF properties: {gif_img.n_frames} frames, {gif_frame_rate:.2f} FPS, {frame_duration_ms}ms per frame")
40
+
41
+ sampling_interval = max(1, int(gif_frame_rate / fps)) if fps < gif_frame_rate else 1
42
+
43
+ saved_count = 0
44
+ for current_frame_index in range(gif_img.n_frames):
45
+ gif_img.seek(current_frame_index)
46
+
47
+ if current_frame_index % sampling_interval == 0:
48
+ rgb_frame = gif_img.convert('RGB')
49
+ frame_ndarray = np.array(rgb_frame)
50
+ frame_output_path = os.path.join(save_directory, f"frame_{saved_count:06d}.jpg")
51
+ pil_image = Image.fromarray(frame_ndarray)
52
+ pil_image.save(frame_output_path, 'JPEG', quality=95)
53
+ extracted_frame_paths.append(frame_output_path)
54
+ frame_indices.append(current_frame_index)
55
+ saved_count += 1
56
+
57
+ if extracted_frame_paths:
58
+ print(f"Successfully extracted {len(extracted_frame_paths)} frames from GIF using PIL")
59
+ # Save metadata
60
+ _save_old_metadata(save_directory, frame_indices, source_fps)
61
+ return extracted_frame_paths
62
+
63
+ except Exception as error:
64
+ print(f"PIL GIF extraction error: {str(error)}, falling back to OpenCV")
65
+
66
+ # For WebM files, use FFmpeg directly for more stable processing
67
+ if input_video_path.lower().endswith('.webm'):
68
+ try:
69
+ print(f"Processing WebM file using FFmpeg: {input_video_path}")
70
+
71
+ # Get video FPS first
72
+ cap = cv2.VideoCapture(input_video_path)
73
+ source_fps = cap.get(cv2.CAP_PROP_FPS) or fps
74
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
75
+ cap.release()
76
+
77
+ output_frame_pattern = os.path.join(save_directory, "frame_%04d.jpg")
78
+
79
+ ffmpeg_command = [
80
+ "ffmpeg",
81
+ "-i", input_video_path,
82
+ "-vf", f"fps={fps}",
83
+ "-q:v", "2",
84
+ output_frame_pattern
85
+ ]
86
+
87
+ ffmpeg_process = subprocess.Popen(
88
+ ffmpeg_command,
89
+ stdout=subprocess.PIPE,
90
+ stderr=subprocess.PIPE
91
+ )
92
+ process_stdout, process_stderr = ffmpeg_process.communicate()
93
+
94
+ # Collect all extracted frames and calculate indices
95
+ extracted_frame_paths = []
96
+ for filename in sorted(os.listdir(save_directory)):
97
+ if filename.startswith("frame_") and filename.endswith(".jpg"):
98
+ full_frame_path = os.path.join(save_directory, filename)
99
+ extracted_frame_paths.append(full_frame_path)
100
+ # Extract frame number from filename (frame_XXXX.jpg)
101
+ try:
102
+ frame_num = int(filename.split("_")[1].split(".")[0])
103
+ # Estimate original frame index based on fps ratio
104
+ frame_idx = int(frame_num * source_fps / fps)
105
+ frame_indices.append(frame_idx)
106
+ except:
107
+ frame_indices.append(len(frame_indices))
108
+
109
+ if extracted_frame_paths:
110
+ print(f"Successfully extracted {len(extracted_frame_paths)} frames from WebM using FFmpeg")
111
+ _save_old_metadata(save_directory, frame_indices, source_fps)
112
+ return extracted_frame_paths
113
+
114
+ print("FFmpeg extraction failed, falling back to OpenCV")
115
+ except Exception as error:
116
+ print(f"FFmpeg extraction error: {str(error)}, falling back to OpenCV")
117
+
118
+ # Standard OpenCV method for non-WebM files or as fallback
119
+ try:
120
+ video_capture = cv2.VideoCapture(input_video_path)
121
+
122
+ if input_video_path.lower().endswith('.webm'):
123
+ video_capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'VP80'))
124
+
125
+ source_fps = video_capture.get(cv2.CAP_PROP_FPS) or fps
126
+ extraction_interval = max(1, int(source_fps / fps))
127
+ processed_frame_count = 0
128
+
129
+ cv2.setLogLevel(0)
130
+
131
+ while True:
132
+ read_success, current_frame = video_capture.read()
133
+ if not read_success:
134
+ break
135
+
136
+ if processed_frame_count % extraction_interval == 0:
137
+ try:
138
+ if current_frame is not None and current_frame.size > 0:
139
+ rgb_converted_frame = cv2.cvtColor(current_frame, cv2.COLOR_BGR2RGB)
140
+ frame_output_path = os.path.join(save_directory, f"frame_{len(extracted_frame_paths):06d}.jpg")
141
+ cv2.imwrite(frame_output_path, cv2.cvtColor(rgb_converted_frame, cv2.COLOR_RGB2BGR))
142
+ extracted_frame_paths.append(frame_output_path)
143
+ frame_indices.append(processed_frame_count)
144
+ except Exception as error:
145
+ print(f"Warning: Failed to process frame {processed_frame_count}: {str(error)}")
146
+
147
+ processed_frame_count += 1
148
+
149
+ if processed_frame_count > 1000:
150
+ break
151
+
152
+ video_capture.release()
153
+ print(f"Extracted {len(extracted_frame_paths)} frames from video using OpenCV")
154
+
155
+ # Save metadata
156
+ if extracted_frame_paths:
157
+ _save_old_metadata(save_directory, frame_indices, source_fps)
158
+
159
+ except Exception as error:
160
+ print(f"Error extracting frames: {str(error)}")
161
+
162
+ return extracted_frame_paths
163
+
164
+
165
+ def _save_old_metadata(save_directory, frame_indices, fps):
166
+ """Save metadata for old sampling strategy."""
167
+ if not frame_indices or not fps:
168
+ return
169
+
170
+ try:
171
+ meta = {
172
+ "frame_indices": frame_indices,
173
+ "frame_times": [idx / fps for idx in frame_indices],
174
+ "fps": fps,
175
+ "algorithm": "uniform_fps_based"
176
+ }
177
+ metadata_path = os.path.join(save_directory, "frame_metadata.json")
178
+ with open(metadata_path, "w") as f:
179
+ json.dump(meta, f, indent=2)
180
+ except Exception as e:
181
+ print(f"Warning: Failed to save metadata: {e}")
182
+
183
+
184
+ def _resize_for_flow(frame, long_edge=320):
185
+ height, width = frame.shape[:2]
186
+ long_side = max(height, width)
187
+ if long_side <= long_edge:
188
+ return frame
189
+ scale = long_edge / float(long_side)
190
+ new_w = max(1, int(width * scale))
191
+ new_h = max(1, int(height * scale))
192
+ return cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
193
+
194
+
195
+ def _resize_for_clarity(frame, long_edge=480):
196
+ """Resize frame for clarity calculation (480p for better accuracy)."""
197
+ height, width = frame.shape[:2]
198
+ long_side = max(height, width)
199
+ if long_side <= long_edge:
200
+ return frame
201
+ scale = long_edge / float(long_side)
202
+ new_w = max(1, int(width * scale))
203
+ new_h = max(1, int(height * scale))
204
+ return cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
205
+
206
+
207
+ def _create_dis_flow():
208
+ if hasattr(cv2, "optflow") and hasattr(cv2.optflow, "createOptFlow_DIS"):
209
+ return cv2.optflow.createOptFlow_DIS(cv2.optflow.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
210
+ if hasattr(cv2, "DISOpticalFlow_create"):
211
+ return cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
212
+ return None
213
+
214
+
215
+ def _calculate_histogram(image):
216
+ """
217
+ Calculate normalized color histogram for global deduplication.
218
+ Using HSV for better robustness to brightness changes.
219
+ """
220
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
221
+ # 8 bins for H, 4 for S, 4 for V -> 128 dim vector
222
+ hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 4, 4], [0, 180, 0, 256, 0, 256])
223
+ cv2.normalize(hist, hist)
224
+ return hist.flatten()
225
+
226
+
227
+ def _calculate_hist_similarity(hist1, hist2):
228
+ return cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
229
+
230
+
231
+ def _advance_cap_to_frame(cap, current_pos, target_idx):
232
+ """Advance cap so that next read() returns frame target_idx. Returns target_idx."""
233
+ dist = target_idx - current_pos
234
+ if dist <= 0:
235
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
236
+ return target_idx
237
+ if dist < 50:
238
+ for _ in range(dist):
239
+ cap.grab()
240
+ return target_idx
241
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
242
+ return target_idx
243
+
244
+
245
+ def _merge_search_windows(candidate_indices, window_size=3):
246
+ """
247
+ Merge adjacent search windows to reduce disk seeks.
248
+ Returns list of (start_idx, end_idx, target_indices) tuples.
249
+ """
250
+ if not candidate_indices:
251
+ return []
252
+
253
+ merged = []
254
+ sorted_indices = sorted(candidate_indices)
255
+ i = 0
256
+
257
+ while i < len(sorted_indices):
258
+ start_idx = max(0, sorted_indices[i] - window_size)
259
+ end_idx = sorted_indices[i] + window_size
260
+ targets_in_window = [sorted_indices[i]]
261
+
262
+ # Extend window to include adjacent candidates
263
+ j = i + 1
264
+ while j < len(sorted_indices):
265
+ next_start = max(0, sorted_indices[j] - window_size)
266
+ if next_start <= end_idx:
267
+ end_idx = sorted_indices[j] + window_size
268
+ targets_in_window.append(sorted_indices[j])
269
+ j += 1
270
+ else:
271
+ break
272
+
273
+ merged.append((start_idx, end_idx, targets_in_window))
274
+ i = j
275
+
276
+ return merged
277
+
278
+
279
+ def _sparse_motion_analysis(cap, fps, total_frames):
280
+ """Phase 1: Sparse sampling with DIS optical flow."""
281
+ sample_interval = max(1, int(fps * 0.5))
282
+ sparse_samples = []
283
+ dis_flow = _create_dis_flow()
284
+ current_idx = 0
285
+ prev_gray = None
286
+
287
+ while True:
288
+ if current_idx > 0:
289
+ steps_to_skip = sample_interval - 1
290
+ if steps_to_skip > 0:
291
+ current_idx = _advance_cap_to_frame(cap, current_idx, current_idx + steps_to_skip)
292
+ ret, frame = cap.read()
293
+ if not ret:
294
+ break
295
+
296
+ small = _resize_for_flow(frame, long_edge=320)
297
+ gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
298
+
299
+ motion_mag = 0.0
300
+ if prev_gray is not None:
301
+ if dis_flow is not None:
302
+ flow = dis_flow.calc(prev_gray, gray, None)
303
+ else:
304
+ flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, None, 0.5, 3, 15, 2, 5, 1.2, 0)
305
+ motion_mag = float(np.mean(np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)))
306
+
307
+ sparse_samples.append({
308
+ "idx": current_idx,
309
+ "motion": motion_mag,
310
+ "hist": _calculate_histogram(small)
311
+ })
312
+ prev_gray = gray
313
+ current_idx += 1
314
+
315
+ return sparse_samples
316
+
317
+
318
+ def _adaptive_frame_selection(sparse_samples, fps, max_frames):
319
+ """Phase 2: Adaptive threshold allocation with deduplication."""
320
+ motions = [s["motion"] for s in sparse_samples[1:]]
321
+
322
+ if not motions:
323
+ return [sparse_samples[0]["idx"]]
324
+
325
+ # Calculate adaptive threshold
326
+ static_floor = 1.0
327
+ total_motion = sum(motions)
328
+ estimated_step = total_motion / max_frames if max_frames > 0 else total_motion
329
+ step_threshold = max(static_floor * 5.0, estimated_step)
330
+
331
+ # Select frames based on accumulated motion
332
+ candidate_indices = [sparse_samples[0]["idx"]]
333
+ selected_hists = [sparse_samples[0]["hist"]]
334
+ current_accum = 0.0
335
+ last_selected_idx = sparse_samples[0]["idx"]
336
+
337
+ for i in range(1, len(sparse_samples)):
338
+ s = sparse_samples[i]
339
+ effective_motion = s["motion"] if s["motion"] >= static_floor else 0.0
340
+ current_accum += effective_motion
341
+
342
+ time_gap = s["idx"] - last_selected_idx
343
+ should_select = (current_accum >= step_threshold) or (time_gap > (4.0 * fps))
344
+
345
+ if should_select:
346
+ is_duplicate = any(_calculate_hist_similarity(s["hist"], h) > 0.999 for h in selected_hists)
347
+ if not is_duplicate:
348
+ candidate_indices.append(s["idx"])
349
+ selected_hists.append(s["hist"])
350
+ current_accum = 0.0
351
+ last_selected_idx = s["idx"]
352
+
353
+ # Always check last frame
354
+ if sparse_samples[-1]["idx"] != candidate_indices[-1]:
355
+ last_hist = sparse_samples[-1]["hist"]
356
+ if not any(_calculate_hist_similarity(last_hist, h) > 0.999 for h in selected_hists):
357
+ candidate_indices.append(sparse_samples[-1]["idx"])
358
+
359
+ return sorted(list(set(candidate_indices)))
360
+
361
+
362
+ def _enforce_frame_constraints(candidate_indices, sparse_samples, min_frames, max_frames):
363
+ """Enforce min/max frame constraints."""
364
+ if len(candidate_indices) < min_frames:
365
+ needed = min_frames - len(candidate_indices)
366
+ all_indices = [s["idx"] for s in sparse_samples]
367
+ extras = np.linspace(0, len(all_indices)-1, needed+2)[1:-1]
368
+ candidate_indices.extend([all_indices[int(e)] for e in extras])
369
+ candidate_indices = sorted(list(set(candidate_indices)))
370
+
371
+ if len(candidate_indices) > max_frames:
372
+ indices_to_keep = np.linspace(0, len(candidate_indices)-1, max_frames)
373
+ candidate_indices = [candidate_indices[int(round(i))] for i in indices_to_keep]
374
+
375
+ return candidate_indices
376
+
377
+
378
+ def _read_window_frames(cap, merged_windows, total_frames):
379
+ """Read all frames from merged windows."""
380
+ all_frames = []
381
+ current_pos = -1
382
+
383
+ for window_idx, (window_start, window_end, _) in enumerate(merged_windows):
384
+ current_pos = _advance_cap_to_frame(cap, current_pos, window_start)
385
+ for idx in range(window_start, min(window_end + 1, total_frames)):
386
+ ret, frame = cap.read()
387
+ if not ret:
388
+ break
389
+ all_frames.append((window_idx, idx, frame))
390
+ current_pos = idx + 1
391
+
392
+ return all_frames
393
+
394
+
395
+ def _compute_clarity_parallel(all_frames):
396
+ """Parallel clarity calculation."""
397
+ def _compute(item):
398
+ window_idx, frame_idx, frame = item
399
+ clarity_frame = _resize_for_clarity(frame, long_edge=480)
400
+ gray = cv2.cvtColor(clarity_frame, cv2.COLOR_BGR2GRAY)
401
+ clarity = cv2.Laplacian(gray, cv2.CV_64F).var()
402
+ return (window_idx, frame_idx, frame, clarity)
403
+
404
+ with ThreadPoolExecutor(max_workers=min(8, len(all_frames) or 1)) as ex:
405
+ return list(ex.map(_compute, all_frames))
406
+
407
+
408
+ def _select_best_frames(clarity_results, merged_windows, candidate_indices, search_window_size=3):
409
+ """Select best frame for each candidate based on clarity."""
410
+ # Group by window
411
+ window_frames = {}
412
+ for window_idx, frame_idx, frame, clarity in clarity_results:
413
+ if window_idx not in window_frames:
414
+ window_frames[window_idx] = []
415
+ window_frames[window_idx].append((frame_idx, frame, clarity))
416
+
417
+ # Select best frame for each target
418
+ target_to_best = {}
419
+ for window_idx, (_, _, targets) in enumerate(merged_windows):
420
+ frames = window_frames.get(window_idx, [])
421
+ for target_idx in targets:
422
+ candidates = [(idx, f, c) for idx, f, c in frames
423
+ if abs(idx - target_idx) <= search_window_size]
424
+ if candidates:
425
+ best_idx, best_frame, _ = max(candidates, key=lambda x: x[2])
426
+ target_to_best[target_idx] = (best_idx, best_frame)
427
+ elif frames:
428
+ closest = min(frames, key=lambda x: abs(x[0] - target_idx))
429
+ target_to_best[target_idx] = (closest[0], closest[1])
430
+
431
+ return target_to_best
432
+
433
+
434
+ def _save_frames_parallel(target_to_best, candidate_indices, save_directory):
435
+ """Parallel frame saving."""
436
+ path_frame_list = []
437
+ final_indices = []
438
+
439
+ for target_idx in sorted(candidate_indices):
440
+ if target_idx in target_to_best:
441
+ best_idx, best_frame = target_to_best[target_idx]
442
+ final_indices.append(best_idx)
443
+ path_frame_list.append((
444
+ os.path.join(save_directory, f"frame_{len(path_frame_list):06d}.jpg"),
445
+ best_frame
446
+ ))
447
+
448
+ def _write(p_f):
449
+ cv2.imwrite(p_f[0], p_f[1])
450
+ return p_f[0]
451
+
452
+ with ThreadPoolExecutor(max_workers=min(8, len(path_frame_list) or 1)) as ex:
453
+ paths = list(ex.map(_write, path_frame_list))
454
+
455
+ return final_indices, paths
456
+
457
+
458
+ def video_to_image_frames_new(
459
+ input_video_path,
460
+ save_directory=None,
461
+ min_frames=1,
462
+ max_frames=64,
463
+ fallback_fps=1,
464
+ ):
465
+ """
466
+ Motion-aware frame extraction with local clarity refinement.
467
+
468
+ Strategy:
469
+ 1. Sparse sampling (~0.5s) with DIS optical flow
470
+ 2. Adaptive threshold allocation based on motion
471
+ 3. Local clarity refinement (±3 frames) to avoid blur
472
+ """
473
+ if save_directory is None:
474
+ raise ValueError("save_directory must be provided")
475
+
476
+ max_frames = int(np.clip(max_frames, 1, 64))
477
+ min_frames = int(np.clip(min_frames, 1, max_frames))
478
+
479
+ cap = cv2.VideoCapture(input_video_path)
480
+ if not cap.isOpened():
481
+ print(f"Error: Failed to open video {input_video_path}")
482
+ return []
483
+
484
+ fps = cap.get(cv2.CAP_PROP_FPS) or fallback_fps or 30.0
485
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
486
+ t_start = time.perf_counter()
487
+
488
+ # Phase 1: Sparse motion analysis
489
+ sparse_samples = _sparse_motion_analysis(cap, fps, total_frames)
490
+ cap.release()
491
+
492
+ t_phase1 = time.perf_counter()
493
+ print(f"[Timing] Phase 1 (Sparse Flow): {t_phase1 - t_start:.3f}s, Samples: {len(sparse_samples)}")
494
+
495
+ if not sparse_samples:
496
+ return []
497
+
498
+ # Phase 2: Adaptive frame selection
499
+ candidate_indices = _adaptive_frame_selection(sparse_samples, fps, max_frames)
500
+ candidate_indices = _enforce_frame_constraints(candidate_indices, sparse_samples, min_frames, max_frames)
501
+
502
+ # Phase 3: Local clarity refinement
503
+ cap = cv2.VideoCapture(input_video_path)
504
+ if not cap.isOpened():
505
+ return []
506
+
507
+ t_phase3_start = time.perf_counter()
508
+ search_window_size = 3
509
+ merged_windows = _merge_search_windows(candidate_indices, window_size=search_window_size)
510
+
511
+ # Read frames
512
+ t_read_start = time.perf_counter()
513
+ all_frames = _read_window_frames(cap, merged_windows, total_frames)
514
+ cap.release()
515
+ t_read_end = time.perf_counter()
516
+
517
+ # Parallel clarity calculation
518
+ t_clarity_start = time.perf_counter()
519
+ clarity_results = _compute_clarity_parallel(all_frames)
520
+ t_clarity_end = time.perf_counter()
521
+
522
+ # Select best frames
523
+ target_to_best = _select_best_frames(clarity_results, merged_windows, candidate_indices, search_window_size)
524
+
525
+ # Parallel save
526
+ t_save_start = time.perf_counter()
527
+ final_indices, extracted_paths = _save_frames_parallel(target_to_best, candidate_indices, save_directory)
528
+ t_save_end = time.perf_counter()
529
+
530
+ t_phase3_end = time.perf_counter()
531
+ print(f"[Timing] Phase 3 (Clarity Refinement + Save): {t_phase3_end - t_phase3_start:.3f}s")
532
+ print(f" - Read frames: {t_read_end - t_read_start:.3f}s")
533
+ print(f" - Parallel clarity: {t_clarity_end - t_clarity_start:.3f}s")
534
+ print(f" - Parallel save: {t_save_end - t_save_start:.3f}s, Saved: {len(extracted_paths)}")
535
+
536
+ # Save metadata
537
+ try:
538
+ meta = {
539
+ "frame_indices": final_indices,
540
+ "frame_times": [i/fps for i in final_indices],
541
+ "fps": fps,
542
+ "algorithm": "sparse_dis_clarity_refined"
543
+ }
544
+ with open(os.path.join(save_directory, "frame_metadata.json"), "w") as f:
545
+ json.dump(meta, f, indent=2)
546
+
547
+ with open(os.path.join(save_directory, "frame_metrics.csv"), "w", newline="") as f:
548
+ writer = csv.writer(f)
549
+ writer.writerow(["frame_index", "time_sec", "motion", "selected"])
550
+ for s in sparse_samples:
551
+ writer.writerow([s["idx"], s["idx"]/fps, s["motion"],
552
+ 1 if s["idx"] in final_indices else 0])
553
+ except:
554
+ pass
555
+
556
+ print(f"Extracted {len(extracted_paths)} frames using DIS flow + local clarity refinement.")
557
+ return extracted_paths
hyworldmirror/utils/visual_util.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Visual utilities for HuggingFace integration.
2
+
3
+ References: https://github.com/facebookresearch/vggt
4
+ """
5
+
6
+ import copy
7
+ import os
8
+ from typing import Tuple
9
+
10
+ import cv2
11
+ import matplotlib
12
+ import numpy as np
13
+ import requests
14
+ import trimesh
15
+
16
+ from scipy.spatial.transform import Rotation
17
+
18
+
19
+ def segment_sky(image_or_path, onnx_session):
20
+ """
21
+ Segments sky from an image using an ONNX model.
22
+ Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
23
+
24
+ Args:
25
+ image_or_path: Path to input image (str) or BGR numpy array (H, W, 3)
26
+ onnx_session: ONNX runtime session with loaded model
27
+
28
+ Returns:
29
+ np.ndarray: Binary mask where 255 indicates non-sky regions
30
+ """
31
+
32
+ if isinstance(image_or_path, (str, os.PathLike)):
33
+ image = cv2.imread(str(image_or_path))
34
+ else:
35
+ image = image_or_path
36
+ result_map = run_skyseg(onnx_session, [320, 320], image)
37
+ # resize the result_map to the original image size
38
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
39
+
40
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
41
+ # The model outputs low values for sky, high values for non-sky
42
+ output_mask = np.zeros_like(result_map_original)
43
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
44
+ return output_mask
45
+
46
+
47
+ def run_skyseg(onnx_session, input_size, image):
48
+ """
49
+ Runs sky segmentation inference using ONNX model.
50
+
51
+ Args:
52
+ onnx_session: ONNX runtime session
53
+ input_size: Target size for model input (width, height)
54
+ image: Input image in BGR format
55
+
56
+ Returns:
57
+ np.ndarray: Segmentation mask
58
+ """
59
+
60
+ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
61
+ temp_image = copy.deepcopy(image)
62
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
63
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
64
+ x = np.array(x, dtype=np.float32)
65
+ mean = [0.485, 0.456, 0.406]
66
+ std = [0.229, 0.224, 0.225]
67
+ x = (x / 255 - mean) / std
68
+ x = x.transpose(2, 0, 1)
69
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
70
+
71
+ # Inference
72
+ input_name = onnx_session.get_inputs()[0].name
73
+ output_name = onnx_session.get_outputs()[0].name
74
+ onnx_result = onnx_session.run([output_name], {input_name: x})
75
+
76
+ # Post process
77
+ onnx_result = np.array(onnx_result).squeeze()
78
+ min_value = np.min(onnx_result)
79
+ max_value = np.max(onnx_result)
80
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
81
+ onnx_result *= 255
82
+ onnx_result = onnx_result.astype("uint8")
83
+
84
+ return onnx_result
85
+
86
+
87
+ def download_file_from_url(url, filename):
88
+ """Downloads a file from a Hugging Face model repo, handling redirects."""
89
+ try:
90
+ # Get the redirect URL
91
+ response = requests.get(url, allow_redirects=False)
92
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
93
+
94
+ if response.status_code == 302: # Expecting a redirect
95
+ redirect_url = response.headers["Location"]
96
+ response = requests.get(redirect_url, stream=True)
97
+ response.raise_for_status()
98
+ else:
99
+ print(f"Unexpected status code: {response.status_code}")
100
+ return
101
+
102
+ with open(filename, "wb") as f:
103
+ for chunk in response.iter_content(chunk_size=8192):
104
+ f.write(chunk)
105
+ print(f"Downloaded {filename} successfully.")
106
+
107
+ except requests.exceptions.RequestException as e:
108
+ print(f"Error downloading file: {e}")
109
+
110
+
111
+ def create_image_mesh(
112
+ *image_data: np.ndarray,
113
+ mask: np.ndarray = None,
114
+ triangulate: bool = False,
115
+ return_vertex_indices: bool = False,
116
+ ) -> Tuple[np.ndarray, ...]:
117
+ """
118
+ Create a mesh from image data using pixel coordinates as vertices and grid connections as faces.
119
+
120
+ Args:
121
+ *image_data (np.ndarray): Image arrays with shape (height, width, [channels])
122
+ mask (np.ndarray, optional): Boolean mask with shape (height, width). Defaults to None.
123
+ triangulate (bool): Convert quad faces to triangular faces. Defaults to False.
124
+ return_vertex_indices (bool): Include vertex indices in output. Defaults to False.
125
+
126
+ Returns:
127
+ faces (np.ndarray): Face connectivity array. Shape (N, 4) for quads or (N, 3) for triangles
128
+ *vertex_data (np.ndarray): Vertex attributes corresponding to input image_data
129
+ vertex_indices (np.ndarray, optional): Original vertex indices if return_vertex_indices=True
130
+ """
131
+ # Validate inputs
132
+ assert (len(image_data) > 0) or (mask is not None), "Need at least one image or mask"
133
+
134
+ if mask is None:
135
+ height, width = image_data[0].shape[:2]
136
+ else:
137
+ height, width = mask.shape
138
+
139
+ # Check all images have same dimensions
140
+ for img in image_data:
141
+ assert img.shape[:2] == (height, width), "All images must have same height and width"
142
+
143
+ # Create quad faces connecting neighboring pixels
144
+ base_quad = np.stack([
145
+ np.arange(0, width - 1, dtype=np.int32), # bottom-left
146
+ np.arange(width, 2 * width - 1, dtype=np.int32), # top-left
147
+ np.arange(1 + width, 2 * width, dtype=np.int32), # top-right
148
+ np.arange(1, width, dtype=np.int32), # bottom-right
149
+ ], axis=1)
150
+
151
+ # Replicate quad pattern for all rows
152
+ row_offsets = np.arange(0, (height - 1) * width, width, dtype=np.int32)
153
+ faces = (row_offsets[:, None, None] + base_quad[None, :, :]).reshape((-1, 4))
154
+
155
+ if mask is None:
156
+ # No masking - use all faces and vertices
157
+ if triangulate:
158
+ faces = _convert_quads_to_triangles(faces)
159
+
160
+ output = [faces]
161
+ for img in image_data:
162
+ output.append(img.reshape(-1, *img.shape[2:]))
163
+
164
+ if return_vertex_indices:
165
+ output.append(np.arange(height * width, dtype=np.int32))
166
+
167
+ return tuple(output)
168
+ else:
169
+ # Apply mask - only keep faces where all 4 corners are valid
170
+ valid_quads = (
171
+ mask[:-1, :-1] & mask[1:, :-1] &
172
+ mask[1:, 1:] & mask[:-1, 1:]
173
+ ).ravel()
174
+ faces = faces[valid_quads]
175
+
176
+ if triangulate:
177
+ faces = _convert_quads_to_triangles(faces)
178
+
179
+ # Remove unused vertices and remap face indices
180
+ num_face_vertices = faces.shape[-1]
181
+ unique_vertices, remapped_indices = np.unique(faces, return_inverse=True)
182
+ faces = remapped_indices.astype(np.int32).reshape(-1, num_face_vertices)
183
+
184
+ output = [faces]
185
+ for img in image_data:
186
+ flattened_img = img.reshape(-1, *img.shape[2:])
187
+ output.append(flattened_img[unique_vertices])
188
+
189
+ if return_vertex_indices:
190
+ output.append(unique_vertices)
191
+
192
+ return tuple(output)
193
+
194
+
195
+ def _convert_quads_to_triangles(quad_faces: np.ndarray) -> np.ndarray:
196
+ """Convert quadrilateral faces to triangular faces."""
197
+ if quad_faces.shape[-1] == 3:
198
+ return quad_faces # Already triangular
199
+
200
+ num_vertices_per_face = quad_faces.shape[-1]
201
+ triangle_indices = np.stack([
202
+ np.zeros(num_vertices_per_face - 2, dtype=int), # First vertex
203
+ np.arange(1, num_vertices_per_face - 1, dtype=int), # Sequential vertices
204
+ np.arange(2, num_vertices_per_face, dtype=int), # Next sequential vertices
205
+ ], axis=1)
206
+
207
+ return quad_faces[:, triangle_indices].reshape((-1, 3))
208
+
209
+
210
+ def convert_predictions_to_glb_scene(
211
+ predictions,
212
+ filter_by_frames="all",
213
+ show_camera=True,
214
+ mask_sky_bg=False,
215
+ mask_ambiguous=False,
216
+ as_mesh=True,
217
+ ) -> trimesh.Scene:
218
+ """
219
+ Converts model predictions to a 3D scene represented as a GLB file.
220
+
221
+ Args:
222
+ predictions (dict): Dictionary containing model predictions with keys:
223
+ - world_points: 3D point coordinates (S, H, W, 3)
224
+ - images: Input images (S, H, W, 3)
225
+ - camera_poses: Camera extrinsic matrices (S, 3, 4)
226
+ filter_by_frames (str): Frame filter specification (default: "all")
227
+ show_camera (bool): Include camera visualization (default: True)
228
+ mask_sky_bg (bool): Mask out sky background pixels (default: False)
229
+ mask_ambiguous (bool): Apply final mask to filter ambiguous predictions (default: False)
230
+ as_mesh (bool): Represent the data as a mesh instead of point cloud (default: False)
231
+
232
+ Returns:
233
+ trimesh.Scene: Processed 3D scene containing point cloud/mesh and cameras
234
+
235
+ Raises:
236
+ ValueError: If input predictions structure is invalid
237
+ """
238
+ if not isinstance(predictions, dict):
239
+ raise ValueError("predictions must be a dictionary")
240
+
241
+ print("Building GLB scene")
242
+
243
+ # Parse frame selection from filter string
244
+ target_frame_index = None
245
+ if filter_by_frames not in ["all", "All"]:
246
+ try:
247
+ # Extract numeric index before colon separator
248
+ target_frame_index = int(filter_by_frames.split(":")[0])
249
+ except (ValueError, IndexError):
250
+ pass
251
+
252
+ # Validate required data in predictions
253
+ print("Using Pointmap Branch")
254
+ if "world_points" not in predictions:
255
+ raise ValueError(
256
+ "world_points not found in predictions. Pointmap Branch requires 'world_points' key. "
257
+ "Depthmap and Camera branches have been removed."
258
+ )
259
+
260
+ # Extract prediction data
261
+ point_cloud_3d = predictions["world_points"]
262
+ input_images = predictions["images"]
263
+ extrinsic_matrices = predictions["camera_poses"]
264
+ ambiguity_mask = predictions["final_mask"]
265
+ sky_region_mask = predictions["sky_mask"]
266
+
267
+ # Filter to single frame if specified
268
+ if target_frame_index is not None:
269
+ point_cloud_3d = point_cloud_3d[target_frame_index][None]
270
+ input_images = input_images[target_frame_index][None]
271
+ extrinsic_matrices = extrinsic_matrices[target_frame_index][None]
272
+ ambiguity_mask = ambiguity_mask[target_frame_index][None]
273
+ sky_region_mask = sky_region_mask[target_frame_index][None]
274
+
275
+ # Flatten 3D points to vertex array
276
+ flattened_vertices = point_cloud_3d.reshape(-1, 3)
277
+
278
+ # Convert images to RGB color array
279
+ if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
280
+ rgb_colors = np.transpose(input_images, (0, 2, 3, 1))
281
+ else: # Already in NHWC format
282
+ rgb_colors = input_images
283
+ rgb_colors = (rgb_colors.reshape(-1, 3) * 255).astype(np.uint8)
284
+
285
+ # Build composite filtering mask
286
+ valid_points_mask = np.ones(len(flattened_vertices), dtype=bool)
287
+
288
+ # Apply ambiguity filtering if requested
289
+ if mask_ambiguous:
290
+ flat_ambiguity_mask = ambiguity_mask.reshape(-1)
291
+ valid_points_mask = valid_points_mask & flat_ambiguity_mask
292
+
293
+ # Apply sky region filtering if requested
294
+ if mask_sky_bg:
295
+ flat_sky_mask = sky_region_mask.reshape(-1)
296
+ valid_points_mask = valid_points_mask & flat_sky_mask
297
+
298
+ # Apply mask to filter vertices and colors
299
+ filtered_vertices = flattened_vertices[valid_points_mask].copy()
300
+ filtered_colors = rgb_colors[valid_points_mask].copy()
301
+
302
+ # Handle empty geometry case
303
+ if filtered_vertices is None or np.asarray(filtered_vertices).size == 0:
304
+ filtered_vertices = np.array([[1, 0, 0]])
305
+ filtered_colors = np.array([[255, 255, 255]])
306
+ scene_scale_factor = 1
307
+ else:
308
+ # Compute scene scale from percentile-based bounding box
309
+ percentile_lower = np.percentile(filtered_vertices, 5, axis=0)
310
+ percentile_upper = np.percentile(filtered_vertices, 95, axis=0)
311
+ scene_scale_factor = np.linalg.norm(percentile_upper - percentile_lower)
312
+
313
+ # Initialize color mapping for cameras
314
+ color_palette = matplotlib.colormaps.get_cmap("gist_rainbow")
315
+
316
+ # Create empty 3D scene container
317
+ output_scene = trimesh.Scene()
318
+
319
+ # Add geometry to scene based on representation type
320
+ if as_mesh:
321
+ # Mesh representation
322
+ if target_frame_index is not None:
323
+ # Single frame mesh generation
324
+ frame_height, frame_width = point_cloud_3d.shape[1:3]
325
+
326
+ # Prepare unfiltered data for mesh construction
327
+ structured_points = point_cloud_3d.reshape(frame_height, frame_width, 3)
328
+
329
+ # Convert image data to proper format
330
+ if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
331
+ structured_colors = np.transpose(input_images[0], (1, 2, 0))
332
+ else: # Already in HWC format
333
+ structured_colors = input_images[0]
334
+ structured_colors *= 255
335
+
336
+ # Get structured mask for mesh creation
337
+ structured_mask = predictions["final_mask"][target_frame_index].reshape(
338
+ frame_height, frame_width
339
+ )
340
+
341
+ # Build filtering mask
342
+ mesh_filter_mask = structured_mask
343
+
344
+ # Check for normal data availability
345
+ mesh_normals = None
346
+ if "normal" in predictions and predictions["normal"] is not None:
347
+ # Extract normals for selected frame
348
+ frame_normal_data = (
349
+ predictions["normal"][target_frame_index]
350
+ if target_frame_index is not None
351
+ else predictions["normal"][0]
352
+ )
353
+
354
+ # Generate mesh with normal information
355
+ mesh_faces, mesh_vertices, mesh_colors, mesh_normals = create_image_mesh(
356
+ structured_points * np.array([1, -1, 1], dtype=np.float32),
357
+ structured_colors / 255.0,
358
+ frame_normal_data * np.array([1, -1, 1], dtype=np.float32),
359
+ mask=mesh_filter_mask,
360
+ triangulate=True,
361
+ return_vertex_indices=False,
362
+ )
363
+
364
+ # Apply coordinate system transformation to normals
365
+ mesh_normals = mesh_normals * np.array([1, -1, 1], dtype=np.float32)
366
+ else:
367
+ # Generate mesh without normal information
368
+ mesh_faces, mesh_vertices, mesh_colors = create_image_mesh(
369
+ structured_points * np.array([1, -1, 1], dtype=np.float32),
370
+ structured_colors / 255.0,
371
+ mask=mesh_filter_mask,
372
+ triangulate=True,
373
+ return_vertex_indices=False,
374
+ )
375
+
376
+ # Construct trimesh object with optional normals
377
+ geometry_mesh = trimesh.Trimesh(
378
+ vertices=mesh_vertices * np.array([1, -1, 1], dtype=np.float32),
379
+ faces=mesh_faces,
380
+ vertex_colors=(mesh_colors * 255).astype(np.uint8),
381
+ vertex_normals=(mesh_normals if mesh_normals is not None else None),
382
+ process=False,
383
+ )
384
+ output_scene.add_geometry(geometry_mesh)
385
+ else:
386
+ # Multi-frame mesh generation
387
+ print("Creating mesh for multi-frame data...")
388
+
389
+ for frame_idx in range(point_cloud_3d.shape[0]):
390
+ frame_height, frame_width = point_cloud_3d.shape[1:3]
391
+
392
+ # Extract per-frame data
393
+ frame_point_data = point_cloud_3d[frame_idx]
394
+ frame_ambiguity_mask = predictions["final_mask"][frame_idx]
395
+ frame_sky_mask = predictions["sky_mask"][frame_idx]
396
+
397
+ # Extract frame image data
398
+ if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
399
+ frame_image_data = np.transpose(input_images[frame_idx], (1, 2, 0))
400
+ else: # Already in HWC format
401
+ frame_image_data = input_images[frame_idx]
402
+ frame_image_data *= 255
403
+
404
+ # Build per-frame filtering mask
405
+ frame_filter_mask = np.ones((frame_height, frame_width), dtype=bool)
406
+
407
+ # Apply ambiguity filtering if enabled
408
+ if mask_ambiguous:
409
+ frame_filter_mask = frame_filter_mask & frame_ambiguity_mask
410
+
411
+ # Apply sky filtering if enabled
412
+ if mask_sky_bg:
413
+ frame_filter_mask = frame_filter_mask & frame_sky_mask
414
+
415
+ # Generate mesh for current frame
416
+ frame_faces, frame_vertices, frame_colors = create_image_mesh(
417
+ frame_point_data * np.array([1, -1, 1], dtype=np.float32),
418
+ frame_image_data / 255.0,
419
+ mask=frame_filter_mask,
420
+ triangulate=True,
421
+ return_vertex_indices=False,
422
+ )
423
+
424
+ frame_vertices = frame_vertices * np.array([1, -1, 1], dtype=np.float32)
425
+
426
+ # Create trimesh object for current frame
427
+ frame_geometry = trimesh.Trimesh(
428
+ vertices=frame_vertices,
429
+ faces=frame_faces,
430
+ vertex_colors=(frame_colors * 255).astype(np.uint8),
431
+ process=False,
432
+ )
433
+ output_scene.add_geometry(frame_geometry)
434
+ else:
435
+ # Point cloud representation
436
+ point_cloud_geometry = trimesh.PointCloud(vertices=filtered_vertices, colors=filtered_colors)
437
+ output_scene.add_geometry(point_cloud_geometry)
438
+
439
+ # Add camera visualizations if requested
440
+ num_camera_views = len(extrinsic_matrices)
441
+
442
+ if show_camera:
443
+ # Iterate through all camera views
444
+ for camera_idx in range(num_camera_views):
445
+ camera_extrinsic = extrinsic_matrices[camera_idx]
446
+ camera_color_rgba = color_palette(camera_idx / num_camera_views)
447
+ camera_color_rgb = tuple(int(255 * x) for x in camera_color_rgba[:3])
448
+
449
+ integrate_camera_into_scene(
450
+ output_scene, camera_extrinsic, camera_color_rgb, scene_scale_factor
451
+ )
452
+
453
+ # Define coordinate system transformation matrices
454
+ opengl_transform = np.eye(4)
455
+ opengl_transform[1, 1] = -1 # Flip Y axis
456
+ opengl_transform[2, 2] = -1 # Flip Z axis
457
+
458
+ # Define alignment rotation (180 degrees around Y-axis)
459
+ alignment_rotation = np.eye(4)
460
+ alignment_rotation[:3, :3] = Rotation.from_euler("y", 0, degrees=True).as_matrix()
461
+
462
+ # Compute and apply final transformation
463
+ scene_transformation = (
464
+ np.linalg.inv(extrinsic_matrices[0])
465
+ @ opengl_transform
466
+ @ alignment_rotation
467
+ )
468
+ output_scene.apply_transform(scene_transformation)
469
+
470
+ print("GLB Scene built")
471
+ return output_scene
472
+
473
+ def integrate_camera_into_scene(
474
+ scene: trimesh.Scene,
475
+ camera_transform: np.ndarray,
476
+ camera_color: tuple,
477
+ scale_factor: float,
478
+ ):
479
+ """
480
+ Adds a camera visualization mesh to the 3D scene.
481
+
482
+ Args:
483
+ scene (trimesh.Scene): The 3D scene to add the camera visualization.
484
+ camera_transform (np.ndarray): 4x4 transformation matrix for camera positioning.
485
+ camera_color (tuple): RGB color tuple for the camera mesh.
486
+ scale_factor (float): Scaling factor for the camera size relative to scene.
487
+ """
488
+ # Define camera dimensions based on scene scale
489
+ camera_base_width = scale_factor * 0.05
490
+ camera_cone_height = scale_factor * 0.1
491
+
492
+ # Create base cone geometry for camera representation
493
+ base_cone = trimesh.creation.cone(camera_base_width, camera_cone_height, sections=4)
494
+
495
+ # Setup rotation transformation (45 degrees around z-axis)
496
+ z_rotation_matrix = np.eye(4)
497
+ z_rotation_matrix[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
498
+ z_rotation_matrix[2, 3] = -camera_cone_height
499
+
500
+ # Setup OpenGL coordinate system conversion
501
+ opengl_coord_transform = np.eye(4)
502
+ opengl_coord_transform[1, 1] = -1 # Flip Y axis
503
+ opengl_coord_transform[2, 2] = -1 # Flip Z axis
504
+
505
+ # Combine all transformations
506
+ final_transform = camera_transform @ opengl_coord_transform @ z_rotation_matrix
507
+
508
+ # Create slight rotation for mesh variation
509
+ minor_rotation = np.eye(4)
510
+ minor_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
511
+
512
+ # Generate multiple vertex sets for complex camera geometry
513
+ original_vertices = base_cone.vertices
514
+ scaled_vertices = 0.95 * original_vertices
515
+ rotated_vertices = apply_transformation_to_points(minor_rotation, original_vertices)
516
+
517
+ # Combine all vertex sets
518
+ all_vertices = np.concatenate([
519
+ original_vertices,
520
+ scaled_vertices,
521
+ rotated_vertices
522
+ ])
523
+
524
+ # Transform vertices to final position
525
+ transformed_vertices = apply_transformation_to_points(final_transform, all_vertices)
526
+
527
+ # Generate faces for the complete camera mesh
528
+ camera_faces = generate_camera_mesh_faces(base_cone)
529
+
530
+ # Create and configure the camera mesh
531
+ camera_mesh = trimesh.Trimesh(
532
+ vertices=transformed_vertices,
533
+ faces=camera_faces
534
+ )
535
+ camera_mesh.visual.face_colors[:, :3] = camera_color
536
+
537
+ # Add the camera mesh to the scene
538
+ scene.add_geometry(camera_mesh)
539
+
540
+
541
+ def apply_transformation_to_points(
542
+ transform_matrix: np.ndarray, point_array: np.ndarray, output_dim: int = None
543
+ ) -> np.ndarray:
544
+ """
545
+ Applies a 4x4 transformation matrix to a collection of 3D points.
546
+
547
+ Args:
548
+ transform_matrix (np.ndarray): 4x4 transformation matrix to apply.
549
+ point_array (np.ndarray): Array of points to transform.
550
+ output_dim (int, optional): Target dimension for output points.
551
+
552
+ Returns:
553
+ np.ndarray: Array of transformed points.
554
+ """
555
+ point_array = np.asarray(point_array)
556
+ original_shape = point_array.shape[:-1]
557
+ target_dim = output_dim or point_array.shape[-1]
558
+
559
+ # Transpose transformation matrix for matrix multiplication
560
+ transposed_transform = transform_matrix.swapaxes(-1, -2)
561
+
562
+ # Apply rotation/scaling and translation components
563
+ transformed_points = (
564
+ point_array @ transposed_transform[..., :-1, :] +
565
+ transposed_transform[..., -1:, :]
566
+ )
567
+
568
+ # Extract desired dimensions and restore original shape
569
+ final_result = transformed_points[..., :target_dim].reshape(*original_shape, target_dim)
570
+ return final_result
571
+
572
+
573
+ def generate_camera_mesh_faces(base_cone_mesh: trimesh.Trimesh) -> np.ndarray:
574
+ """
575
+ Generates face indices for a complex camera mesh composed of multiple cone layers.
576
+
577
+ Args:
578
+ base_cone_mesh (trimesh.Trimesh): Base cone geometry used as template.
579
+
580
+ Returns:
581
+ np.ndarray: Array of face indices defining the camera mesh topology.
582
+ """
583
+ face_indices = []
584
+ vertex_count_per_cone = len(base_cone_mesh.vertices)
585
+
586
+ # Process each face of the base cone
587
+ for triangle_face in base_cone_mesh.faces:
588
+ # Skip faces that include the cone tip (vertex 0)
589
+ if 0 in triangle_face:
590
+ continue
591
+
592
+ # Get vertex indices for current triangle
593
+ vertex_a, vertex_b, vertex_c = triangle_face
594
+
595
+ # Calculate corresponding vertices in second and third cone layers
596
+ vertex_a_layer2, vertex_b_layer2, vertex_c_layer2 = triangle_face + vertex_count_per_cone
597
+ vertex_a_layer3, vertex_b_layer3, vertex_c_layer3 = triangle_face + 2 * vertex_count_per_cone
598
+
599
+ # Create connecting faces between cone layers
600
+ connecting_faces = [
601
+ (vertex_a, vertex_b, vertex_b_layer2),
602
+ (vertex_a, vertex_a_layer2, vertex_c),
603
+ (vertex_c_layer2, vertex_b, vertex_c),
604
+ (vertex_a, vertex_b, vertex_b_layer3),
605
+ (vertex_a, vertex_a_layer3, vertex_c),
606
+ (vertex_c_layer3, vertex_b, vertex_c),
607
+ ]
608
+
609
+ face_indices.extend(connecting_faces)
610
+
611
+ # Add reverse-winding faces for proper mesh closure
612
+ reversed_faces = [(vertex_c, vertex_b, vertex_a) for vertex_a, vertex_b, vertex_c in face_indices]
613
+ face_indices.extend(reversed_faces)
614
+
615
+ return np.array(face_indices)
616
+
617
+
hyworldmirror/utils/warnings.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper utilities for warnings.
3
+ """
4
+
5
+ import warnings
6
+ from functools import wraps
7
+
8
+
9
+ class no_warnings:
10
+ def __init__(self, action: str = "ignore", **kwargs):
11
+ self.action = action
12
+ self.filter_kwargs = kwargs
13
+
14
+ def __call__(self, fn):
15
+ @wraps(fn)
16
+ def wrapper(*args, **kwargs):
17
+ with warnings.catch_warnings():
18
+ warnings.simplefilter(self.action, **self.filter_kwargs)
19
+ return fn(*args, **kwargs)
20
+
21
+ return wrapper
22
+
23
+ def __enter__(self):
24
+ self.warnings_manager = warnings.catch_warnings()
25
+ self.warnings_manager.__enter__()
26
+ warnings.simplefilter(self.action, **self.filter_kwargs)
27
+
28
+ def __exit__(self, exc_type, exc_val, exc_tb):
29
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
pipeline.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HunyuanWorld-Mirror Inference Pipeline
3
+
4
+ Usage:
5
+ # Python API — Single GPU
6
+ from hyworld2.worldrecon.pipeline import WorldMirrorPipeline
7
+ pipeline = WorldMirrorPipeline.from_pretrained('tencent/HY-World-2.0')
8
+ result = pipeline('path/to/images')
9
+
10
+ # Python API — Multi-GPU (in a torchrun script)
11
+ pipeline = WorldMirrorPipeline.from_pretrained(
12
+ 'tencent/HY-World-2.0', use_fsdp=True, enable_bf16=True)
13
+ result = pipeline('path/to/images')
14
+
15
+ # CLI — Single GPU
16
+ python -m hyworld2.worldrecon.pipeline --input_path path/to/images
17
+
18
+ # CLI — Multi-GPU
19
+ torchrun --nproc_per_node=2 -m hyworld2.worldrecon.pipeline --input_path path/to/images --use_fsdp --enable_bf16
20
+ """
21
+
22
+ import argparse
23
+ import functools
24
+ import gc
25
+ import os
26
+ import time
27
+ from datetime import datetime, timedelta
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.distributed as dist
33
+ from omegaconf import OmegaConf
34
+ from safetensors.torch import load_file as load_safetensors
35
+ from torch.distributed.fsdp import (
36
+ FullyShardedDataParallel as FSDP,
37
+ ShardingStrategy,
38
+ CPUOffload,
39
+ )
40
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
41
+
42
+ from .hyworldmirror.models.models.worldmirror import WorldMirror
43
+ from .hyworldmirror.models.layers.block import Block, DistBlock
44
+ from .hyworldmirror.models.heads.dense_head import DPTHead
45
+ from .hyworldmirror.models.heads.camera_head import CameraHead
46
+ from .hyworldmirror.utils.inference_utils import (
47
+ prepare_images_to_tensor,
48
+ prepare_input,
49
+ compute_adaptive_target_size,
50
+ compute_preprocessing_transform,
51
+ load_prior_camera,
52
+ load_prior_depth,
53
+ compute_sky_mask,
54
+ compute_filter_mask,
55
+ save_results,
56
+ print_and_save_timings,
57
+ )
58
+ from .hyworldmirror.utils.render_utils import render_interpolated_video
59
+
60
+
61
+ # ============================================================
62
+ # Model loading helpers (checkpoint, config, selective load)
63
+ # ============================================================
64
+
65
+ def _get_model_config_from_yaml(cfg) -> dict:
66
+ if hasattr(cfg, "wrapper") and hasattr(cfg.wrapper, "model"):
67
+ model_cfg = cfg.wrapper.model
68
+ elif hasattr(cfg, "model"):
69
+ model_cfg = cfg.model
70
+ else:
71
+ raise ValueError("No model config found (expect wrapper.model or model).")
72
+ out = OmegaConf.to_container(model_cfg, resolve=True)
73
+ out.pop("_target_", None)
74
+ return out
75
+
76
+
77
+ def _load_checkpoint_state_dict(ckpt_path: str) -> dict:
78
+ if ckpt_path.endswith(".safetensors"):
79
+ return load_safetensors(ckpt_path)
80
+ ckpt = torch.load(ckpt_path, map_location="cpu")
81
+ state = ckpt.get("state_dict", ckpt)
82
+ if "state_dict" in ckpt:
83
+ state = {k.replace("model.", ""): v for k, v in state.items()}
84
+ return state
85
+
86
+
87
+ def _load_state_dict_selective(model, ckpt_state, source_name="checkpoint"):
88
+ current = model.state_dict()
89
+ for key in current:
90
+ if key in ckpt_state and current[key].shape == ckpt_state[key].shape:
91
+ current[key] = ckpt_state[key]
92
+ model.load_state_dict(current, strict=True)
93
+ matched = sum(1 for k in current if k in ckpt_state and current[k].shape == ckpt_state[k].shape)
94
+ print(f" Loaded {matched}/{len(current)} keys from {source_name}")
95
+
96
+
97
+ def _has_model_files(path: str) -> bool:
98
+ """Check whether a directory contains the expected model artifacts."""
99
+ has_weights = os.path.isfile(os.path.join(path, "model.safetensors"))
100
+ has_config = (os.path.isfile(os.path.join(path, "config.yaml"))
101
+ or os.path.isfile(os.path.join(path, "config.json")))
102
+ return has_weights and has_config
103
+
104
+
105
+ def _resolve_model_dir(model_path: str, subfolder: str) -> str:
106
+ """Resolve a local directory containing config + model.safetensors.
107
+
108
+ Resolution order:
109
+ 1. {model_path}/{subfolder} — local repo root with subfolder
110
+ 2. {model_path} — direct local path (backward compat)
111
+ 3. HuggingFace download: snapshot_download(repo_id, allow_patterns=[subfolder/*])
112
+ """
113
+ candidate = os.path.join(model_path, subfolder)
114
+ if os.path.isdir(candidate) and _has_model_files(candidate):
115
+ print(f"[Init] Found local model at {candidate}")
116
+ return candidate
117
+
118
+ if os.path.isdir(model_path) and _has_model_files(model_path):
119
+ print(f"[Init] Found local model at {model_path}")
120
+ return model_path
121
+
122
+ print(f"[Init] Downloading from HuggingFace: {model_path} (subfolder={subfolder})")
123
+ from huggingface_hub import snapshot_download
124
+ repo_root = snapshot_download(
125
+ repo_id=model_path,
126
+ allow_patterns=[f"{subfolder}/*"],
127
+ )
128
+ resolved = os.path.join(repo_root, subfolder)
129
+ if not _has_model_files(resolved):
130
+ raise FileNotFoundError(
131
+ f"Downloaded repo '{model_path}' but subfolder '{subfolder}' "
132
+ f"does not contain model.safetensors + config. "
133
+ f"Check that the repo and subfolder name are correct."
134
+ )
135
+ return resolved
136
+
137
+
138
+ def _load_model_config(model_dir: str) -> dict:
139
+ """Load model constructor kwargs from config.yaml or config.json in model_dir."""
140
+ import json as _json
141
+ yaml_path = os.path.join(model_dir, "config.yaml")
142
+ json_path = os.path.join(model_dir, "config.json")
143
+
144
+ if os.path.isfile(yaml_path):
145
+ cfg = OmegaConf.load(yaml_path)
146
+ return _get_model_config_from_yaml(cfg)
147
+ elif os.path.isfile(json_path):
148
+ with open(json_path) as f:
149
+ return _json.load(f)
150
+ else:
151
+ raise FileNotFoundError(f"No config.yaml or config.json in {model_dir}")
152
+
153
+
154
+ # ============================================================
155
+ # FSDP / bf16 helpers
156
+ # ============================================================
157
+
158
+ def _collect_fp32_critical_modules(model):
159
+ from .hyworldmirror.models.layers.mlp import MlpFP32
160
+ critical = set()
161
+ for _, module in model.named_modules():
162
+ if isinstance(module, MlpFP32) and hasattr(module, 'fc2'):
163
+ if any(p.dtype == torch.float32 for p in module.fc2.parameters()):
164
+ critical.add(module.fc2)
165
+ if hasattr(module, 'scratch') and hasattr(module.scratch, 'output_conv2'):
166
+ oc2 = module.scratch.output_conv2
167
+ if any(p.dtype == torch.float32 for p in oc2.parameters()):
168
+ critical.add(oc2)
169
+ return critical
170
+
171
+
172
+ def _cast_noncritical_fp32_to_bf16(model, critical_modules):
173
+ critical_ids = {id(p) for mod in critical_modules for p in mod.parameters()}
174
+ cast = []
175
+ for name, param in model.named_parameters():
176
+ if param.dtype == torch.float32 and id(param) not in critical_ids:
177
+ param.data = param.data.to(torch.bfloat16)
178
+ cast.append(name)
179
+ for _, buf in model.named_buffers():
180
+ if buf.dtype == torch.float32:
181
+ buf.data = buf.data.to(torch.bfloat16)
182
+
183
+ def _hook(module, args):
184
+ if not args:
185
+ return args
186
+ dtype = next((p.dtype for p in module.parameters(recurse=False)), None)
187
+ if dtype is None:
188
+ return args
189
+ return tuple(a.to(dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() and a.dtype != dtype else a
190
+ for a in args)
191
+
192
+ for name, module in model.named_modules():
193
+ if not any(True for _ in module.children()):
194
+ own = list(module.named_parameters(recurse=False))
195
+ if own and all(p.dtype == torch.bfloat16 for _, p in own):
196
+ pfx = name + "." if name else ""
197
+ if any(c.startswith(pfx) for c in cast):
198
+ module.register_forward_pre_hook(_hook)
199
+
200
+
201
+ def _wrap_model_fsdp(model, sp_group, device, use_cpu_offload=False, enable_bf16=False):
202
+ wrap_cls = {DistBlock, Block, DPTHead, CameraHead}
203
+ if enable_bf16:
204
+ fp32_critical = _collect_fp32_critical_modules(model)
205
+ def policy(module, recurse, nonwrapped_numel, **kw):
206
+ if recurse:
207
+ return True
208
+ return isinstance(module, tuple(wrap_cls)) or module in fp32_critical
209
+ auto_wrap_policy = policy
210
+ else:
211
+ auto_wrap_policy = functools.partial(
212
+ transformer_auto_wrap_policy, transformer_layer_cls=wrap_cls)
213
+
214
+ fsdp_model = FSDP(
215
+ model, process_group=sp_group,
216
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
217
+ auto_wrap_policy=auto_wrap_policy, mixed_precision=None,
218
+ cpu_offload=CPUOffload(offload_params=True) if use_cpu_offload else None,
219
+ device_id=device, use_orig_params=True, sync_module_states=True,
220
+ forward_prefetch=False,
221
+ )
222
+
223
+ rank = dist.get_rank()
224
+ if rank == 0:
225
+ total = sum(p.numel() for p in fsdp_model.parameters())
226
+ local = sum(getattr(p, '_local_tensor', p).numel() for p in fsdp_model.parameters())
227
+ print(f"[FSDP] total={total/1e6:.1f}M, local≈{local/1e6:.1f}M")
228
+ return fsdp_model
229
+
230
+
231
+ # ============================================================
232
+ # WorldMirrorPipeline
233
+ # ============================================================
234
+
235
+ class WorldMirrorPipeline:
236
+ """HunyuanWorld-Mirror inference pipeline.
237
+
238
+ Supports single-GPU and multi-GPU (Sequence Parallel) inference with
239
+ a unified API. Multi-GPU mode is auto-detected from torch.distributed.
240
+ """
241
+
242
+ def __init__(self, model, device, sp_size=1, sp_group=None, rank=0):
243
+ self.model = model
244
+ self.device = device
245
+ self.sp_size = sp_size
246
+ self.sp_group = sp_group
247
+ self.rank = rank
248
+
249
+ @classmethod
250
+ def from_pretrained(
251
+ cls,
252
+ pretrained_model_name_or_path: str = "tencent/HY-World-2.0",
253
+ *,
254
+ subfolder: str = "HY-WorldMirror-2.0",
255
+ config_path: str = None,
256
+ ckpt_path: str = None,
257
+ use_fsdp: bool = False,
258
+ enable_bf16: bool = False,
259
+ fsdp_cpu_offload: bool = False,
260
+ disable_heads: list = None,
261
+ ) -> "WorldMirrorPipeline":
262
+ """Load model and create pipeline instance.
263
+
264
+ Automatically detects distributed mode (torchrun sets WORLD_SIZE).
265
+
266
+ Args:
267
+ pretrained_model_name_or_path: HuggingFace repo ID or local path.
268
+ The model files are expected under ``{path}/{subfolder}/``.
269
+ subfolder: Subfolder inside the repo that contains the WorldMirror
270
+ checkpoint (model.safetensors + config).
271
+ config_path: Training config YAML (used with ckpt_path).
272
+ ckpt_path: Checkpoint file (.ckpt / .safetensors).
273
+ use_fsdp: Shard parameters across GPUs via FSDP.
274
+ enable_bf16: Use bf16 precision (except critical layers).
275
+ fsdp_cpu_offload: Offload FSDP params to CPU.
276
+ disable_heads: List of heads to disable, e.g. ["camera", "depth"].
277
+ """
278
+ is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
279
+
280
+ if is_distributed:
281
+ if not dist.is_initialized():
282
+ dist.init_process_group(backend="nccl")
283
+ rank = dist.get_rank()
284
+ world_size = dist.get_world_size()
285
+ local_rank = int(os.environ.get("LOCAL_RANK", rank))
286
+ torch.cuda.set_device(local_rank)
287
+ device = torch.device("cuda", local_rank)
288
+ sp_size = world_size
289
+ sp_group = dist.new_group(ranks=list(range(sp_size)))
290
+ if rank == 0:
291
+ print(f"[Pipeline] Multi-GPU: world_size={world_size}, sp_size={sp_size}")
292
+ else:
293
+ rank, sp_size = 0, 1
294
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
295
+ sp_group = None
296
+ if use_fsdp:
297
+ print("[Pipeline] Warning: use_fsdp is ignored in single-GPU mode (FSDP requires torchrun with multiple GPUs)")
298
+ use_fsdp = False
299
+ print("[Pipeline] Single-GPU mode")
300
+
301
+ # Load model
302
+ t0 = time.perf_counter()
303
+ if config_path and ckpt_path:
304
+ print(f"[Init] config={config_path}, ckpt={ckpt_path}, sp_size={sp_size}")
305
+ cfg = OmegaConf.load(config_path)
306
+ model_cfg = _get_model_config_from_yaml(cfg)
307
+ if sp_size > 1:
308
+ model_cfg["sp_size"] = sp_size
309
+ if enable_bf16:
310
+ model_cfg["enable_bf16"] = True
311
+ model = WorldMirror(**model_cfg).to(device)
312
+ state = _load_checkpoint_state_dict(ckpt_path)
313
+ _load_state_dict_selective(model, state, source_name=ckpt_path)
314
+ del state; gc.collect(); torch.cuda.empty_cache()
315
+ else:
316
+ model_dir = _resolve_model_dir(pretrained_model_name_or_path, subfolder)
317
+ model_cfg = _load_model_config(model_dir)
318
+ if sp_size > 1:
319
+ model_cfg["sp_size"] = sp_size
320
+ if enable_bf16:
321
+ model_cfg["enable_bf16"] = True
322
+ model = WorldMirror(**model_cfg).to(device)
323
+ state = load_safetensors(os.path.join(model_dir, "model.safetensors"))
324
+ _load_state_dict_selective(model, state, source_name=model_dir)
325
+ del state; gc.collect(); torch.cuda.empty_cache()
326
+
327
+ # bf16 casting — two strategies depending on FSDP:
328
+ #
329
+ # Multi-GPU + FSDP: cast everything to bf16 uniformly (including fc2).
330
+ # FSDP requires uniform dtype per flat-param unit.
331
+ #
332
+ # Single GPU (no FSDP): cast to bf16, then restore critical fp32
333
+ # modules (MlpFP32.fc2, output_conv2) so their .float() calls work.
334
+ # Register input-cast hooks on bf16 leaf modules for dtype boundaries.
335
+ if enable_bf16:
336
+ if use_fsdp and is_distributed:
337
+ model.to(torch.bfloat16)
338
+ crit = _collect_fp32_critical_modules(model)
339
+ _cast_noncritical_fp32_to_bf16(model, crit)
340
+ else:
341
+ crit = _collect_fp32_critical_modules(model)
342
+ model.to(torch.bfloat16)
343
+ for mod in crit:
344
+ mod.to(torch.float32)
345
+
346
+ def _input_cast_hook(module, args):
347
+ if not args:
348
+ return args
349
+ dtype = next((p.dtype for p in module.parameters(recurse=False)), None)
350
+ if dtype is None:
351
+ return args
352
+ return tuple(
353
+ a.to(dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() and a.dtype != dtype else a
354
+ for a in args
355
+ )
356
+
357
+ for _, module in model.named_modules():
358
+ if not any(True for _ in module.children()):
359
+ own = list(module.parameters(recurse=False))
360
+ if own and all(p.dtype == torch.bfloat16 for p in own):
361
+ module.register_forward_pre_hook(_input_cast_hook)
362
+
363
+ model.eval()
364
+
365
+ # Disable unused heads
366
+ if disable_heads:
367
+ _disable_heads(model, disable_heads)
368
+
369
+ # FSDP wrapping
370
+ if use_fsdp and is_distributed:
371
+ model = _wrap_model_fsdp(model, sp_group, device,
372
+ use_cpu_offload=fsdp_cpu_offload,
373
+ enable_bf16=enable_bf16)
374
+ if enable_bf16:
375
+ inner = model.module if hasattr(model, 'module') else model
376
+ inner.to = lambda *a, **kw: inner
377
+
378
+ if rank == 0:
379
+ print(f"[Init] Model ready in {time.perf_counter() - t0:.1f}s")
380
+ if torch.cuda.is_available():
381
+ alloc = torch.cuda.memory_allocated(device) / (1024**3)
382
+ print(f"[Memory] allocated={alloc:.2f}GB")
383
+
384
+ return cls(model, device, sp_size, sp_group, rank)
385
+
386
+ @torch.no_grad()
387
+ def __call__(
388
+ self,
389
+ input_path: str,
390
+ output_path: str = "inference_output",
391
+ *,
392
+ # Inference
393
+ target_size: int = 952,
394
+ fps: int = 1,
395
+ video_strategy: str = "new",
396
+ video_min_frames: int = 1,
397
+ video_max_frames: int = 32,
398
+ # Save
399
+ save_depth: bool = True,
400
+ save_normal: bool = True,
401
+ save_gs: bool = True,
402
+ save_camera: bool = True,
403
+ save_points: bool = True,
404
+ save_colmap: bool = False,
405
+ save_conf: bool = False,
406
+ # Mask
407
+ apply_sky_mask: bool = True,
408
+ apply_edge_mask: bool = True,
409
+ apply_confidence_mask: bool = False,
410
+ save_sky_mask: bool = False,
411
+ sky_mask_source: str = "auto",
412
+ model_sky_threshold: float = 0.45,
413
+ confidence_percentile: float = 10.0,
414
+ edge_normal_threshold: float = 1.0,
415
+ edge_depth_threshold: float = 0.03,
416
+ # Compression
417
+ compress_pts: bool = True,
418
+ compress_pts_max_points: int = 2_000_000,
419
+ compress_pts_voxel_size: float = 0.002,
420
+ max_resolution: int = 1920,
421
+ compress_gs_max_points: int = 5_000_000,
422
+ # Prior
423
+ prior_cam_path: str = None,
424
+ prior_depth_path: str = None,
425
+ # Rendered video
426
+ save_rendered: bool = False,
427
+ render_interp_per_pair: int = 15,
428
+ render_depth: bool = False,
429
+ # Misc
430
+ log_time: bool = True,
431
+ strict_output_path: str = None,
432
+ ) -> str:
433
+ """Run inference on images/video and save results.
434
+
435
+ Args:
436
+ input_path: Directory of images or a video file.
437
+ output_path: Root output directory.
438
+ **kwargs: Override default inference parameters.
439
+
440
+ Returns:
441
+ Path to the output directory (str), or None on skip.
442
+ """
443
+ model = self.model
444
+ device = self.device
445
+ sp_size, sp_group, rank = self.sp_size, self.sp_group, self.rank
446
+ is_distributed = sp_size > 1
447
+
448
+ case_t0 = time.perf_counter()
449
+ timings = {}
450
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
451
+
452
+ # 1. Prepare input
453
+ t0 = time.perf_counter()
454
+ img_paths, subdir_name = prepare_input(
455
+ input_path, target_size=target_size, fps=fps,
456
+ video_strategy=video_strategy,
457
+ min_frames=video_min_frames, max_frames=video_max_frames,
458
+ )
459
+ if log_time:
460
+ timings["data_loading"] = time.perf_counter() - t0
461
+
462
+ if strict_output_path is not None:
463
+ outdir = Path(strict_output_path)
464
+ else:
465
+ outdir = Path(output_path) / subdir_name / timestamp
466
+
467
+ # 2. Adaptive resolution
468
+ effective = compute_adaptive_target_size(img_paths, target_size)
469
+ if rank == 0 and effective != target_size:
470
+ print(f"[Inference] Adaptive resolution: {effective} (max={target_size})")
471
+
472
+ # 3. Inference
473
+ if torch.cuda.is_available():
474
+ torch.cuda.reset_peak_memory_stats(device)
475
+ torch.cuda.synchronize(device)
476
+
477
+ t0_all = time.perf_counter()
478
+ try:
479
+ predictions, imgs, infer_time = self._run_inference(
480
+ img_paths, effective, prior_cam_path, prior_depth_path)
481
+ except ValueError as e:
482
+ if rank == 0:
483
+ print(f"[Pipeline] Skipping '{input_path}': {e}")
484
+ return None
485
+
486
+ if log_time:
487
+ timings["inference"] = infer_time
488
+ timings["inference_preprocess"] = time.perf_counter() - t0_all - infer_time
489
+
490
+ # GPU memory stats (multi-GPU)
491
+ if log_time and torch.cuda.is_available() and is_distributed:
492
+ peak = torch.cuda.max_memory_allocated(device) / (1024**3)
493
+ peak_t = torch.tensor([peak], dtype=torch.float64, device=device)
494
+ gathered = [torch.zeros(1, dtype=torch.float64, device=device) for _ in range(sp_size)]
495
+ dist.all_gather(gathered, peak_t, group=sp_group)
496
+ timings["gpu_mem_peak_per_rank_gb"] = [t.item() for t in gathered]
497
+ timings["gpu_mem_peak_avg_gb"] = sum(timings["gpu_mem_peak_per_rank_gb"]) / sp_size
498
+
499
+ # 4. Post-processing and saving (rank 0 only)
500
+ if rank == 0:
501
+ B, S, C, H, W = imgs.shape
502
+ t0 = time.perf_counter()
503
+
504
+ sky_mask = (compute_sky_mask(
505
+ img_paths, H, W, S, predictions=predictions,
506
+ source=sky_mask_source, model_threshold=model_sky_threshold,
507
+ processed_aspect_ratio=W / H,
508
+ ) if apply_sky_mask else None)
509
+
510
+ filter_mask, gs_filter_mask = None, None
511
+ if apply_confidence_mask or apply_edge_mask or apply_sky_mask:
512
+ filter_mask, gs_filter_mask = compute_filter_mask(
513
+ predictions, imgs, img_paths, H, W, S,
514
+ apply_confidence_mask=apply_confidence_mask,
515
+ apply_edge_mask=apply_edge_mask,
516
+ apply_sky_mask=apply_sky_mask,
517
+ confidence_percentile=confidence_percentile,
518
+ edge_normal_threshold=edge_normal_threshold,
519
+ edge_depth_threshold=edge_depth_threshold,
520
+ sky_mask=sky_mask, use_gs_depth=save_gs,
521
+ )
522
+
523
+ if log_time:
524
+ timings["compute_mask"] = time.perf_counter() - t0
525
+
526
+ t0 = time.perf_counter()
527
+ save_timings = save_results(
528
+ predictions, imgs, img_paths, outdir,
529
+ save_depth=save_depth, save_normal=save_normal,
530
+ save_gs=save_gs, save_camera=save_camera,
531
+ save_points=save_points, save_colmap=save_colmap,
532
+ save_sky_mask=save_sky_mask, save_conf=save_conf,
533
+ log_time=log_time, max_resolution=max_resolution,
534
+ filter_mask=filter_mask, gs_filter_mask=gs_filter_mask,
535
+ sky_mask=sky_mask,
536
+ compress_pts=compress_pts,
537
+ compress_pts_max_points=compress_pts_max_points,
538
+ compress_pts_voxel_size=compress_pts_voxel_size,
539
+ compress_gs_max_points=compress_gs_max_points,
540
+ )
541
+ if log_time:
542
+ timings.update(save_timings or {})
543
+ timings["save_total_wall"] = time.perf_counter() - t0
544
+
545
+ # Render interpolated video from Gaussian splats
546
+ if save_rendered and "splats" in predictions:
547
+ inner_model = model.module if hasattr(model, 'module') else model
548
+ if hasattr(inner_model, 'gs_renderer'):
549
+ t0_render = time.perf_counter()
550
+ try:
551
+ splats_f32 = {k: v.float() if isinstance(v, torch.Tensor) else v
552
+ for k, v in predictions["splats"].items()}
553
+ render_interpolated_video(
554
+ inner_model.gs_renderer,
555
+ splats_f32,
556
+ predictions["camera_poses"].float(),
557
+ predictions["camera_intrs"].float(),
558
+ (H, W),
559
+ outdir / "rendered",
560
+ interp_per_pair=render_interp_per_pair,
561
+ loop_reverse=(S <= 2),
562
+ render_depth=render_depth,
563
+ )
564
+ if log_time:
565
+ timings["render_video"] = time.perf_counter() - t0_render
566
+ except Exception as e:
567
+ print(f"[Pipeline] Warning: video rendering failed: {e}")
568
+
569
+ if not is_distributed:
570
+ del predictions
571
+ torch.cuda.empty_cache()
572
+
573
+ timings["case_total"] = time.perf_counter() - case_t0
574
+ if log_time:
575
+ print_and_save_timings(timings, outdir)
576
+
577
+ print(f"\n{'='*60}\n[Pipeline] Results saved to: {outdir}\n{'='*60}\n")
578
+
579
+ if is_distributed:
580
+ del predictions, imgs
581
+ gc.collect()
582
+ torch.cuda.empty_cache()
583
+ dist.barrier()
584
+
585
+ return str(outdir)
586
+
587
+ def _run_inference(self, img_paths, target_size, prior_cam_path, prior_depth_path):
588
+ """Run model forward pass."""
589
+ device = self.device
590
+ imgs = prepare_images_to_tensor(
591
+ img_paths, target_size=target_size, resize_strategy="crop"
592
+ ).to(device)
593
+ views = {"img": imgs}
594
+ B, S, C, H, W = imgs.shape
595
+
596
+ if self.sp_size > 1 and S < self.sp_size:
597
+ raise ValueError(
598
+ f"Number of input images ({S}) must be >= number of GPUs ({self.sp_size}) "
599
+ f"in multi-GPU mode. Please provide at least {self.sp_size} images, "
600
+ f"or use fewer GPUs."
601
+ )
602
+
603
+ if self.rank == 0:
604
+ print(f"[Inference] {S} images, shape={imgs.shape}, sp_size={self.sp_size}")
605
+
606
+ pp_xform = compute_preprocessing_transform(img_paths, target_size)
607
+ cond_flags = [0, 0, 0]
608
+
609
+ if prior_cam_path and os.path.isfile(prior_cam_path):
610
+ extr, intr = load_prior_camera(prior_cam_path, img_paths, preprocess_transform=pp_xform)
611
+ if extr is not None:
612
+ first = extr[0, 0]
613
+ extr = torch.linalg.inv(first.float()).to(first.dtype).unsqueeze(0).unsqueeze(0) @ extr
614
+ views["camera_poses"] = extr.to(device)
615
+ cond_flags[0] = 1
616
+ if intr is not None:
617
+ views["camera_intrs"] = intr.to(device)
618
+ cond_flags[2] = 1
619
+
620
+ if prior_depth_path and os.path.isdir(prior_depth_path):
621
+ depth = load_prior_depth(prior_depth_path, img_paths, H, W, preprocess_transform=pp_xform)
622
+ if depth is not None:
623
+ views["depthmap"] = depth.to(device)
624
+ cond_flags[1] = 1
625
+
626
+ use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
627
+ inner = self.model.module if hasattr(self.model, 'module') else self.model
628
+ model_bf16 = getattr(inner, 'enable_bf16', False)
629
+
630
+ t0 = time.perf_counter()
631
+ with torch.amp.autocast("cuda", enabled=(not model_bf16 and use_amp), dtype=torch.bfloat16):
632
+ fwd_kw = dict(views=views, cond_flags=cond_flags, is_inference=True)
633
+ if self.sp_size > 1:
634
+ fwd_kw["sp_size"] = self.sp_size
635
+ fwd_kw["sp_group"] = self.sp_group
636
+ predictions = self.model(**fwd_kw)
637
+ if device.type == "cuda":
638
+ torch.cuda.synchronize()
639
+ infer_time = time.perf_counter() - t0
640
+
641
+ if self.rank == 0:
642
+ print(f"[Inference] Done in {infer_time:.2f}s")
643
+ return predictions, imgs, infer_time
644
+
645
+
646
+ # ============================================================
647
+ # Head disabling helper
648
+ # ============================================================
649
+
650
+ def _disable_heads(model, head_names):
651
+ """Disable and free specified heads. head_names: list of 'camera','depth','normal','points','gs'."""
652
+ mapping = {
653
+ "camera": ("enable_cam", ["cam_head"]),
654
+ "depth": ("enable_depth", ["depth_head"]),
655
+ "normal": ("enable_norm", ["norm_head"]),
656
+ "points": ("enable_pts", ["pts_head"]),
657
+ "gs": ("enable_gs", ["gs_head", "gs_renderer"]),
658
+ }
659
+ freed = 0
660
+ for name in head_names:
661
+ if name not in mapping:
662
+ continue
663
+ attr, modules = mapping[name]
664
+ setattr(model, attr, False)
665
+ for mod_name in modules:
666
+ if hasattr(model, mod_name):
667
+ mod = getattr(model, mod_name)
668
+ freed += sum(p.numel() for p in mod.parameters())
669
+ mod.cpu()
670
+ delattr(model, mod_name)
671
+ del mod
672
+ if freed:
673
+ gc.collect()
674
+ torch.cuda.empty_cache()
675
+ print(f"[Init] Disabled heads: {head_names}, freed ~{freed/1e6:.1f}M params")
676
+
677
+
678
+ # ============================================================
679
+ # CLI entry point
680
+ # ============================================================
681
+
682
+ def _broadcast_string(s, rank, src=0):
683
+ if rank == src:
684
+ data = s.encode("utf-8")
685
+ length = torch.tensor([len(data)], dtype=torch.long, device="cuda")
686
+ else:
687
+ length = torch.tensor([0], dtype=torch.long, device="cuda")
688
+ dist.broadcast(length, src=src)
689
+ n = length.item()
690
+ tensor = torch.tensor(list(data), dtype=torch.uint8, device="cuda") if rank == src else torch.empty(n, dtype=torch.uint8, device="cuda")
691
+ dist.broadcast(tensor, src=src)
692
+ return tensor.cpu().numpy().tobytes().decode("utf-8")
693
+
694
+
695
+ def main():
696
+ parser = argparse.ArgumentParser(description="HunyuanWorld-Mirror Pipeline")
697
+ parser.add_argument("--input_path", type=str, required=True)
698
+ parser.add_argument("--output_path", type=str, default="inference_output")
699
+ parser.add_argument("--strict_output_path", type=str, default=None,
700
+ help="If set, save results directly to this path (no subdir/timestamp)")
701
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="tencent/HY-World-2.0",
702
+ help="HuggingFace repo ID or local path")
703
+ parser.add_argument("--subfolder", type=str, default="HY-WorldMirror-2.0",
704
+ help="Subfolder inside the repo containing WorldMirror weights")
705
+ parser.add_argument("--config_path", type=str, default=None)
706
+ parser.add_argument("--ckpt_path", type=str, default=None)
707
+ parser.add_argument("--use_fsdp", action="store_true", default=False)
708
+ parser.add_argument("--enable_bf16", action="store_true", default=False)
709
+ parser.add_argument("--fsdp_cpu_offload", action="store_true", default=False)
710
+ parser.add_argument("--target_size", type=int, default=952)
711
+ parser.add_argument("--fps", type=int, default=1)
712
+ parser.add_argument("--video_strategy", type=str, default="new", choices=["old", "new"])
713
+ parser.add_argument("--video_min_frames", type=int, default=1)
714
+ parser.add_argument("--video_max_frames", type=int, default=32)
715
+ parser.add_argument("--no_save_depth", action="store_true")
716
+ parser.add_argument("--no_save_normal", action="store_true")
717
+ parser.add_argument("--no_save_gs", action="store_true")
718
+ parser.add_argument("--no_save_camera", action="store_true")
719
+ parser.add_argument("--no_save_points", action="store_true")
720
+ parser.add_argument("--save_colmap", action="store_true", default=False)
721
+ parser.add_argument("--save_conf", action="store_true", default=False)
722
+ parser.add_argument("--save_sky_mask", action="store_true", default=False)
723
+ parser.add_argument("--apply_sky_mask", action="store_true", default=True)
724
+ parser.add_argument("--no_sky_mask", dest="apply_sky_mask", action="store_false")
725
+ parser.add_argument("--apply_edge_mask", action="store_true", default=True)
726
+ parser.add_argument("--no_edge_mask", dest="apply_edge_mask", action="store_false")
727
+ parser.add_argument("--apply_confidence_mask", action="store_true", default=False)
728
+ parser.add_argument("--sky_mask_source", type=str, default="auto", choices=["auto", "model", "onnx"])
729
+ parser.add_argument("--model_sky_threshold", type=float, default=0.45)
730
+ parser.add_argument("--confidence_percentile", type=float, default=10.0)
731
+ parser.add_argument("--edge_normal_threshold", type=float, default=1.0)
732
+ parser.add_argument("--edge_depth_threshold", type=float, default=0.03)
733
+ parser.add_argument("--compress_pts", action="store_true", default=True)
734
+ parser.add_argument("--no_compress_pts", dest="compress_pts", action="store_false")
735
+ parser.add_argument("--compress_pts_max_points", type=int, default=2_000_000)
736
+ parser.add_argument("--compress_pts_voxel_size", type=float, default=0.002)
737
+ parser.add_argument("--max_resolution", type=int, default=1920)
738
+ parser.add_argument("--compress_gs_max_points", type=int, default=5_000_000)
739
+ parser.add_argument("--prior_cam_path", type=str, default=None)
740
+ parser.add_argument("--prior_depth_path", type=str, default=None)
741
+ parser.add_argument("--disable_heads", type=str, nargs="*", default=None,
742
+ help="Heads to disable: camera depth normal points gs")
743
+ parser.add_argument("--save_rendered", action="store_true", default=False,
744
+ help="Render interpolated video from Gaussian splats")
745
+ parser.add_argument("--render_interp_per_pair", type=int, default=15,
746
+ help="Interpolated frames per camera pair for video rendering")
747
+ parser.add_argument("--render_depth", action="store_true", default=False,
748
+ help="Also render depth video")
749
+ parser.add_argument("--log_time", action="store_true", default=True)
750
+ parser.add_argument("--no_log_time", dest="log_time", action="store_false")
751
+ parser.add_argument("--no_interactive", action="store_true")
752
+ args = parser.parse_args()
753
+
754
+ pipeline = WorldMirrorPipeline.from_pretrained(
755
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
756
+ subfolder=args.subfolder,
757
+ config_path=args.config_path, ckpt_path=args.ckpt_path,
758
+ use_fsdp=args.use_fsdp, enable_bf16=args.enable_bf16,
759
+ fsdp_cpu_offload=args.fsdp_cpu_offload,
760
+ disable_heads=args.disable_heads,
761
+ )
762
+
763
+ call_kwargs = dict(
764
+ output_path=args.output_path,
765
+ target_size=args.target_size, fps=args.fps,
766
+ video_strategy=args.video_strategy,
767
+ video_min_frames=args.video_min_frames,
768
+ video_max_frames=args.video_max_frames,
769
+ save_depth=not args.no_save_depth,
770
+ save_normal=not args.no_save_normal,
771
+ save_gs=not args.no_save_gs,
772
+ save_camera=not args.no_save_camera,
773
+ save_points=not args.no_save_points,
774
+ save_colmap=args.save_colmap,
775
+ save_conf=args.save_conf,
776
+ save_sky_mask=args.save_sky_mask,
777
+ apply_sky_mask=args.apply_sky_mask,
778
+ apply_edge_mask=args.apply_edge_mask,
779
+ apply_confidence_mask=args.apply_confidence_mask,
780
+ sky_mask_source=args.sky_mask_source,
781
+ model_sky_threshold=args.model_sky_threshold,
782
+ confidence_percentile=args.confidence_percentile,
783
+ edge_normal_threshold=args.edge_normal_threshold,
784
+ edge_depth_threshold=args.edge_depth_threshold,
785
+ compress_pts=args.compress_pts,
786
+ compress_pts_max_points=args.compress_pts_max_points,
787
+ compress_pts_voxel_size=args.compress_pts_voxel_size,
788
+ max_resolution=args.max_resolution,
789
+ compress_gs_max_points=args.compress_gs_max_points,
790
+ prior_cam_path=args.prior_cam_path,
791
+ prior_depth_path=args.prior_depth_path,
792
+ save_rendered=args.save_rendered,
793
+ render_interp_per_pair=args.render_interp_per_pair,
794
+ render_depth=args.render_depth,
795
+ log_time=args.log_time,
796
+ strict_output_path=args.strict_output_path,
797
+ )
798
+
799
+ try:
800
+ pipeline(args.input_path, **call_kwargs)
801
+
802
+ if args.no_interactive:
803
+ return
804
+
805
+ rank = pipeline.rank
806
+ is_distributed = pipeline.sp_size > 1
807
+
808
+ if rank == 0:
809
+ print("\n[Interactive] Enter new input paths. Type 'quit' to stop.\n")
810
+
811
+ _INF_TIMEOUT = timedelta(days=365)
812
+ _DEF_TIMEOUT = timedelta(minutes=10)
813
+
814
+ while True:
815
+ if is_distributed:
816
+ dist.distributed_c10d._get_default_group()._get_backend(
817
+ torch.device("cuda")).options._timeout = _INF_TIMEOUT
818
+
819
+ new_input = ""
820
+ if rank == 0:
821
+ try:
822
+ new_input = input(">>> ").strip()
823
+ except (EOFError, KeyboardInterrupt):
824
+ new_input = "quit"
825
+
826
+ if is_distributed:
827
+ new_input = _broadcast_string(new_input, rank, src=0)
828
+ dist.distributed_c10d._get_default_group()._get_backend(
829
+ torch.device("cuda")).options._timeout = _DEF_TIMEOUT
830
+
831
+ if not new_input or new_input.lower() in ("quit", "exit", "q"):
832
+ break
833
+
834
+ if rank == 0 and not (Path(new_input).is_dir() or Path(new_input).is_file()):
835
+ print(f" Invalid path: {new_input}")
836
+ continue
837
+
838
+ pipeline.model.to(pipeline.device)
839
+ pipeline.model.eval()
840
+ pipeline(new_input, **call_kwargs)
841
+ finally:
842
+ if dist.is_initialized():
843
+ dist.destroy_process_group()
844
+
845
+
846
+ if __name__ == "__main__":
847
+ main()