Spaces:
Running on Zero
Running on Zero
update app --files
Browse files- hyworldmirror/__init__.py +0 -0
- hyworldmirror/comm/__init__.py +0 -0
- hyworldmirror/comm/communication.py +61 -0
- hyworldmirror/comm/padding.py +134 -0
- hyworldmirror/models/__init__.py +0 -0
- hyworldmirror/models/heads/__init__.py +0 -0
- hyworldmirror/models/heads/camera_head.py +184 -0
- hyworldmirror/models/heads/dense_head.py +672 -0
- hyworldmirror/models/heads/gs_head.py +83 -0
- hyworldmirror/models/layers/__init__.py +5 -0
- hyworldmirror/models/layers/attention.py +131 -0
- hyworldmirror/models/layers/block.py +269 -0
- hyworldmirror/models/layers/drop_path.py +29 -0
- hyworldmirror/models/layers/layer_scale.py +17 -0
- hyworldmirror/models/layers/mlp.py +64 -0
- hyworldmirror/models/layers/norm_rope.py +140 -0
- hyworldmirror/models/layers/patch_embed.py +155 -0
- hyworldmirror/models/layers/rope.py +182 -0
- hyworldmirror/models/layers/swiglu_ffn.py +46 -0
- hyworldmirror/models/layers/vision_transformer.py +394 -0
- hyworldmirror/models/models/__init__.py +0 -0
- hyworldmirror/models/models/rasterization.py +525 -0
- hyworldmirror/models/models/visual_transformer.py +542 -0
- hyworldmirror/models/models/worldmirror.py +685 -0
- hyworldmirror/models/utils/__init__.py +0 -0
- hyworldmirror/models/utils/act_gs.py +22 -0
- hyworldmirror/models/utils/camera_utils.py +75 -0
- hyworldmirror/models/utils/frustum.py +196 -0
- hyworldmirror/models/utils/geometry.py +111 -0
- hyworldmirror/models/utils/grid.py +90 -0
- hyworldmirror/models/utils/priors.py +168 -0
- hyworldmirror/models/utils/rotation.py +126 -0
- hyworldmirror/models/utils/sh_utils.py +116 -0
- hyworldmirror/utils/__init__.py +0 -0
- hyworldmirror/utils/geometry.py +531 -0
- hyworldmirror/utils/inference_utils.py +824 -0
- hyworldmirror/utils/render_utils.py +294 -0
- hyworldmirror/utils/save_utils.py +261 -0
- hyworldmirror/utils/video_utils.py +557 -0
- hyworldmirror/utils/visual_util.py +617 -0
- hyworldmirror/utils/warnings.py +29 -0
- pipeline.py +847 -0
hyworldmirror/__init__.py
ADDED
|
File without changes
|
hyworldmirror/comm/__init__.py
ADDED
|
File without changes
|
hyworldmirror/comm/communication.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
def all2all(tensor, scatter_dim, gather_dim, cur_group, async_op):
|
| 5 |
+
group_size = dist.get_world_size(group=cur_group)
|
| 6 |
+
scatter_tensor_list = list(chunk.contiguous() for chunk in torch.chunk(tensor, chunks=group_size, dim=scatter_dim))
|
| 7 |
+
gather_tensor_list = [torch.zeros_like(x) for x in scatter_tensor_list]
|
| 8 |
+
comm = dist.all_to_all(gather_tensor_list, scatter_tensor_list, group=cur_group, async_op=async_op)
|
| 9 |
+
if async_op:
|
| 10 |
+
def wait():
|
| 11 |
+
comm.wait()
|
| 12 |
+
recieved_tensor = torch.cat(gather_tensor_list, dim=gather_dim).contiguous()
|
| 13 |
+
return recieved_tensor
|
| 14 |
+
return wait()
|
| 15 |
+
recieved_tensor = torch.cat(gather_tensor_list, dim=gather_dim).contiguous()
|
| 16 |
+
return recieved_tensor
|
| 17 |
+
|
| 18 |
+
def all_gather(tensor, gather_dim, cur_group, async_op):
|
| 19 |
+
tensor = tensor.contiguous()
|
| 20 |
+
group_size = dist.get_world_size(group=cur_group)
|
| 21 |
+
gather_list = [torch.zeros_like(tensor) for _ in range(group_size)]
|
| 22 |
+
comm = dist.all_gather(gather_list, tensor, group=cur_group, async_op=async_op)
|
| 23 |
+
gather_tensor = torch.cat(gather_list, dim=gather_dim)
|
| 24 |
+
if async_op:
|
| 25 |
+
def wait():
|
| 26 |
+
comm.wait()
|
| 27 |
+
gather_tensor = torch.cat(gather_list, dim=gather_dim)
|
| 28 |
+
return gather_tensor
|
| 29 |
+
return wait()
|
| 30 |
+
return gather_tensor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _All2All(torch.autograd.Function):
|
| 34 |
+
@staticmethod
|
| 35 |
+
def forward(ctx, tensor, scatter_dim, gather_dim, cur_group, async_op):
|
| 36 |
+
ctx.cur_group = cur_group
|
| 37 |
+
ctx.scatter_dim = scatter_dim
|
| 38 |
+
ctx.gather_dim = gather_dim
|
| 39 |
+
ctx.async_op = async_op
|
| 40 |
+
return all2all(tensor=tensor, scatter_dim=scatter_dim, gather_dim=gather_dim, cur_group=cur_group, async_op=async_op)
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def backward(ctx, grad_outputs):
|
| 44 |
+
input_t = grad_outputs
|
| 45 |
+
return (all2all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.cur_group, False), None, None, None, None)
|
| 46 |
+
|
| 47 |
+
class _Allgather(torch.autograd.Function):
|
| 48 |
+
@staticmethod
|
| 49 |
+
def forward(ctx, tensor, gather_dim, cur_group, async_op):
|
| 50 |
+
ctx.gather_dim = gather_dim
|
| 51 |
+
ctx.cur_group = cur_group
|
| 52 |
+
ctx.async_op = async_op
|
| 53 |
+
return all_gather(tensor=tensor, gather_dim=gather_dim, cur_group=cur_group, async_op=async_op)
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def backward(ctx, grad_outputs):
|
| 57 |
+
sp_group = ctx.cur_group
|
| 58 |
+
sp_group_size = dist.get_world_size(group=sp_group)
|
| 59 |
+
rank = dist.get_rank()
|
| 60 |
+
rank_in_group = dist.get_group_rank(group=sp_group, global_rank=rank)
|
| 61 |
+
return (grad_outputs.split(grad_outputs.shape[ctx.gather_dim] // sp_group_size, dim=ctx.gather_dim)[rank_in_group], None, None, None)
|
hyworldmirror/comm/padding.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
def minimal_pad_to_divisible(tensor: torch.Tensor, sp_size: int, dim: int = 1, pad_value: float = 0.0):
|
| 5 |
+
"""
|
| 6 |
+
对三维或更高维度的tensor在指定维度进行最小化padding,使其长度能被 sp_size 整除。
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
tensor: 输入的PyTorch tensor (例如:[B, L, C] 或 [B, H, W, C] 等)。
|
| 10 |
+
sp_size: 要求的最小分割尺寸。
|
| 11 |
+
dim: 需要进行padding的维度索引(默认为 1,即第二维)。
|
| 12 |
+
pad_value: 填充的值(默认为 0.0)。
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
padded_tensor: 填充后的 tensor。
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
current_size = tensor.size(dim)
|
| 19 |
+
|
| 20 |
+
# 计算需要填充的长度
|
| 21 |
+
# (sp_size - current_size % sp_size) % sp_size
|
| 22 |
+
# 保证了如果 current_size 已经是 sp_size 的倍数,padding_len 为 0。
|
| 23 |
+
# 否则,计算出最小的填充长度。
|
| 24 |
+
padding_len = (sp_size - current_size % sp_size) % sp_size
|
| 25 |
+
|
| 26 |
+
if padding_len == 0:
|
| 27 |
+
# 如果长度已经可以整除,直接返回原 tensor
|
| 28 |
+
return tensor, 0
|
| 29 |
+
|
| 30 |
+
# 构建 pad 元组
|
| 31 |
+
# torch.nn.functional.pad 的 pad 参数是从**最末尾的维度**开始,**成对** (后填充, 前填充) 指定的。
|
| 32 |
+
# 假设你的 tensor 是 [D0, D1, D2]
|
| 33 |
+
# 如果 dim=1 (第二维, D1),pad 应该是 (0, 0, padding_len, 0, 0, 0, ...)
|
| 34 |
+
#
|
| 35 |
+
# 由于我们需要在第二维 (dim=1) 的末尾进行填充,我们需要确定 pad 元组中对应 dim=1 的位置。
|
| 36 |
+
# 维度数量 D = tensor.dim()
|
| 37 |
+
# dim=0 对应 pad 元组的最后两位
|
| 38 |
+
# dim=1 对应 pad 元组的倒数第 4, 3 位
|
| 39 |
+
# dim=2 对应 pad 元组的倒数第 6, 5 位 (对于三维 tensor,即前两位)
|
| 40 |
+
|
| 41 |
+
# 在 dim 维度进行 '后填充' (在末尾添加)
|
| 42 |
+
# padding_dims 是一个长度为 2 * D 的元组,所有维度默认不填充
|
| 43 |
+
padding_dims = [0] * (2 * tensor.dim())
|
| 44 |
+
|
| 45 |
+
# 对应 dim 维度的 '后填充' (即 pad 元组中的偶数索引位置,从后往前数)
|
| 46 |
+
# 填充的位置是 (2 * tensor.dim() - 2 * dim - 2)
|
| 47 |
+
# 例如:D=3, dim=1 -> 2*3 - 2*1 - 2 = 2
|
| 48 |
+
# pad 元组为 (d2_start, d2_end, d1_start, d1_end, d0_start, d0_end)
|
| 49 |
+
# 我们要填充 d1_end,它在索引 2 的位置
|
| 50 |
+
|
| 51 |
+
# F.pad 要求的是 (最后维度 start, 最后维度 end, 倒数第二维度 start, 倒数第二维度 end, ...)
|
| 52 |
+
# 我们的 dim=1 是倒数第 (D - 1 - dim) + 1 = D - dim 个维度
|
| 53 |
+
# 它在 pad 元组中是倒数第 2 * (D - dim) 位和倒数第 2 * (D - dim) - 1 位
|
| 54 |
+
#
|
| 55 |
+
# 填充位置的索引 (从 0 开始, 从左往右):
|
| 56 |
+
# (2 * (tensor.dim() - dim - 1)) 是 '前填充' 的位置
|
| 57 |
+
# (2 * (tensor.dim() - dim - 1) + 1) 是 '后填充' 的位置
|
| 58 |
+
pad_index = 2 * (tensor.dim() - dim - 1) + 1
|
| 59 |
+
|
| 60 |
+
if pad_index < len(padding_dims):
|
| 61 |
+
padding_dims[pad_index] = padding_len
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Invalid dimension index.")
|
| 64 |
+
|
| 65 |
+
# 转换回 tuple
|
| 66 |
+
pad = tuple(padding_dims)
|
| 67 |
+
|
| 68 |
+
# 使用 F.pad 进行填充,模式为 'constant'
|
| 69 |
+
padded_tensor = F.pad(tensor, pad=pad, mode='constant', value=pad_value)
|
| 70 |
+
|
| 71 |
+
return padded_tensor, padding_len
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def depad_by_length(padded_tensor: torch.Tensor, depadding_len: int, dim: int = 1) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
在指定维度上去除末尾的 padding 部分。
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
padded_tensor: 已经经过 padding 的 PyTorch tensor。
|
| 80 |
+
depadding_len: 需要从末尾去除的长度。
|
| 81 |
+
dim: 需要去除 padding 的维度索引(默认为 1,即第二维)。
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
depadded_tensor: 去除 padding 后的 tensor。
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
# 检查去除长度是否合理
|
| 88 |
+
current_size = padded_tensor.size(dim)
|
| 89 |
+
if depadding_len < 0:
|
| 90 |
+
raise ValueError("depadding_len 必须是非负数。")
|
| 91 |
+
if depadding_len > current_size:
|
| 92 |
+
raise ValueError(f"要去除的长度 {depadding_len} 大于当前维度长度 {current_size}。")
|
| 93 |
+
|
| 94 |
+
# 计算去除 padding 后的目标长度
|
| 95 |
+
target_size = current_size - depadding_len
|
| 96 |
+
|
| 97 |
+
# 构造切片操作所需的索引元组
|
| 98 |
+
# 对于所有维度,我们默认使用完整的切片 `:`
|
| 99 |
+
slices = [slice(None)] * padded_tensor.dim()
|
| 100 |
+
|
| 101 |
+
# 在指定维度 dim 上,我们只取从 0 到 target_size 的部分
|
| 102 |
+
# Python 切片 [0:target_size] 会保留 target_size 个元素,即去除了末尾的 depadding_len
|
| 103 |
+
slices[dim] = slice(0, target_size)
|
| 104 |
+
|
| 105 |
+
# 使用元组解包进行切片操作
|
| 106 |
+
depadded_tensor = padded_tensor[tuple(slices)]
|
| 107 |
+
|
| 108 |
+
return depadded_tensor
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def pad_by_length(padded_tensor: torch.Tensor, padding_len: int, dim: int = 1,pad_value: float = 0.0) -> torch.Tensor:
|
| 114 |
+
|
| 115 |
+
if padding_len < 0:
|
| 116 |
+
raise ValueError("padding_len 必须是非负数。")
|
| 117 |
+
|
| 118 |
+
if dim < 0 or dim >= padded_tensor.dim():
|
| 119 |
+
raise ValueError(f"维度索引 {dim} 超出有效范围 [0, {padded_tensor.dim() - 1}]。")
|
| 120 |
+
|
| 121 |
+
# 构建padding参数
|
| 122 |
+
# F.pad需要为每个维度指定左右两边的padding长度
|
| 123 |
+
# 格式为: (最后一个维度的左边, 最后一个维度的右边, 倒数第二个维度的左边, 倒数第二个维度的右边, ...)
|
| 124 |
+
pad_tuple = [0] * (2 * padded_tensor.dim())
|
| 125 |
+
|
| 126 |
+
# 将指定维度右边的padding长度设置为padding_len
|
| 127 |
+
# F.pad的维度顺序是从最后一个维度开始的,所以需要进行转换
|
| 128 |
+
pad_idx = 2 * (padded_tensor.dim() - 1 - dim) + 1
|
| 129 |
+
pad_tuple[pad_idx] = padding_len
|
| 130 |
+
|
| 131 |
+
# 调用F.pad进行padding
|
| 132 |
+
padded_tensor = F.pad(padded_tensor, pad=tuple(pad_tuple), mode='constant', value=pad_value)
|
| 133 |
+
|
| 134 |
+
return padded_tensor
|
hyworldmirror/models/__init__.py
ADDED
|
File without changes
|
hyworldmirror/models/heads/__init__.py
ADDED
|
File without changes
|
hyworldmirror/models/heads/camera_head.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inspired by https://github.com/facebookresearch/vggt/blob/main/src/models/heads/camera_head.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from ..layers import Mlp, MlpFP32
|
| 7 |
+
from ..layers.block import Block, DistBlock
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CameraHead(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Camera head module: predicts camera parameters from token representations using iterative refinement
|
| 13 |
+
|
| 14 |
+
Processes dedicated camera tokens through a series of transformer blocks
|
| 15 |
+
"""
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim_in: int = 2048,
|
| 19 |
+
trunk_depth: int = 4,
|
| 20 |
+
num_heads: int = 16,
|
| 21 |
+
mlp_ratio: int = 4,
|
| 22 |
+
init_values: float = 0.01,
|
| 23 |
+
trans_act: str = "linear",
|
| 24 |
+
quat_act: str = "linear",
|
| 25 |
+
fl_act: str = "relu",
|
| 26 |
+
block_fn: nn.Module = Block,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.out_dim = 9
|
| 31 |
+
self.trans_act = trans_act
|
| 32 |
+
self.quat_act = quat_act
|
| 33 |
+
self.fl_act = fl_act
|
| 34 |
+
self.depth = trunk_depth
|
| 35 |
+
|
| 36 |
+
# Build refinement network using transformer block sequence
|
| 37 |
+
self.refine_net = nn.Sequential(
|
| 38 |
+
*[
|
| 39 |
+
block_fn(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
| 40 |
+
for _ in range(trunk_depth)
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Normalization for camera tokens and network output
|
| 45 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 46 |
+
self.out_norm = nn.LayerNorm(dim_in)
|
| 47 |
+
|
| 48 |
+
# Learnable initial camera parameter token
|
| 49 |
+
self.init_token = nn.Parameter(torch.zeros(1, 1, self.out_dim))
|
| 50 |
+
self.param_embed = nn.Linear(self.out_dim, dim_in)
|
| 51 |
+
|
| 52 |
+
# Generate adaptive normalization parameters: shift, scale, and gate
|
| 53 |
+
self.adapt_norm_gen = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 54 |
+
|
| 55 |
+
# Adaptive layer normalization (no learnable parameters)
|
| 56 |
+
self.adapt_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 57 |
+
# self.param_predictor = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)
|
| 58 |
+
self.param_predictor = MlpFP32(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)
|
| 59 |
+
|
| 60 |
+
def to(self, *args, **kwargs):
|
| 61 |
+
self.refine_net = self.refine_net.to(*args, **kwargs)
|
| 62 |
+
self.token_norm = self.token_norm.to(*args, **kwargs)
|
| 63 |
+
self.out_norm = self.out_norm.to(*args, **kwargs)
|
| 64 |
+
self.adapt_norm_gen = self.adapt_norm_gen.to(*args, **kwargs)
|
| 65 |
+
self.adapt_norm = self.adapt_norm.to(*args, **kwargs)
|
| 66 |
+
self.param_predictor = self.param_predictor.to(*args, **kwargs)
|
| 67 |
+
|
| 68 |
+
# keep these parameters in FP32
|
| 69 |
+
args, kwargs = MlpFP32.map_to_args_to_float(args, kwargs)
|
| 70 |
+
self.init_token = nn.Parameter(self.init_token.to(*args, **kwargs))
|
| 71 |
+
self.param_embed = self.param_embed.to(*args, **kwargs)
|
| 72 |
+
|
| 73 |
+
return self
|
| 74 |
+
|
| 75 |
+
def forward(self, feat_seq: list, steps: int = 4) -> list:
|
| 76 |
+
"""
|
| 77 |
+
Forward pass to predict camera parameters
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
feat_seq: List of token tensors from network, last one used for prediction
|
| 81 |
+
steps: Number of iterative refinement steps, default 4
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
List of predicted camera encodings (post-activation) from each iteration
|
| 85 |
+
"""
|
| 86 |
+
# Use tokens from last block for camera prediction
|
| 87 |
+
latest_feat = feat_seq[-1]
|
| 88 |
+
|
| 89 |
+
# Extract camera tokens
|
| 90 |
+
cam_tokens = latest_feat[:, :, 0]
|
| 91 |
+
cam_tokens = self.token_norm(cam_tokens)
|
| 92 |
+
|
| 93 |
+
# Iteratively refine camera pose predictions
|
| 94 |
+
b, seq_len, feat_dim = cam_tokens.shape # seq_len expected to be 1
|
| 95 |
+
curr_pred = None
|
| 96 |
+
pred_seq = []
|
| 97 |
+
|
| 98 |
+
for step in range(steps):
|
| 99 |
+
# Use learned initial token for first iteration
|
| 100 |
+
if curr_pred is None:
|
| 101 |
+
net_input = self.param_embed(self.init_token.expand(b, seq_len, -1))
|
| 102 |
+
else:
|
| 103 |
+
curr_pred = curr_pred.detach()
|
| 104 |
+
net_input = self.param_embed(curr_pred)
|
| 105 |
+
net_input = net_input.to(cam_tokens.dtype)
|
| 106 |
+
norm_shift, norm_scale, norm_gate = self.adapt_norm_gen(net_input).chunk(3, dim=-1)
|
| 107 |
+
mod_cam_feat = norm_gate * self.apply_adaptive_modulation(self.adapt_norm(cam_tokens), norm_shift, norm_scale)
|
| 108 |
+
mod_cam_feat = mod_cam_feat + cam_tokens
|
| 109 |
+
|
| 110 |
+
proc_feat = self.refine_net(mod_cam_feat)
|
| 111 |
+
param_delta = self.param_predictor(self.out_norm(proc_feat))
|
| 112 |
+
|
| 113 |
+
if curr_pred is None:
|
| 114 |
+
curr_pred = param_delta
|
| 115 |
+
else:
|
| 116 |
+
curr_pred = curr_pred + param_delta
|
| 117 |
+
|
| 118 |
+
# Apply final activation functions for translation, quaternion, and field-of-view
|
| 119 |
+
activated_params = self.apply_camera_parameter_activation(curr_pred)
|
| 120 |
+
pred_seq.append(activated_params)
|
| 121 |
+
|
| 122 |
+
return pred_seq
|
| 123 |
+
|
| 124 |
+
def apply_camera_parameter_activation(self, params: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Apply activation functions to camera parameter components
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
params: Tensor containing camera parameters [translation, quaternion, focal_length]
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Activated camera parameters tensor
|
| 133 |
+
"""
|
| 134 |
+
trans_vec = params[..., :3]
|
| 135 |
+
quat_vec = params[..., 3:7]
|
| 136 |
+
fl_vec = params[..., 7:] # or field of view
|
| 137 |
+
|
| 138 |
+
trans_vec = self.apply_parameter_activation(trans_vec, self.trans_act)
|
| 139 |
+
quat_vec = self.apply_parameter_activation(quat_vec, self.quat_act)
|
| 140 |
+
fl_vec = self.apply_parameter_activation(fl_vec, self.fl_act)
|
| 141 |
+
|
| 142 |
+
activated_params = torch.cat([trans_vec, quat_vec, fl_vec], dim=-1)
|
| 143 |
+
return activated_params
|
| 144 |
+
|
| 145 |
+
def apply_parameter_activation(self, tensor: torch.Tensor, act_type: str) -> torch.Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Apply specified activation function to parameter tensor
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
tensor: Tensor containing parameter values
|
| 151 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Activated parameter tensor
|
| 155 |
+
"""
|
| 156 |
+
if act_type == "linear":
|
| 157 |
+
return tensor
|
| 158 |
+
elif act_type == "inv_log":
|
| 159 |
+
return self.apply_inverse_logarithm_transform(tensor)
|
| 160 |
+
elif act_type == "exp":
|
| 161 |
+
return torch.exp(tensor)
|
| 162 |
+
elif act_type == "relu":
|
| 163 |
+
return F.relu(tensor)
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError(f"Unknown activation_type: {act_type}")
|
| 166 |
+
|
| 167 |
+
def apply_inverse_logarithm_transform(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
x: Input tensor
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Transformed tensor
|
| 176 |
+
"""
|
| 177 |
+
return torch.sign(x) * (torch.expm1(torch.abs(x)))
|
| 178 |
+
|
| 179 |
+
def apply_adaptive_modulation(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
"""
|
| 181 |
+
Apply adaptive modulation to input tensor using scaling and shifting parameters
|
| 182 |
+
"""
|
| 183 |
+
# Modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 184 |
+
return x * (1 + scale) + shift
|
hyworldmirror/models/heads/dense_head.py
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 2 |
+
from typing import List, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.checkpoint import checkpoint
|
| 8 |
+
|
| 9 |
+
from ..layers.mlp import MlpFP32
|
| 10 |
+
from ..utils.grid import create_uv_grid, position_grid_to_embed
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _BaseDPTHead(nn.Module):
|
| 14 |
+
"""Base class with shared DPT feature extraction: projects, resize, scratch, and fusion."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim_in: int,
|
| 19 |
+
patch_size: int = 14,
|
| 20 |
+
features: int = 256,
|
| 21 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 22 |
+
pos_embed: bool = True,
|
| 23 |
+
down_ratio: int = 1,
|
| 24 |
+
gradient_checkpoint: bool = False,
|
| 25 |
+
_cast_pos_embed_dtype: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.patch_size = patch_size
|
| 29 |
+
self.pos_embed = pos_embed
|
| 30 |
+
self.down_ratio = down_ratio
|
| 31 |
+
self.gradient_checkpoint = gradient_checkpoint
|
| 32 |
+
self._cast_pos_embed_dtype = _cast_pos_embed_dtype
|
| 33 |
+
|
| 34 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 35 |
+
self.projects = nn.ModuleList([
|
| 36 |
+
nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0)
|
| 37 |
+
for oc in out_channels
|
| 38 |
+
])
|
| 39 |
+
self.resize_layers = nn.ModuleList([
|
| 40 |
+
nn.ConvTranspose2d(
|
| 41 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 42 |
+
),
|
| 43 |
+
nn.ConvTranspose2d(
|
| 44 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 45 |
+
),
|
| 46 |
+
nn.Identity(),
|
| 47 |
+
nn.Conv2d(
|
| 48 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 49 |
+
),
|
| 50 |
+
])
|
| 51 |
+
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 52 |
+
self.scratch.stem_transpose = None
|
| 53 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 54 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 55 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 56 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 57 |
+
|
| 58 |
+
head_features_1 = features
|
| 59 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 60 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 64 |
+
patch_w = x.shape[-1]
|
| 65 |
+
patch_h = x.shape[-2]
|
| 66 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 67 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 68 |
+
pos_embed = pos_embed * ratio
|
| 69 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 70 |
+
if self._cast_pos_embed_dtype:
|
| 71 |
+
pos_embed = pos_embed.to(x.dtype)
|
| 72 |
+
return x + pos_embed
|
| 73 |
+
|
| 74 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 75 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 76 |
+
|
| 77 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 78 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 79 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 80 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 81 |
+
|
| 82 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 83 |
+
del layer_4_rn, layer_4
|
| 84 |
+
|
| 85 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 86 |
+
del layer_3_rn, layer_3
|
| 87 |
+
|
| 88 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 89 |
+
del layer_2_rn, layer_2
|
| 90 |
+
|
| 91 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 92 |
+
del layer_1_rn, layer_1
|
| 93 |
+
|
| 94 |
+
out = self.scratch.output_conv1(out)
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
def _extract_fused_features(
|
| 98 |
+
self,
|
| 99 |
+
token_list: List[torch.Tensor],
|
| 100 |
+
B: int,
|
| 101 |
+
S: int,
|
| 102 |
+
H: int,
|
| 103 |
+
W: int,
|
| 104 |
+
patch_start_idx: int,
|
| 105 |
+
frame_start: int = None,
|
| 106 |
+
frame_end: int = None,
|
| 107 |
+
) -> torch.Tensor:
|
| 108 |
+
"""Extract multi-scale features from tokens, fuse via scratch network, and upsample."""
|
| 109 |
+
ph = H // self.patch_size
|
| 110 |
+
pw = W // self.patch_size
|
| 111 |
+
|
| 112 |
+
feats = []
|
| 113 |
+
for proj, resize, tokens in zip(self.projects, self.resize_layers, token_list):
|
| 114 |
+
patch_tokens = tokens[:, :, patch_start_idx:]
|
| 115 |
+
if frame_start is not None and frame_end is not None:
|
| 116 |
+
patch_tokens = patch_tokens[:, frame_start:frame_end]
|
| 117 |
+
|
| 118 |
+
patch_tokens = patch_tokens.reshape(B * S, -1, patch_tokens.shape[-1])
|
| 119 |
+
patch_tokens = self.norm(patch_tokens)
|
| 120 |
+
|
| 121 |
+
feat = patch_tokens.permute(0, 2, 1).reshape(B * S, patch_tokens.shape[-1], ph, pw)
|
| 122 |
+
feat = proj(feat)
|
| 123 |
+
|
| 124 |
+
if self.pos_embed:
|
| 125 |
+
feat = self._apply_pos_embed(feat, W, H)
|
| 126 |
+
feat = resize(feat)
|
| 127 |
+
feats.append(feat)
|
| 128 |
+
|
| 129 |
+
fused = checkpoint(self.scratch_forward, feats, use_reentrant=False) if self.gradient_checkpoint else self.scratch_forward(feats)
|
| 130 |
+
_interpolate_fn = lambda t: custom_interpolate(
|
| 131 |
+
t,
|
| 132 |
+
size=(
|
| 133 |
+
int(ph * self.patch_size / self.down_ratio),
|
| 134 |
+
int(pw * self.patch_size / self.down_ratio)
|
| 135 |
+
),
|
| 136 |
+
mode="bilinear",
|
| 137 |
+
align_corners=True,
|
| 138 |
+
)
|
| 139 |
+
fused = checkpoint(_interpolate_fn, fused, use_reentrant=False) if self.gradient_checkpoint else _interpolate_fn(fused)
|
| 140 |
+
|
| 141 |
+
if self.pos_embed:
|
| 142 |
+
fused = self._apply_pos_embed(fused, W, H)
|
| 143 |
+
|
| 144 |
+
return fused
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class DPTHead(_BaseDPTHead):
|
| 148 |
+
"""
|
| 149 |
+
# DPT Head for dense prediction tasks.
|
| 150 |
+
|
| 151 |
+
# This module implements the DPT (Dense Prediction Transformer) head as proposed in
|
| 152 |
+
# "Vision Transformers for Dense Prediction" (https://arxiv.org/abs/2103.13413).
|
| 153 |
+
# It takes features from a vision transformer backbone and generates dense (per-pixel) predictions
|
| 154 |
+
# by fusing multi-scale features through a series of projection, upsampling, and refinement blocks.
|
| 155 |
+
|
| 156 |
+
# Args:
|
| 157 |
+
# dim_in (int): Number of input feature channels.
|
| 158 |
+
# patch_size (int, optional): Patch size used by the backbone, default is 14.
|
| 159 |
+
# output_dim (int, optional): Number of output channels, default is 4.
|
| 160 |
+
# activation (str, optional): Activation function type for the output head, default is "inv_log".
|
| 161 |
+
# conf_activation (str, optional): Activation function type for the confidence/output uncertainty head, default is "expp1".
|
| 162 |
+
# features (int, optional): Number of channels used in intermediate feature representations, default is 256.
|
| 163 |
+
# out_channels (List[int], optional): Number of channels for each intermediate multi-scale feature.
|
| 164 |
+
# intermediate_layer_idx (List[int], optional): Indices specifying which backbone layers to use for multi-scale fusion.
|
| 165 |
+
# pos_embed (bool, optional): Whether to add positional encoding to the features, default is True.
|
| 166 |
+
# feature_only (bool, optional): If True, only return intermediate features (skip final prediction and activations).
|
| 167 |
+
# down_ratio (int, optional): Downsampling ratio of the output predictions, default is 1 (no downsampling).
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
dim_in: int,
|
| 173 |
+
patch_size: int = 14,
|
| 174 |
+
output_dim: int = 4,
|
| 175 |
+
activation: str = "inv_log+expp1",
|
| 176 |
+
features: int = 256,
|
| 177 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 178 |
+
pos_embed: bool = True,
|
| 179 |
+
down_ratio: int = 1,
|
| 180 |
+
is_gsdpt: bool = False,
|
| 181 |
+
enable_depth_mask: bool = False,
|
| 182 |
+
gradient_checkpoint: bool = False,
|
| 183 |
+
) -> None:
|
| 184 |
+
super().__init__(
|
| 185 |
+
dim_in=dim_in, patch_size=patch_size, features=features,
|
| 186 |
+
out_channels=out_channels, pos_embed=pos_embed,
|
| 187 |
+
down_ratio=down_ratio, gradient_checkpoint=gradient_checkpoint,
|
| 188 |
+
)
|
| 189 |
+
self.activation = activation
|
| 190 |
+
self.is_gsdpt = is_gsdpt
|
| 191 |
+
self.enable_depth_mask = enable_depth_mask
|
| 192 |
+
|
| 193 |
+
head_features_2 = 32
|
| 194 |
+
conv2_in_channels = features // 2
|
| 195 |
+
|
| 196 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 197 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 198 |
+
nn.ReLU(inplace=True),
|
| 199 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 200 |
+
)
|
| 201 |
+
if self.is_gsdpt:
|
| 202 |
+
self.input_merger = nn.Sequential(
|
| 203 |
+
nn.Conv2d(3, conv2_in_channels, 7, 1, 3),
|
| 204 |
+
nn.ReLU()
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def to(self, *args, **kwargs):
|
| 208 |
+
self.norm = self.norm.to(*args, **kwargs)
|
| 209 |
+
self.projects = self.projects.to(*args, **kwargs)
|
| 210 |
+
self.resize_layers = self.resize_layers.to(*args, **kwargs)
|
| 211 |
+
if self.is_gsdpt:
|
| 212 |
+
self.input_merger = self.input_merger.to(*args, **kwargs)
|
| 213 |
+
for key in ('layer1_rn', 'layer2_rn', 'layer3_rn', 'layer4_rn',
|
| 214 |
+
'refinenet1', 'refinenet2', 'refinenet3', 'refinenet4',
|
| 215 |
+
'output_conv1'):
|
| 216 |
+
if not hasattr(self.scratch, key):
|
| 217 |
+
continue
|
| 218 |
+
setattr(self.scratch, key, getattr(self.scratch, key).to(*args, **kwargs))
|
| 219 |
+
|
| 220 |
+
# keep output_conv2 in FP32
|
| 221 |
+
args, kwargs = MlpFP32.map_to_args_to_float(args, kwargs)
|
| 222 |
+
self.scratch.output_conv2 = self.scratch.output_conv2.to(*args, **kwargs)
|
| 223 |
+
|
| 224 |
+
return self
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
token_list: List[torch.Tensor],
|
| 229 |
+
images: torch.Tensor,
|
| 230 |
+
patch_start_idx: int,
|
| 231 |
+
frames_chunk_size: int = 8,
|
| 232 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 233 |
+
"""
|
| 234 |
+
Forward pass with optional frame chunking for memory efficiency.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
token_list: List of token tensors from transformer, each [B, N, C]
|
| 238 |
+
images: Input images [B, S, 3, H, W], range [0, 1]
|
| 239 |
+
patch_start_idx: Starting index of patch tokens
|
| 240 |
+
frames_chunk_size: Number of frames per chunk. If None or >= S, process all at once
|
| 241 |
+
gradient_checkpoint: Whether to use gradient checkpointing
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
For is_gsdpt: predictions [B, S, ...]
|
| 245 |
+
Otherwise: (predictions, confidence), [B, S, X, H, W] and [B, S, 1, H, W]
|
| 246 |
+
"""
|
| 247 |
+
B, S, _, H, W = images.shape
|
| 248 |
+
|
| 249 |
+
# Process all frames together if chunk size not specified or large enough
|
| 250 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 251 |
+
return self._forward_impl(token_list, images, patch_start_idx)
|
| 252 |
+
|
| 253 |
+
assert frames_chunk_size > 0
|
| 254 |
+
|
| 255 |
+
# Process frames in chunks
|
| 256 |
+
preds_chunks = []
|
| 257 |
+
conf_chunks = []
|
| 258 |
+
gs_chunks = []
|
| 259 |
+
depth_mask_chunks = []
|
| 260 |
+
|
| 261 |
+
for frame_start in range(0, S, frames_chunk_size):
|
| 262 |
+
frame_end = min(frame_start + frames_chunk_size, S)
|
| 263 |
+
|
| 264 |
+
if self.is_gsdpt:
|
| 265 |
+
if self.enable_depth_mask:
|
| 266 |
+
gs, preds, conf, depth_mask = self._forward_impl(
|
| 267 |
+
token_list, images, patch_start_idx, frame_start, frame_end
|
| 268 |
+
)
|
| 269 |
+
gs_chunks.append(gs)
|
| 270 |
+
preds_chunks.append(preds)
|
| 271 |
+
conf_chunks.append(conf)
|
| 272 |
+
depth_mask_chunks.append(depth_mask)
|
| 273 |
+
else:
|
| 274 |
+
gs, preds, conf = self._forward_impl(
|
| 275 |
+
token_list, images, patch_start_idx, frame_start, frame_end
|
| 276 |
+
)
|
| 277 |
+
gs_chunks.append(gs)
|
| 278 |
+
preds_chunks.append(preds)
|
| 279 |
+
conf_chunks.append(conf)
|
| 280 |
+
else:
|
| 281 |
+
if self.enable_depth_mask:
|
| 282 |
+
preds, conf, depth_mask = self._forward_impl(
|
| 283 |
+
token_list, images, patch_start_idx, frame_start, frame_end
|
| 284 |
+
)
|
| 285 |
+
preds_chunks.append(preds)
|
| 286 |
+
conf_chunks.append(conf)
|
| 287 |
+
depth_mask_chunks.append(depth_mask)
|
| 288 |
+
else:
|
| 289 |
+
preds, conf = self._forward_impl(
|
| 290 |
+
token_list, images, patch_start_idx, frame_start, frame_end
|
| 291 |
+
)
|
| 292 |
+
preds_chunks.append(preds)
|
| 293 |
+
conf_chunks.append(conf)
|
| 294 |
+
|
| 295 |
+
# Concatenate chunks along frame dimension
|
| 296 |
+
if self.is_gsdpt:
|
| 297 |
+
if self.enable_depth_mask:
|
| 298 |
+
return (
|
| 299 |
+
torch.cat(gs_chunks, dim=1),
|
| 300 |
+
torch.cat(preds_chunks, dim=1),
|
| 301 |
+
torch.cat(conf_chunks, dim=1),
|
| 302 |
+
torch.cat(depth_mask_chunks, dim=1),
|
| 303 |
+
)
|
| 304 |
+
return torch.cat(gs_chunks, dim=1), torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1)
|
| 305 |
+
else:
|
| 306 |
+
if self.enable_depth_mask:
|
| 307 |
+
return torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1), torch.cat(depth_mask_chunks, dim=1)
|
| 308 |
+
else:
|
| 309 |
+
return torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1)
|
| 310 |
+
|
| 311 |
+
def _forward_impl(
|
| 312 |
+
self,
|
| 313 |
+
token_list: List[torch.Tensor],
|
| 314 |
+
images: torch.Tensor,
|
| 315 |
+
patch_start_idx: int,
|
| 316 |
+
frame_start: int = None,
|
| 317 |
+
frame_end: int = None,
|
| 318 |
+
) -> torch.Tensor:
|
| 319 |
+
"""
|
| 320 |
+
Core forward implementation for DPT head.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
token_list: List of transformer tokens from each layer, [B, S, N, C]
|
| 324 |
+
images: Input images [B, S, 3, H, W]
|
| 325 |
+
patch_start_idx: Starting index of patch tokens
|
| 326 |
+
frame_start: Start index for frame chunking (optional)
|
| 327 |
+
frame_end: End index for frame chunking (optional)
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
If is_gsdpt: (features, preds, conf)
|
| 331 |
+
Else: (preds, conf)
|
| 332 |
+
"""
|
| 333 |
+
if frame_start is not None and frame_end is not None:
|
| 334 |
+
images = images[:, frame_start:frame_end].contiguous()
|
| 335 |
+
|
| 336 |
+
B, S, _, H, W = images.shape
|
| 337 |
+
|
| 338 |
+
fused = self._extract_fused_features(token_list, B, S, H, W, patch_start_idx, frame_start, frame_end)
|
| 339 |
+
|
| 340 |
+
# Generate predictions and confidence
|
| 341 |
+
if self.is_gsdpt:
|
| 342 |
+
out = self.scratch.output_conv2(fused.float().contiguous())
|
| 343 |
+
if self.enable_depth_mask:
|
| 344 |
+
preds, conf, depth_mask = self.activate_head(out, activation=self.activation)
|
| 345 |
+
else:
|
| 346 |
+
preds, conf = self.activate_head(out, activation=self.activation)
|
| 347 |
+
preds = preds.reshape(B, S, *preds.shape[1:])
|
| 348 |
+
conf = conf.reshape(B, S, *conf.shape[1:])
|
| 349 |
+
|
| 350 |
+
# Merge direct image features
|
| 351 |
+
img_flat = images.reshape(B * S, -1, H, W)
|
| 352 |
+
img_feat = self.input_merger(img_flat)
|
| 353 |
+
fused = fused + img_feat
|
| 354 |
+
fused = fused.reshape(B, S, *fused.shape[1:]).float().contiguous()
|
| 355 |
+
if self.enable_depth_mask:
|
| 356 |
+
depth_mask = depth_mask.reshape(B, S, *depth_mask.shape[1:])
|
| 357 |
+
return fused, preds, conf, depth_mask
|
| 358 |
+
return fused, preds, conf
|
| 359 |
+
else:
|
| 360 |
+
out = self.scratch.output_conv2(fused.float().contiguous())
|
| 361 |
+
if self.enable_depth_mask:
|
| 362 |
+
preds, conf, depth_mask = self.activate_head(out, activation=self.activation)
|
| 363 |
+
preds = preds.reshape(B, S, *preds.shape[1:])
|
| 364 |
+
conf = conf.reshape(B, S, *conf.shape[1:])
|
| 365 |
+
depth_mask = depth_mask.reshape(B, S, *depth_mask.shape[1:])
|
| 366 |
+
return preds, conf, depth_mask
|
| 367 |
+
else:
|
| 368 |
+
preds, conf = self.activate_head(out, activation=self.activation)
|
| 369 |
+
preds = preds.reshape(B, S, *preds.shape[1:])
|
| 370 |
+
conf = conf.reshape(B, S, *conf.shape[1:])
|
| 371 |
+
return preds, conf
|
| 372 |
+
|
| 373 |
+
def activate_head(self, out_head: torch.Tensor, activation: str = "inv_log+expp1") -> Tuple[torch.Tensor, torch.Tensor]:
|
| 374 |
+
"""
|
| 375 |
+
Process network output to extract attribute (e.g. points, depth, etc.) and confidence values.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
out_head: Network output tensor (B, C, H, W)
|
| 379 |
+
activation: Activation type for processing (e.g., "inv_log+expp1")
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
Tuple of (attribute tensor, confidence tensor)
|
| 383 |
+
"""
|
| 384 |
+
# Parse activation string
|
| 385 |
+
if self.enable_depth_mask:
|
| 386 |
+
act_attr, act_conf, act_depth_mask = (activation.split("+") if "+" in activation else (activation, "expp1", "linear"))
|
| 387 |
+
|
| 388 |
+
# (B,C,H,W) -> (B,H,W,C)
|
| 389 |
+
feat = out_head.permute(0, 2, 3, 1)
|
| 390 |
+
attr, conf, depth_mask = feat[..., :-2], feat[..., -2], feat[..., -1]
|
| 391 |
+
else:
|
| 392 |
+
act_attr, act_conf = (activation.split("+") if "+" in activation else (activation, "expp1"))
|
| 393 |
+
|
| 394 |
+
# (B,C,H,W) -> (B,H,W,C)
|
| 395 |
+
feat = out_head.permute(0, 2, 3, 1)
|
| 396 |
+
attr, conf = feat[..., :-1], feat[..., -1]
|
| 397 |
+
|
| 398 |
+
# Map point activations to lambdas for clarity and conciseness
|
| 399 |
+
attr_activations = {
|
| 400 |
+
"norm_exp": lambda x: (x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)) * torch.expm1(x.norm(dim=-1, keepdim=True)),
|
| 401 |
+
"norm": lambda x: x / x.norm(dim=-1, keepdim=True),
|
| 402 |
+
"exp": torch.exp,
|
| 403 |
+
"relu": F.relu,
|
| 404 |
+
"inv_log": self._apply_inverse_log_transform,
|
| 405 |
+
"xy_inv_log": lambda x: torch.cat([
|
| 406 |
+
x[..., :2] * self._apply_inverse_log_transform(x[..., 2:]),
|
| 407 |
+
self._apply_inverse_log_transform(x[..., 2:])
|
| 408 |
+
], dim=-1),
|
| 409 |
+
"sigmoid": torch.sigmoid,
|
| 410 |
+
"linear": lambda x: x
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
if act_attr not in attr_activations:
|
| 414 |
+
raise ValueError(f"Unknown attribute activation: {act_attr}")
|
| 415 |
+
attr_out = attr_activations[act_attr](attr)
|
| 416 |
+
|
| 417 |
+
# Confidence activation mapping
|
| 418 |
+
conf_activations = {
|
| 419 |
+
"expp1": lambda c: 1 + c.exp(),
|
| 420 |
+
"expp0": torch.exp,
|
| 421 |
+
"sigmoid": torch.sigmoid
|
| 422 |
+
}
|
| 423 |
+
if act_conf not in conf_activations:
|
| 424 |
+
raise ValueError(f"Unknown confidence activation: {act_conf}")
|
| 425 |
+
conf_out = conf_activations[act_conf](conf)
|
| 426 |
+
|
| 427 |
+
if self.enable_depth_mask:
|
| 428 |
+
depth_mask_activations = {
|
| 429 |
+
"sigmoid": torch.sigmoid,
|
| 430 |
+
"linear": lambda x: x,
|
| 431 |
+
}
|
| 432 |
+
if act_depth_mask not in depth_mask_activations:
|
| 433 |
+
raise ValueError(f"Unknown depth mask activation: {act_depth_mask}")
|
| 434 |
+
depth_mask_out = depth_mask_activations[act_depth_mask](depth_mask)
|
| 435 |
+
return attr_out, conf_out, depth_mask_out
|
| 436 |
+
else:
|
| 437 |
+
return attr_out, conf_out
|
| 438 |
+
|
| 439 |
+
def _apply_inverse_log_transform(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 440 |
+
"""
|
| 441 |
+
Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
input_tensor: Input tensor
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Transformed tensor
|
| 448 |
+
"""
|
| 449 |
+
return torch.sign(input_tensor) * (torch.expm1(torch.abs(input_tensor)))
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
################################################################################
|
| 454 |
+
# DPT Modules
|
| 455 |
+
################################################################################
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 459 |
+
return FeatureFusionBlock(
|
| 460 |
+
features,
|
| 461 |
+
nn.ReLU(inplace=True),
|
| 462 |
+
deconv=False,
|
| 463 |
+
bn=False,
|
| 464 |
+
expand=False,
|
| 465 |
+
align_corners=True,
|
| 466 |
+
size=size,
|
| 467 |
+
has_residual=has_residual,
|
| 468 |
+
groups=groups,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 473 |
+
scratch = nn.Module()
|
| 474 |
+
out_shape1 = out_shape
|
| 475 |
+
out_shape2 = out_shape
|
| 476 |
+
out_shape3 = out_shape
|
| 477 |
+
if len(in_shape) >= 4:
|
| 478 |
+
out_shape4 = out_shape
|
| 479 |
+
|
| 480 |
+
if expand:
|
| 481 |
+
out_shape1 = out_shape
|
| 482 |
+
out_shape2 = out_shape * 2
|
| 483 |
+
out_shape3 = out_shape * 4
|
| 484 |
+
if len(in_shape) >= 4:
|
| 485 |
+
out_shape4 = out_shape * 8
|
| 486 |
+
|
| 487 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 488 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 489 |
+
)
|
| 490 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 491 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 492 |
+
)
|
| 493 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 494 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 495 |
+
)
|
| 496 |
+
if len(in_shape) >= 4:
|
| 497 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 498 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 499 |
+
)
|
| 500 |
+
return scratch
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class ResidualConvUnit(nn.Module):
|
| 504 |
+
"""Residual convolution module with skip connection."""
|
| 505 |
+
|
| 506 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 507 |
+
"""Initialize ResidualConvUnit.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
features (int): Number of input/output feature channels
|
| 511 |
+
activation: Activation function to use
|
| 512 |
+
bn (bool): Whether to use batch normalization (currently unused)
|
| 513 |
+
groups (int): Number of groups for grouped convolution
|
| 514 |
+
"""
|
| 515 |
+
super().__init__()
|
| 516 |
+
|
| 517 |
+
self.bn = bn
|
| 518 |
+
self.groups = groups
|
| 519 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 520 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 521 |
+
|
| 522 |
+
self.norm1 = None
|
| 523 |
+
self.norm2 = None
|
| 524 |
+
|
| 525 |
+
self.activation = activation
|
| 526 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 527 |
+
|
| 528 |
+
def forward(self, x):
|
| 529 |
+
"""Forward pass with residual connection.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
x (tensor): Input tensor of shape (B, C, H, W)
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
tensor: Output tensor of shape (B, C, H, W) with residual added
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
out = self.activation(x)
|
| 539 |
+
out = self.conv1(out)
|
| 540 |
+
if self.norm1 is not None:
|
| 541 |
+
out = self.norm1(out)
|
| 542 |
+
|
| 543 |
+
out = self.activation(out)
|
| 544 |
+
out = self.conv2(out)
|
| 545 |
+
if self.norm2 is not None:
|
| 546 |
+
out = self.norm2(out)
|
| 547 |
+
|
| 548 |
+
return self.skip_add.add(out, x)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class FeatureFusionBlock(nn.Module):
|
| 552 |
+
"""Feature fusion block."""
|
| 553 |
+
|
| 554 |
+
def __init__(
|
| 555 |
+
self,
|
| 556 |
+
features,
|
| 557 |
+
activation,
|
| 558 |
+
deconv=False,
|
| 559 |
+
bn=False,
|
| 560 |
+
expand=False,
|
| 561 |
+
align_corners=True,
|
| 562 |
+
size=None,
|
| 563 |
+
has_residual=True,
|
| 564 |
+
groups=1,
|
| 565 |
+
):
|
| 566 |
+
"""Initialize FeatureFusionBlock.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
features (int): Number of input/output feature channels
|
| 570 |
+
activation: Activation function to use
|
| 571 |
+
deconv (bool): Whether to use deconvolution
|
| 572 |
+
bn (bool): Whether to use batch normalization
|
| 573 |
+
expand (bool): Whether to expand features (halve output channels)
|
| 574 |
+
align_corners (bool): Align corners for interpolation
|
| 575 |
+
size: Target size for upsampling
|
| 576 |
+
has_residual (bool): Whether to include residual connection
|
| 577 |
+
groups (int): Number of groups for grouped convolution
|
| 578 |
+
"""
|
| 579 |
+
super(FeatureFusionBlock, self).__init__()
|
| 580 |
+
|
| 581 |
+
self.deconv = deconv
|
| 582 |
+
self.align_corners = align_corners
|
| 583 |
+
self.groups = groups
|
| 584 |
+
self.expand = expand
|
| 585 |
+
out_features = features
|
| 586 |
+
if self.expand == True:
|
| 587 |
+
out_features = features // 2
|
| 588 |
+
|
| 589 |
+
self.out_conv = nn.Conv2d(
|
| 590 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if has_residual:
|
| 594 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 595 |
+
|
| 596 |
+
self.has_residual = has_residual
|
| 597 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 598 |
+
|
| 599 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 600 |
+
self.size = size
|
| 601 |
+
|
| 602 |
+
def forward(self, *xs, size=None):
|
| 603 |
+
"""Forward pass through the feature fusion block.
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
*xs: Variable number of input tensors. First tensor is the main input,
|
| 607 |
+
second tensor (if present) is used for residual connection.
|
| 608 |
+
size: Optional target size for upsampling. If None, uses self.size or scale_factor=2.
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
torch.Tensor: Fused and upsampled output tensor.
|
| 612 |
+
"""
|
| 613 |
+
output = xs[0]
|
| 614 |
+
|
| 615 |
+
if self.has_residual:
|
| 616 |
+
res = self.resConfUnit1(xs[1])
|
| 617 |
+
output = self.skip_add.add(output, res)
|
| 618 |
+
|
| 619 |
+
output = self.resConfUnit2(output)
|
| 620 |
+
|
| 621 |
+
if (size is None) and (self.size is None):
|
| 622 |
+
modifier = {"scale_factor": 2}
|
| 623 |
+
elif size is None:
|
| 624 |
+
modifier = {"size": self.size}
|
| 625 |
+
else:
|
| 626 |
+
modifier = {"size": size}
|
| 627 |
+
|
| 628 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 629 |
+
output = self.out_conv(output)
|
| 630 |
+
|
| 631 |
+
return output
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def custom_interpolate(
|
| 635 |
+
x: torch.Tensor,
|
| 636 |
+
size: Tuple[int, int] = None,
|
| 637 |
+
scale_factor: float = None,
|
| 638 |
+
mode: str = "bilinear",
|
| 639 |
+
align_corners: bool = True,
|
| 640 |
+
) -> torch.Tensor:
|
| 641 |
+
"""
|
| 642 |
+
Custom interpolation function to handle large tensors by chunking.
|
| 643 |
+
|
| 644 |
+
Avoids INT_MAX overflow issues in nn.functional.interpolate when dealing with
|
| 645 |
+
very large input tensors by splitting them into smaller chunks.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
x: Input tensor to interpolate
|
| 649 |
+
size: Target output size (H, W)
|
| 650 |
+
scale_factor: Scaling factor if size is not provided
|
| 651 |
+
mode: Interpolation mode (default: "bilinear")
|
| 652 |
+
align_corners: Whether to align corners in interpolation
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
Interpolated tensor
|
| 656 |
+
"""
|
| 657 |
+
if size is None:
|
| 658 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 659 |
+
|
| 660 |
+
INT_MAX = 1610612736
|
| 661 |
+
|
| 662 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 663 |
+
|
| 664 |
+
if input_elements > INT_MAX:
|
| 665 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 666 |
+
interpolated_chunks = [
|
| 667 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 668 |
+
]
|
| 669 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 670 |
+
return x.contiguous()
|
| 671 |
+
else:
|
| 672 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
hyworldmirror/models/heads/gs_head.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .dense_head import _BaseDPTHead
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GSFeatHead(_BaseDPTHead):
|
| 10 |
+
"""
|
| 11 |
+
GS feature head that only outputs fused GS features.
|
| 12 |
+
|
| 13 |
+
This head is used when gs depth is disabled. It skips the prediction
|
| 14 |
+
conv (output_conv2) and returns only the fused GS feature map.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim_in: int,
|
| 20 |
+
patch_size: int = 14,
|
| 21 |
+
features: int = 256,
|
| 22 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 23 |
+
pos_embed: bool = True,
|
| 24 |
+
down_ratio: int = 1,
|
| 25 |
+
gradient_checkpoint: bool = False,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__(
|
| 28 |
+
dim_in=dim_in, patch_size=patch_size, features=features,
|
| 29 |
+
out_channels=out_channels, pos_embed=pos_embed,
|
| 30 |
+
down_ratio=down_ratio, gradient_checkpoint=gradient_checkpoint,
|
| 31 |
+
_cast_pos_embed_dtype=False,
|
| 32 |
+
)
|
| 33 |
+
conv2_in_channels = features // 2
|
| 34 |
+
self.input_merger = nn.Sequential(
|
| 35 |
+
nn.Conv2d(3, conv2_in_channels, 7, 1, 3),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
token_list: List[torch.Tensor],
|
| 42 |
+
images: torch.Tensor,
|
| 43 |
+
patch_start_idx: int,
|
| 44 |
+
frames_chunk_size: int = 8,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
B, S, _, H, W = images.shape
|
| 47 |
+
|
| 48 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 49 |
+
return self._forward_impl(token_list, images, patch_start_idx)
|
| 50 |
+
|
| 51 |
+
assert frames_chunk_size > 0
|
| 52 |
+
gs_chunks = []
|
| 53 |
+
for frame_start in range(0, S, frames_chunk_size):
|
| 54 |
+
frame_end = min(frame_start + frames_chunk_size, S)
|
| 55 |
+
gs = self._forward_impl(
|
| 56 |
+
token_list, images, patch_start_idx, frame_start, frame_end
|
| 57 |
+
)
|
| 58 |
+
gs_chunks.append(gs)
|
| 59 |
+
|
| 60 |
+
return torch.cat(gs_chunks, dim=1)
|
| 61 |
+
|
| 62 |
+
def _forward_impl(
|
| 63 |
+
self,
|
| 64 |
+
token_list: List[torch.Tensor],
|
| 65 |
+
images: torch.Tensor,
|
| 66 |
+
patch_start_idx: int,
|
| 67 |
+
frame_start: int = None,
|
| 68 |
+
frame_end: int = None,
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
if frame_start is not None and frame_end is not None:
|
| 71 |
+
images = images[:, frame_start:frame_end].contiguous()
|
| 72 |
+
|
| 73 |
+
B, S, _, H, W = images.shape
|
| 74 |
+
|
| 75 |
+
fused = self._extract_fused_features(
|
| 76 |
+
token_list, B, S, H, W, patch_start_idx, frame_start, frame_end
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
img_flat = images.reshape(B * S, -1, H, W)
|
| 80 |
+
img_feat = self.input_merger(img_flat)
|
| 81 |
+
fused = fused + img_feat
|
| 82 |
+
fused = fused.reshape(B, S, *fused.shape[1:])
|
| 83 |
+
return fused
|
hyworldmirror/models/layers/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp import Mlp, MlpFP32
|
| 2 |
+
from .patch_embed import PatchEmbed, PatchEmbed_Mlp
|
| 3 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 4 |
+
from .block import NestedTensorBlock
|
| 5 |
+
from .attention import MemEffAttention
|
hyworldmirror/models/layers/attention.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 4 |
+
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
| 12 |
+
_USE_FLASH_ATTN_V3 = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_v2
|
| 15 |
+
_USE_FLASH_ATTN_V3 = False
|
| 16 |
+
from ...comm.padding import minimal_pad_to_divisible, depad_by_length, pad_by_length
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from ...comm.communication import _All2All, _Allgather
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Attention(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
num_heads: int = 8,
|
| 26 |
+
qkv_bias: bool = True,
|
| 27 |
+
proj_bias: bool = True,
|
| 28 |
+
attn_drop: float = 0.0,
|
| 29 |
+
proj_drop: float = 0.0,
|
| 30 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 31 |
+
qk_norm: bool = False,
|
| 32 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 33 |
+
rope=None,
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 37 |
+
self.num_heads = num_heads
|
| 38 |
+
self.head_dim = dim // num_heads
|
| 39 |
+
self.scale = self.head_dim**-0.5
|
| 40 |
+
self.fused_attn = fused_attn
|
| 41 |
+
|
| 42 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 43 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 44 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
self.rope = rope
|
| 49 |
+
|
| 50 |
+
def _compute_qkv(self, x: Tensor):
|
| 51 |
+
B, N, C = x.shape
|
| 52 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 53 |
+
q, k, v = qkv.unbind(0)
|
| 54 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 55 |
+
return q, k, v, B, N, C
|
| 56 |
+
|
| 57 |
+
def _apply_attention(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 58 |
+
if q.dtype==torch.bfloat16 or q.dtype==torch.float16:
|
| 59 |
+
if q.is_contiguous():
|
| 60 |
+
q = q.transpose(1,2)
|
| 61 |
+
else:
|
| 62 |
+
q = q.transpose(1, 2).contiguous()
|
| 63 |
+
if k.is_contiguous():
|
| 64 |
+
k = k.transpose(1, 2)
|
| 65 |
+
else:
|
| 66 |
+
k = k.transpose(1, 2).contiguous()
|
| 67 |
+
if v.is_contiguous():
|
| 68 |
+
v = v.transpose(1, 2)
|
| 69 |
+
else:
|
| 70 |
+
v = v.transpose(1, 2).contiguous()
|
| 71 |
+
if _USE_FLASH_ATTN_V3:
|
| 72 |
+
x = flash_attn_func_v3(q, k, v)
|
| 73 |
+
else:
|
| 74 |
+
x = flash_attn_func_v2(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
| 75 |
+
if x.is_contiguous():
|
| 76 |
+
x = x.transpose(1, 2)
|
| 77 |
+
else:
|
| 78 |
+
x = x.transpose(1, 2).contiguous()
|
| 79 |
+
else:
|
| 80 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def _project_output(self, x: Tensor, B: int, N: int, C: int) -> Tensor:
|
| 84 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 85 |
+
x = self.proj(x)
|
| 86 |
+
x = self.proj_drop(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 90 |
+
q, k, v, B, N, C = self._compute_qkv(x)
|
| 91 |
+
|
| 92 |
+
if self.rope is not None:
|
| 93 |
+
q = self.rope(q, pos)
|
| 94 |
+
k = self.rope(k, pos)
|
| 95 |
+
|
| 96 |
+
x = self._apply_attention(q, k, v)
|
| 97 |
+
return self._project_output(x, B, N, C)
|
| 98 |
+
|
| 99 |
+
class DistAttention(Attention):
|
| 100 |
+
def forward(self, x: Tensor, pos=None, sp_size=1, sp_group=None, padding_tokens=0) -> Tensor:
|
| 101 |
+
|
| 102 |
+
q, k, v, B, N, C = self._compute_qkv(x)
|
| 103 |
+
|
| 104 |
+
if sp_size>1:
|
| 105 |
+
|
| 106 |
+
q = _All2All.apply(q,1,2,sp_group,False)
|
| 107 |
+
k = _All2All.apply(k,1,2,sp_group,False)
|
| 108 |
+
v = _All2All.apply(v,1,2,sp_group,False)
|
| 109 |
+
q = depad_by_length(q,padding_tokens,2)
|
| 110 |
+
k = depad_by_length(k,padding_tokens,2)
|
| 111 |
+
v = depad_by_length(v,padding_tokens,2)
|
| 112 |
+
|
| 113 |
+
if self.rope is not None:
|
| 114 |
+
q = self.rope(q, pos)
|
| 115 |
+
k = self.rope(k, pos)
|
| 116 |
+
|
| 117 |
+
x = self._apply_attention(q, k, v)
|
| 118 |
+
|
| 119 |
+
if sp_size>1:
|
| 120 |
+
x = pad_by_length(x,padding_tokens,2,0)
|
| 121 |
+
x = _All2All.apply(x,2,1,sp_group,False)
|
| 122 |
+
|
| 123 |
+
return self._project_output(x, B, N, C)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class MemEffAttention(Attention):
|
| 127 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
| 128 |
+
assert pos is None
|
| 129 |
+
if attn_bias is not None:
|
| 130 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 131 |
+
return super().forward(x)
|
hyworldmirror/models/layers/block.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 4 |
+
|
| 5 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, Tensor
|
| 9 |
+
|
| 10 |
+
from .attention import Attention, DistAttention
|
| 11 |
+
from .drop_path import DropPath
|
| 12 |
+
from .layer_scale import LayerScale
|
| 13 |
+
from .mlp import Mlp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
XFORMERS_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
def modulate(x, shift, scale):
|
| 19 |
+
return x * (1 + scale.unsqueeze(2)) + shift.unsqueeze(2)
|
| 20 |
+
|
| 21 |
+
class Block(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
num_heads: int,
|
| 26 |
+
mlp_ratio: float = 4.0,
|
| 27 |
+
qkv_bias: bool = True,
|
| 28 |
+
proj_bias: bool = True,
|
| 29 |
+
ffn_bias: bool = True,
|
| 30 |
+
drop: float = 0.0,
|
| 31 |
+
attn_drop: float = 0.0,
|
| 32 |
+
init_values=None,
|
| 33 |
+
drop_path: float = 0.0,
|
| 34 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 35 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 36 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 37 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 38 |
+
qk_norm: bool = False,
|
| 39 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 40 |
+
rope=None
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.norm1 = norm_layer(dim)
|
| 45 |
+
|
| 46 |
+
self.attn = attn_class(
|
| 47 |
+
dim,
|
| 48 |
+
num_heads=num_heads,
|
| 49 |
+
qkv_bias=qkv_bias,
|
| 50 |
+
proj_bias=proj_bias,
|
| 51 |
+
attn_drop=attn_drop,
|
| 52 |
+
proj_drop=drop,
|
| 53 |
+
qk_norm=qk_norm,
|
| 54 |
+
fused_attn=fused_attn,
|
| 55 |
+
rope=rope,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 59 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 60 |
+
|
| 61 |
+
self.norm2 = norm_layer(dim)
|
| 62 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 63 |
+
self.mlp = ffn_layer(
|
| 64 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
| 65 |
+
)
|
| 66 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 67 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 68 |
+
|
| 69 |
+
self.sample_drop_ratio = drop_path
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 73 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 74 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
| 75 |
+
|
| 76 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 77 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 78 |
+
|
| 79 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 80 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 81 |
+
x = drop_add_residual_stochastic_depth(
|
| 82 |
+
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 83 |
+
)
|
| 84 |
+
x = drop_add_residual_stochastic_depth(
|
| 85 |
+
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 86 |
+
)
|
| 87 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 88 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 89 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 90 |
+
else:
|
| 91 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 92 |
+
x = x + ffn_residual_func(x)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
class DistBlock(Block):
|
| 96 |
+
def __init__(self, *args, attn_class: Callable[..., nn.Module] = DistAttention, **kwargs):
|
| 97 |
+
super().__init__(*args, attn_class=attn_class, **kwargs)
|
| 98 |
+
|
| 99 |
+
def forward(self, x: Tensor, pos=None, sp_size=1,sp_group=None,padding_tokens=0,block_type = None, token_shape=None) -> Tensor:
|
| 100 |
+
def attn_residual_func(x: Tensor, pos=None, sp_size=1,sp_group=None,padding_tokens=0) -> Tensor:
|
| 101 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos,sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens))
|
| 102 |
+
|
| 103 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 104 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 105 |
+
|
| 106 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 107 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 108 |
+
x = drop_add_residual_stochastic_depth(
|
| 109 |
+
x, pos=pos, sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens,residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 110 |
+
)
|
| 111 |
+
x = drop_add_residual_stochastic_depth(
|
| 112 |
+
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 113 |
+
)
|
| 114 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 115 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos,sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens))
|
| 116 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 117 |
+
else:
|
| 118 |
+
x = x + attn_residual_func(x, pos=pos,sp_size=sp_size,sp_group=sp_group,padding_tokens=padding_tokens)
|
| 119 |
+
x = x + ffn_residual_func(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def drop_add_residual_stochastic_depth(
|
| 124 |
+
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
| 125 |
+
) -> Tensor:
|
| 126 |
+
# 1) extract subset using permutation
|
| 127 |
+
b, n, d = x.shape
|
| 128 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 129 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 130 |
+
x_subset = x[brange]
|
| 131 |
+
|
| 132 |
+
# 2) apply residual_func to get residual
|
| 133 |
+
if pos is not None:
|
| 134 |
+
# if necessary, apply rope to the subset
|
| 135 |
+
pos = pos[brange]
|
| 136 |
+
residual = residual_func(x_subset, pos=pos)
|
| 137 |
+
else:
|
| 138 |
+
residual = residual_func(x_subset)
|
| 139 |
+
|
| 140 |
+
x_flat = x.flatten(1)
|
| 141 |
+
residual = residual.flatten(1)
|
| 142 |
+
|
| 143 |
+
residual_scale_factor = b / sample_subset_size
|
| 144 |
+
|
| 145 |
+
# 3) add the residual
|
| 146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 147 |
+
return x_plus_residual.view_as(x)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 151 |
+
b, n, d = x.shape
|
| 152 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 153 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 154 |
+
residual_scale_factor = b / sample_subset_size
|
| 155 |
+
return brange, residual_scale_factor
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 159 |
+
if scaling_vector is None:
|
| 160 |
+
x_flat = x.flatten(1)
|
| 161 |
+
residual = residual.flatten(1)
|
| 162 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 163 |
+
else:
|
| 164 |
+
x_plus_residual = scaled_index_add(
|
| 165 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 166 |
+
)
|
| 167 |
+
return x_plus_residual
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 174 |
+
"""
|
| 175 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 176 |
+
"""
|
| 177 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 178 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 179 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 180 |
+
seqlens = []
|
| 181 |
+
for b, x in zip(batch_sizes, x_list):
|
| 182 |
+
for _ in range(b):
|
| 183 |
+
seqlens.append(x.shape[1])
|
| 184 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 185 |
+
attn_bias._batch_sizes = batch_sizes
|
| 186 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 187 |
+
|
| 188 |
+
if branges is not None:
|
| 189 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 190 |
+
else:
|
| 191 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 192 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 193 |
+
|
| 194 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def drop_add_residual_stochastic_depth_list(
|
| 198 |
+
x_list: List[Tensor],
|
| 199 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 200 |
+
sample_drop_ratio: float = 0.0,
|
| 201 |
+
scaling_vector=None,
|
| 202 |
+
) -> Tensor:
|
| 203 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 204 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 205 |
+
branges = [s[0] for s in branges_scales]
|
| 206 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 207 |
+
|
| 208 |
+
# 2) get attention bias and index+concat the tensors
|
| 209 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 210 |
+
|
| 211 |
+
# 3) apply residual_func to get residual, and split the result
|
| 212 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 213 |
+
|
| 214 |
+
outputs = []
|
| 215 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 216 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 217 |
+
return outputs
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class NestedTensorBlock(Block):
|
| 221 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 222 |
+
"""
|
| 223 |
+
x_list contains a list of tensors to nest together and run
|
| 224 |
+
"""
|
| 225 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 226 |
+
|
| 227 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 228 |
+
|
| 229 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 230 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 231 |
+
|
| 232 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 233 |
+
return self.mlp(self.norm2(x))
|
| 234 |
+
|
| 235 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 236 |
+
x_list,
|
| 237 |
+
residual_func=attn_residual_func,
|
| 238 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 239 |
+
scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
|
| 240 |
+
)
|
| 241 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 242 |
+
x_list,
|
| 243 |
+
residual_func=ffn_residual_func,
|
| 244 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 245 |
+
scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
|
| 246 |
+
)
|
| 247 |
+
return x_list
|
| 248 |
+
else:
|
| 249 |
+
|
| 250 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 251 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 252 |
+
|
| 253 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 254 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 255 |
+
|
| 256 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 257 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 258 |
+
x = x + ffn_residual_func(x)
|
| 259 |
+
return attn_bias.split(x)
|
| 260 |
+
|
| 261 |
+
def forward(self, x_or_x_list):
|
| 262 |
+
if isinstance(x_or_x_list, Tensor):
|
| 263 |
+
return super().forward(x_or_x_list)
|
| 264 |
+
elif isinstance(x_or_x_list, list):
|
| 265 |
+
if not XFORMERS_AVAILABLE:
|
| 266 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 267 |
+
return self.forward_nested(x_or_x_list)
|
| 268 |
+
else:
|
| 269 |
+
raise AssertionError
|
hyworldmirror/models/layers/drop_path.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 10 |
+
if drop_prob == 0.0 or not training:
|
| 11 |
+
return x
|
| 12 |
+
keep_prob = 1 - drop_prob
|
| 13 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 14 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 15 |
+
if keep_prob > 0.0:
|
| 16 |
+
random_tensor.div_(keep_prob)
|
| 17 |
+
output = x * random_tensor
|
| 18 |
+
return output
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DropPath(nn.Module):
|
| 22 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, drop_prob=None):
|
| 25 |
+
super(DropPath, self).__init__()
|
| 26 |
+
self.drop_prob = drop_prob
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return drop_path(x, self.drop_prob, self.training)
|
hyworldmirror/models/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 2 |
+
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LayerScale(nn.Module):
|
| 11 |
+
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.inplace = inplace
|
| 14 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 15 |
+
|
| 16 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 17 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
hyworldmirror/models/layers/mlp.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from typing import Callable, Optional
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Mlp(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
in_features: int,
|
| 15 |
+
hidden_features: Optional[int] = None,
|
| 16 |
+
out_features: Optional[int] = None,
|
| 17 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 18 |
+
drop: float = 0.0,
|
| 19 |
+
bias: bool = True,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
out_features = out_features or in_features
|
| 23 |
+
hidden_features = hidden_features or in_features
|
| 24 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 25 |
+
self.act = act_layer()
|
| 26 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 27 |
+
self.drop = nn.Dropout(drop)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x = self.fc1(x)
|
| 31 |
+
x = self.act(x)
|
| 32 |
+
x = self.drop(x)
|
| 33 |
+
x = self.fc2(x)
|
| 34 |
+
x = self.drop(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MlpFP32(Mlp):
|
| 39 |
+
@staticmethod
|
| 40 |
+
def map_to_args_to_float(args, kwargs):
|
| 41 |
+
args = tuple(
|
| 42 |
+
torch.float32 if isinstance(arg, torch.dtype) else arg
|
| 43 |
+
for arg in args
|
| 44 |
+
)
|
| 45 |
+
kwargs = dict(kwargs)
|
| 46 |
+
for key in kwargs:
|
| 47 |
+
if key == "dtype":
|
| 48 |
+
kwargs[key] = torch.float32
|
| 49 |
+
return args, kwargs
|
| 50 |
+
|
| 51 |
+
def to(self, *args, **kwargs):
|
| 52 |
+
self.fc1 = self.fc1.to(*args, **kwargs)
|
| 53 |
+
args, kwargs = self.map_to_args_to_float(args, kwargs)
|
| 54 |
+
self.fc2 = self.fc2.to(*args, **kwargs)
|
| 55 |
+
return self
|
| 56 |
+
|
| 57 |
+
def forward_infer(self, x):
|
| 58 |
+
x = self.fc1(x)
|
| 59 |
+
x = 0.5 * x * (1 + torch.erf(x * 2**-0.5))
|
| 60 |
+
x = self.fc2(x.float())
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 64 |
+
return self.forward_infer(x)
|
hyworldmirror/models/layers/norm_rope.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Dict, Literal, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PositionGetter:
|
| 11 |
+
"""Generates and caches 2D spatial positions for patches in a grid."""
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 15 |
+
|
| 16 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 17 |
+
if (height, width) not in self.position_cache:
|
| 18 |
+
y_coords = torch.arange(height, device=device)
|
| 19 |
+
x_coords = torch.arange(width, device=device)
|
| 20 |
+
self.position_cache[height, width] = torch.cartesian_prod(y_coords, x_coords)
|
| 21 |
+
|
| 22 |
+
cached_positions = self.position_cache[height, width]
|
| 23 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 28 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class NormalizedRotaryPositionEmbedding2D(nn.Module):
|
| 32 |
+
"""DINOv3-aligned 2D Rotary Position Embedding."""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
*,
|
| 37 |
+
head_dim: int,
|
| 38 |
+
base: float = 100.0,
|
| 39 |
+
normalize_coords: Literal["min", "max", "separate"] = "separate",
|
| 40 |
+
shift_coords: Union[float, None] = None,
|
| 41 |
+
jitter_coords: Union[float, None] = None,
|
| 42 |
+
rescale_coords: Union[float, None] = None,
|
| 43 |
+
dtype: Union[torch.dtype, None] = None,
|
| 44 |
+
device: Union[torch.device, None] = None,
|
| 45 |
+
**ignored_kwargs,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
if len(ignored_kwargs) > 0:
|
| 49 |
+
# maintain parity with DINOv3 implementation that warns on ignored kwargs
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
if head_dim % 4 != 0:
|
| 53 |
+
raise ValueError("head_dim must be divisible by 4 for 2D RoPE")
|
| 54 |
+
|
| 55 |
+
self.head_dim = head_dim
|
| 56 |
+
self.base = base
|
| 57 |
+
self.normalize_coords = normalize_coords
|
| 58 |
+
self.shift_coords = shift_coords
|
| 59 |
+
self.jitter_coords = jitter_coords
|
| 60 |
+
self.rescale_coords = rescale_coords
|
| 61 |
+
self.dtype = dtype
|
| 62 |
+
|
| 63 |
+
quarter_dim = head_dim // 4
|
| 64 |
+
self.register_buffer(
|
| 65 |
+
"periods",
|
| 66 |
+
torch.empty(quarter_dim, device=device, dtype=dtype),
|
| 67 |
+
persistent=True,
|
| 68 |
+
)
|
| 69 |
+
self._init_periods()
|
| 70 |
+
|
| 71 |
+
def _init_periods(self) -> None:
|
| 72 |
+
quarter_dim = self.periods.shape[0]
|
| 73 |
+
half_dim = self.head_dim // 2
|
| 74 |
+
exponents = 2 * torch.arange(quarter_dim, device=self.periods.device, dtype=self.dtype) / half_dim
|
| 75 |
+
periods = self.base ** exponents
|
| 76 |
+
self.periods.data.copy_(periods)
|
| 77 |
+
|
| 78 |
+
def _get_sincos_for_grid(self, H: int, W: int, device: torch.device, dtype: torch.dtype) -> Tuple[Tensor, Tensor]:
|
| 79 |
+
dd = {"device": device, "dtype": dtype}
|
| 80 |
+
|
| 81 |
+
if self.normalize_coords == "max":
|
| 82 |
+
max_hw = max(H, W)
|
| 83 |
+
coords_h = torch.arange(0.5, H, **dd) / max_hw
|
| 84 |
+
coords_w = torch.arange(0.5, W, **dd) / max_hw
|
| 85 |
+
elif self.normalize_coords == "min":
|
| 86 |
+
min_hw = min(H, W)
|
| 87 |
+
coords_h = torch.arange(0.5, H, **dd) / min_hw
|
| 88 |
+
coords_w = torch.arange(0.5, W, **dd) / min_hw
|
| 89 |
+
elif self.normalize_coords == "separate":
|
| 90 |
+
coords_h = torch.arange(0.5, H, **dd) / H
|
| 91 |
+
coords_w = torch.arange(0.5, W, **dd) / W
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
|
| 94 |
+
|
| 95 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
|
| 96 |
+
coords = coords.flatten(0, 1) # [HW, 2]
|
| 97 |
+
coords = 2.0 * coords - 1.0
|
| 98 |
+
|
| 99 |
+
if self.training:
|
| 100 |
+
if self.shift_coords is not None:
|
| 101 |
+
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
|
| 102 |
+
coords += shift_hw[None, :]
|
| 103 |
+
if self.jitter_coords is not None:
|
| 104 |
+
jitter_max = np.log(self.jitter_coords)
|
| 105 |
+
jitter_hw = torch.empty(2, **dd).uniform_(-jitter_max, jitter_max).exp()
|
| 106 |
+
coords *= jitter_hw[None, :]
|
| 107 |
+
if self.rescale_coords is not None:
|
| 108 |
+
rescale_max = np.log(self.rescale_coords)
|
| 109 |
+
rescale_hw = torch.empty(1, **dd).uniform_(-rescale_max, rescale_max).exp()
|
| 110 |
+
coords *= rescale_hw
|
| 111 |
+
|
| 112 |
+
periods = self.periods.to(device=device, dtype=dtype)
|
| 113 |
+
angles = (2 * math.pi * coords[:, :, None]) / periods[None, None, :] # [HW, 2, D/4]
|
| 114 |
+
angles = angles.flatten(1, 2) # [HW, D/2]
|
| 115 |
+
angles = torch.cat((angles, angles), dim=-1) # [HW, D]
|
| 116 |
+
|
| 117 |
+
cos = torch.cos(angles)
|
| 118 |
+
sin = torch.sin(angles)
|
| 119 |
+
return sin, cos
|
| 120 |
+
|
| 121 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
# Validate inputs
|
| 123 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 124 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
| 125 |
+
|
| 126 |
+
B, _, N, C_head = tokens.shape
|
| 127 |
+
if C_head != self.head_dim:
|
| 128 |
+
raise ValueError(f"Head dim {C_head} doesn't match configured {self.head_dim}")
|
| 129 |
+
|
| 130 |
+
H = int(positions[..., 0].max().item() + 1)
|
| 131 |
+
W = int(positions[..., 1].max().item() + 1)
|
| 132 |
+
|
| 133 |
+
sin, cos = self._get_sincos_for_grid(H, W, tokens.device, tokens.dtype)
|
| 134 |
+
|
| 135 |
+
indices = (positions[..., 0] * W + positions[..., 1]).long()
|
| 136 |
+
flat_indices = indices.view(-1)
|
| 137 |
+
gathered_sin = sin[flat_indices].view(B, 1, N, C_head)
|
| 138 |
+
gathered_cos = cos[flat_indices].view(B, 1, N, C_head)
|
| 139 |
+
return (tokens * gathered_cos) + (_rotate_half(tokens) * gathered_sin)
|
| 140 |
+
|
hyworldmirror/models/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 4 |
+
|
| 5 |
+
from typing import Callable, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from itertools import repeat
|
| 12 |
+
import collections.abc
|
| 13 |
+
|
| 14 |
+
def make_2tuple(x):
|
| 15 |
+
if isinstance(x, tuple):
|
| 16 |
+
assert len(x) == 2
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
assert isinstance(x, int)
|
| 20 |
+
return (x, x)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PatchEmbed(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
img_size: Image size.
|
| 29 |
+
patch_size: Patch token size.
|
| 30 |
+
in_chans: Number of input image channels.
|
| 31 |
+
embed_dim: Number of linear projection output channels.
|
| 32 |
+
norm_layer: Normalization layer.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 38 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 39 |
+
in_chans: int = 3,
|
| 40 |
+
embed_dim: int = 768,
|
| 41 |
+
norm_layer: Optional[Callable] = None,
|
| 42 |
+
flatten_embedding: bool = True,
|
| 43 |
+
) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
image_HW = make_2tuple(img_size)
|
| 47 |
+
patch_HW = make_2tuple(patch_size)
|
| 48 |
+
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
|
| 49 |
+
|
| 50 |
+
self.img_size = image_HW
|
| 51 |
+
self.patch_size = patch_HW
|
| 52 |
+
self.patches_resolution = patch_grid_size
|
| 53 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 54 |
+
|
| 55 |
+
self.in_chans = in_chans
|
| 56 |
+
self.embed_dim = embed_dim
|
| 57 |
+
|
| 58 |
+
self.flatten_embedding = flatten_embedding
|
| 59 |
+
|
| 60 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 61 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 62 |
+
|
| 63 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 64 |
+
_, _, H, W = x.shape
|
| 65 |
+
patch_H, patch_W = self.patch_size
|
| 66 |
+
|
| 67 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 68 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 69 |
+
|
| 70 |
+
x = self.proj(x) # B C H W
|
| 71 |
+
H, W = x.size(2), x.size(3)
|
| 72 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 73 |
+
x = self.norm(x)
|
| 74 |
+
if not self.flatten_embedding:
|
| 75 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class PatchEmbed_Mlp(PatchEmbed):
|
| 80 |
+
def __init__(self, img_size=224,
|
| 81 |
+
patch_size=16,
|
| 82 |
+
in_chans=3,
|
| 83 |
+
embed_dim=768,
|
| 84 |
+
norm_layer=None,
|
| 85 |
+
flatten_embedding=True):
|
| 86 |
+
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten_embedding)
|
| 87 |
+
|
| 88 |
+
self.proj = nn.Sequential(
|
| 89 |
+
PixelUnshuffle(patch_size),
|
| 90 |
+
Permute((0,2,3,1)),
|
| 91 |
+
Mlp(in_chans * patch_size**2, 4*embed_dim, embed_dim),
|
| 92 |
+
Permute((0,3,1,2)),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class PixelUnshuffle (nn.Module):
|
| 97 |
+
def __init__(self, downscale_factor):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.downscale_factor = downscale_factor
|
| 100 |
+
|
| 101 |
+
def forward(self, input):
|
| 102 |
+
if input.numel() == 0:
|
| 103 |
+
# this is not in the original torch implementation
|
| 104 |
+
C,H,W = input.shape[-3:]
|
| 105 |
+
assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
|
| 106 |
+
return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
|
| 107 |
+
else:
|
| 108 |
+
return F.pixel_unshuffle(input, self.downscale_factor)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Permute(torch.nn.Module):
|
| 112 |
+
dims: tuple[int, ...]
|
| 113 |
+
def __init__(self, dims: tuple[int, ...]) -> None:
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.dims = tuple(dims)
|
| 116 |
+
|
| 117 |
+
def __repr__(self):
|
| 118 |
+
return f"Permute{self.dims}"
|
| 119 |
+
|
| 120 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
return input.permute(*self.dims)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _ntuple(n):
|
| 125 |
+
def parse(x):
|
| 126 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 127 |
+
return x
|
| 128 |
+
return tuple(repeat(x, n))
|
| 129 |
+
return parse
|
| 130 |
+
to_2tuple = _ntuple(2)
|
| 131 |
+
|
| 132 |
+
class Mlp(nn.Module):
|
| 133 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 134 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
|
| 135 |
+
super().__init__()
|
| 136 |
+
out_features = out_features or in_features
|
| 137 |
+
hidden_features = hidden_features or in_features
|
| 138 |
+
bias = to_2tuple(bias)
|
| 139 |
+
drop_probs = to_2tuple(drop)
|
| 140 |
+
|
| 141 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
| 142 |
+
self.act = act_layer()
|
| 143 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 144 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
| 145 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
x = self.fc1(x)
|
| 149 |
+
x = self.act(x)
|
| 150 |
+
x = self.drop1(x)
|
| 151 |
+
x = self.fc2(x)
|
| 152 |
+
x = self.drop2(x)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
hyworldmirror/models/layers/rope.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
| 2 |
+
|
| 3 |
+
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
| 4 |
+
# which extends the original RoPE concept to handle 2D spatial positions.
|
| 5 |
+
|
| 6 |
+
# Inspired by:
|
| 7 |
+
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 8 |
+
# https://github.com/naver-ai/rope-vit
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from typing import Dict, Tuple
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PositionGetter:
|
| 19 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
| 20 |
+
|
| 21 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
| 22 |
+
in a 2D grid, caching results to avoid redundant computations.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
| 26 |
+
grid dimensions.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
"""Initializes the position generator with an empty cache."""
|
| 31 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 32 |
+
|
| 33 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 34 |
+
"""Generates spatial positions for a batch of patches.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
batch_size: Number of samples in the batch.
|
| 38 |
+
height: Height of the grid in patches.
|
| 39 |
+
width: Width of the grid in patches.
|
| 40 |
+
device: Target device for the position tensor.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
| 44 |
+
for each position in the grid, repeated for each batch item.
|
| 45 |
+
"""
|
| 46 |
+
if (height, width) not in self.position_cache:
|
| 47 |
+
y_coords = torch.arange(height, device=device)
|
| 48 |
+
x_coords = torch.arange(width, device=device)
|
| 49 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
| 50 |
+
self.position_cache[height, width] = positions
|
| 51 |
+
|
| 52 |
+
cached_positions = self.position_cache[height, width]
|
| 53 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
| 57 |
+
"""2D Rotary Position Embedding implementation.
|
| 58 |
+
|
| 59 |
+
This module applies rotary position embeddings to input tokens based on their
|
| 60 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
| 61 |
+
separately for vertical and horizontal dimensions.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
| 65 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
| 66 |
+
|
| 67 |
+
Attributes:
|
| 68 |
+
base_frequency: Base frequency for computing position embeddings.
|
| 69 |
+
scaling_factor: Factor to scale the computed frequencies.
|
| 70 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0,):
|
| 74 |
+
"""Initializes the 2D RoPE module."""
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.base_frequency = frequency
|
| 77 |
+
self.scaling_factor = scaling_factor
|
| 78 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 79 |
+
|
| 80 |
+
def _compute_frequency_components(
|
| 81 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 82 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 83 |
+
"""Computes frequency components for rotary embeddings.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
dim: Feature dimension (must be even).
|
| 87 |
+
seq_len: Maximum sequence length.
|
| 88 |
+
device: Target device for computations.
|
| 89 |
+
dtype: Data type for the computed tensors.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
| 93 |
+
"""
|
| 94 |
+
cache_key = (dim, seq_len, device, dtype)
|
| 95 |
+
if cache_key not in self.frequency_cache:
|
| 96 |
+
# Compute frequency bands
|
| 97 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 98 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
| 99 |
+
|
| 100 |
+
# Generate position-dependent frequencies
|
| 101 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 102 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 103 |
+
|
| 104 |
+
# Compute and cache frequency components
|
| 105 |
+
angles = angles.to(dtype)
|
| 106 |
+
angles = torch.cat((angles, angles), dim=-1)
|
| 107 |
+
cos_components = angles.cos().to(dtype)
|
| 108 |
+
sin_components = angles.sin().to(dtype)
|
| 109 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
| 110 |
+
|
| 111 |
+
return self.frequency_cache[cache_key]
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
x: Input tensor to rotate.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Rotated feature tensor.
|
| 122 |
+
"""
|
| 123 |
+
feature_dim = x.shape[-1]
|
| 124 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
| 125 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 126 |
+
|
| 127 |
+
def _apply_1d_rope(
|
| 128 |
+
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
tokens: Input token features.
|
| 134 |
+
positions: Position indices.
|
| 135 |
+
cos_comp: Cosine components for rotation.
|
| 136 |
+
sin_comp: Sine components for rotation.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Tokens with applied rotary position embeddings.
|
| 140 |
+
"""
|
| 141 |
+
# Embed positions with frequency components
|
| 142 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
| 143 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
| 144 |
+
|
| 145 |
+
# Apply rotation
|
| 146 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
| 147 |
+
|
| 148 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
| 153 |
+
The feature dimension (dim) must be divisible by 4.
|
| 154 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
| 155 |
+
the y and x coordinates for each token.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
| 159 |
+
|
| 160 |
+
Raises:
|
| 161 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
| 162 |
+
"""
|
| 163 |
+
# Validate inputs
|
| 164 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 165 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
| 166 |
+
|
| 167 |
+
# Compute feature dimension for each spatial direction
|
| 168 |
+
feature_dim = tokens.size(-1) // 2
|
| 169 |
+
|
| 170 |
+
# Get frequency components
|
| 171 |
+
max_position = int(positions.max()) + 1
|
| 172 |
+
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
| 173 |
+
|
| 174 |
+
# Split features for vertical and horizontal processing
|
| 175 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
| 176 |
+
|
| 177 |
+
# Apply RoPE separately for each dimension
|
| 178 |
+
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
| 179 |
+
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
| 180 |
+
|
| 181 |
+
# Combine processed features
|
| 182 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
hyworldmirror/models/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional
|
| 2 |
+
|
| 3 |
+
from torch import Tensor, nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SwiGLUFFN(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
in_features: int,
|
| 11 |
+
hidden_features: Optional[int] = None,
|
| 12 |
+
out_features: Optional[int] = None,
|
| 13 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 14 |
+
drop: float = 0.0,
|
| 15 |
+
bias: bool = True,
|
| 16 |
+
) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
out_features = out_features or in_features
|
| 19 |
+
hidden_features = hidden_features or in_features
|
| 20 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 21 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 24 |
+
x12 = self.w12(x)
|
| 25 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 26 |
+
hidden = F.silu(x1) * x2
|
| 27 |
+
return self.w3(hidden)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
SwiGLU = SwiGLUFFN
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
in_features: int,
|
| 37 |
+
hidden_features: Optional[int] = None,
|
| 38 |
+
out_features: Optional[int] = None,
|
| 39 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 40 |
+
drop: float = 0.0,
|
| 41 |
+
bias: bool = True,
|
| 42 |
+
) -> None:
|
| 43 |
+
out_features = out_features or in_features
|
| 44 |
+
hidden_features = hidden_features or in_features
|
| 45 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 46 |
+
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
|
hyworldmirror/models/layers/vision_transformer.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 4 |
+
|
| 5 |
+
from functools import partial
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.utils.checkpoint import checkpoint
|
| 13 |
+
from torch.nn.init import trunc_normal_
|
| 14 |
+
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 20 |
+
if not depth_first and include_root:
|
| 21 |
+
fn(module=module, name=name)
|
| 22 |
+
for child_name, child_module in module.named_children():
|
| 23 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 24 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 25 |
+
if depth_first and include_root:
|
| 26 |
+
fn(module=module, name=name)
|
| 27 |
+
return module
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BlockChunk(nn.ModuleList):
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
for b in self:
|
| 33 |
+
x = b(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DinoVisionTransformer(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size=224,
|
| 41 |
+
patch_size=16,
|
| 42 |
+
in_chans=3,
|
| 43 |
+
embed_dim=768,
|
| 44 |
+
depth=12,
|
| 45 |
+
num_heads=12,
|
| 46 |
+
mlp_ratio=4.0,
|
| 47 |
+
qkv_bias=True,
|
| 48 |
+
ffn_bias=True,
|
| 49 |
+
proj_bias=True,
|
| 50 |
+
drop_path_rate=0.0,
|
| 51 |
+
drop_path_uniform=False,
|
| 52 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 53 |
+
embed_layer=PatchEmbed,
|
| 54 |
+
act_layer=nn.GELU,
|
| 55 |
+
block_fn=Block,
|
| 56 |
+
ffn_layer="mlp",
|
| 57 |
+
block_chunks=1,
|
| 58 |
+
num_register_tokens=0,
|
| 59 |
+
interpolate_antialias=False,
|
| 60 |
+
interpolate_offset=0.1,
|
| 61 |
+
qk_norm=False,
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
img_size (int, tuple): input image size
|
| 66 |
+
patch_size (int, tuple): patch size
|
| 67 |
+
in_chans (int): number of input channels
|
| 68 |
+
embed_dim (int): embedding dimension
|
| 69 |
+
depth (int): depth of transformer
|
| 70 |
+
num_heads (int): number of attention heads
|
| 71 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 72 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 73 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 74 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 75 |
+
drop_path_rate (float): stochastic depth rate
|
| 76 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 77 |
+
weight_init (str): weight init scheme
|
| 78 |
+
init_values (float): layer-scale init values
|
| 79 |
+
embed_layer (nn.Module): patch embedding layer
|
| 80 |
+
act_layer (nn.Module): MLP activation layer
|
| 81 |
+
block_fn (nn.Module): transformer block class
|
| 82 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 83 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 84 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 85 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 86 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 87 |
+
"""
|
| 88 |
+
super().__init__()
|
| 89 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 90 |
+
|
| 91 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 92 |
+
self.num_tokens = 1
|
| 93 |
+
self.n_blocks = depth
|
| 94 |
+
self.num_heads = num_heads
|
| 95 |
+
self.patch_size = patch_size
|
| 96 |
+
self.num_register_tokens = num_register_tokens
|
| 97 |
+
self.interpolate_antialias = interpolate_antialias
|
| 98 |
+
self.interpolate_offset = interpolate_offset
|
| 99 |
+
self.use_reentrant = False # hardcoded to False
|
| 100 |
+
|
| 101 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 102 |
+
num_patches = self.patch_embed.num_patches
|
| 103 |
+
|
| 104 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 105 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 106 |
+
assert num_register_tokens >= 0
|
| 107 |
+
self.register_tokens = (
|
| 108 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if drop_path_uniform is True:
|
| 112 |
+
dpr = [drop_path_rate] * depth
|
| 113 |
+
else:
|
| 114 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 115 |
+
|
| 116 |
+
if ffn_layer == "mlp":
|
| 117 |
+
log.info("using MLP layer as FFN")
|
| 118 |
+
ffn_layer = Mlp
|
| 119 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 120 |
+
log.info("using SwiGLU layer as FFN")
|
| 121 |
+
ffn_layer = SwiGLUFFNFused
|
| 122 |
+
elif ffn_layer == "identity":
|
| 123 |
+
log.info("using Identity layer as FFN")
|
| 124 |
+
|
| 125 |
+
def f(*args, **kwargs):
|
| 126 |
+
return nn.Identity()
|
| 127 |
+
|
| 128 |
+
ffn_layer = f
|
| 129 |
+
else:
|
| 130 |
+
raise NotImplementedError
|
| 131 |
+
|
| 132 |
+
blocks_list = [
|
| 133 |
+
block_fn(
|
| 134 |
+
dim=embed_dim,
|
| 135 |
+
num_heads=num_heads,
|
| 136 |
+
mlp_ratio=mlp_ratio,
|
| 137 |
+
qkv_bias=qkv_bias,
|
| 138 |
+
proj_bias=proj_bias,
|
| 139 |
+
ffn_bias=ffn_bias,
|
| 140 |
+
drop_path=dpr[i],
|
| 141 |
+
norm_layer=norm_layer,
|
| 142 |
+
act_layer=act_layer,
|
| 143 |
+
ffn_layer=ffn_layer,
|
| 144 |
+
init_values=init_values,
|
| 145 |
+
qk_norm=qk_norm,
|
| 146 |
+
)
|
| 147 |
+
for i in range(depth)
|
| 148 |
+
]
|
| 149 |
+
if block_chunks > 0:
|
| 150 |
+
self.chunked_blocks = True
|
| 151 |
+
chunked_blocks = []
|
| 152 |
+
chunksize = depth // block_chunks
|
| 153 |
+
for i in range(0, depth, chunksize):
|
| 154 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 155 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 156 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 157 |
+
else:
|
| 158 |
+
self.chunked_blocks = False
|
| 159 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 160 |
+
|
| 161 |
+
self.norm = norm_layer(embed_dim)
|
| 162 |
+
self.head = nn.Identity()
|
| 163 |
+
|
| 164 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 165 |
+
|
| 166 |
+
self.init_weights()
|
| 167 |
+
|
| 168 |
+
def init_weights(self):
|
| 169 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 170 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 171 |
+
if self.register_tokens is not None:
|
| 172 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 173 |
+
named_apply(init_weights_vit_timm, self)
|
| 174 |
+
|
| 175 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 176 |
+
previous_dtype = x.dtype
|
| 177 |
+
npatch = x.shape[1] - 1
|
| 178 |
+
N = self.pos_embed.shape[1] - 1
|
| 179 |
+
if npatch == N and w == h:
|
| 180 |
+
return self.pos_embed
|
| 181 |
+
pos_embed = self.pos_embed.float()
|
| 182 |
+
class_pos_embed = pos_embed[:, 0]
|
| 183 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 184 |
+
dim = x.shape[-1]
|
| 185 |
+
w0 = w // self.patch_size
|
| 186 |
+
h0 = h // self.patch_size
|
| 187 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 188 |
+
assert N == M * M
|
| 189 |
+
kwargs = {}
|
| 190 |
+
if self.interpolate_offset:
|
| 191 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 192 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 193 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 194 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 195 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 196 |
+
else:
|
| 197 |
+
# Simply specify an output size instead of a scale factor
|
| 198 |
+
kwargs["size"] = (w0, h0)
|
| 199 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 200 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 201 |
+
mode="bicubic",
|
| 202 |
+
antialias=self.interpolate_antialias,
|
| 203 |
+
**kwargs,
|
| 204 |
+
)
|
| 205 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 206 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 207 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 208 |
+
|
| 209 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 210 |
+
B, nc, w, h = x.shape
|
| 211 |
+
x = self.patch_embed(x)
|
| 212 |
+
if masks is not None:
|
| 213 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 214 |
+
|
| 215 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 216 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 217 |
+
|
| 218 |
+
if self.register_tokens is not None:
|
| 219 |
+
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
|
| 220 |
+
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
def forward_features_list(self, x_list, masks_list):
|
| 224 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 225 |
+
|
| 226 |
+
for blk in self.blocks:
|
| 227 |
+
if self.training:
|
| 228 |
+
# x = blk(x)
|
| 229 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 230 |
+
else:
|
| 231 |
+
x = blk(x)
|
| 232 |
+
|
| 233 |
+
all_x = x
|
| 234 |
+
output = []
|
| 235 |
+
for x, masks in zip(all_x, masks_list):
|
| 236 |
+
x_norm = self.norm(x)
|
| 237 |
+
output.append(
|
| 238 |
+
{
|
| 239 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 240 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 241 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 242 |
+
"x_prenorm": x,
|
| 243 |
+
"masks": masks,
|
| 244 |
+
}
|
| 245 |
+
)
|
| 246 |
+
return output
|
| 247 |
+
|
| 248 |
+
def forward_features(self, x, masks=None):
|
| 249 |
+
if isinstance(x, list):
|
| 250 |
+
return self.forward_features_list(x, masks)
|
| 251 |
+
|
| 252 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 253 |
+
|
| 254 |
+
for blk in self.blocks:
|
| 255 |
+
if self.training:
|
| 256 |
+
# x = blk(x)
|
| 257 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 258 |
+
else:
|
| 259 |
+
x = blk(x)
|
| 260 |
+
|
| 261 |
+
x_norm = self.norm(x)
|
| 262 |
+
return {
|
| 263 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 264 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 265 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 266 |
+
"x_prenorm": x,
|
| 267 |
+
"masks": masks,
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 271 |
+
x = self.prepare_tokens_with_masks(x)
|
| 272 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 273 |
+
output, total_block_len = [], len(self.blocks)
|
| 274 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 275 |
+
for i, blk in enumerate(self.blocks):
|
| 276 |
+
x = blk(x)
|
| 277 |
+
if i in blocks_to_take:
|
| 278 |
+
output.append(x)
|
| 279 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 280 |
+
return output
|
| 281 |
+
|
| 282 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 283 |
+
x = self.prepare_tokens_with_masks(x)
|
| 284 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 285 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 286 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 287 |
+
for block_chunk in self.blocks:
|
| 288 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 289 |
+
x = blk(x)
|
| 290 |
+
if i in blocks_to_take:
|
| 291 |
+
output.append(x)
|
| 292 |
+
i += 1
|
| 293 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 294 |
+
return output
|
| 295 |
+
|
| 296 |
+
def get_intermediate_layers(
|
| 297 |
+
self,
|
| 298 |
+
x: torch.Tensor,
|
| 299 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 300 |
+
reshape: bool = False,
|
| 301 |
+
return_class_token: bool = False,
|
| 302 |
+
norm=True,
|
| 303 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 304 |
+
if self.chunked_blocks:
|
| 305 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 306 |
+
else:
|
| 307 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 308 |
+
if norm:
|
| 309 |
+
outputs = [self.norm(out) for out in outputs]
|
| 310 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 311 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 312 |
+
if reshape:
|
| 313 |
+
B, _, w, h = x.shape
|
| 314 |
+
outputs = [
|
| 315 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 316 |
+
for out in outputs
|
| 317 |
+
]
|
| 318 |
+
if return_class_token:
|
| 319 |
+
return tuple(zip(outputs, class_tokens))
|
| 320 |
+
return tuple(outputs)
|
| 321 |
+
|
| 322 |
+
def forward(self, *args, is_training=True, **kwargs):
|
| 323 |
+
ret = self.forward_features(*args, **kwargs)
|
| 324 |
+
if is_training:
|
| 325 |
+
return ret
|
| 326 |
+
else:
|
| 327 |
+
return self.head(ret["x_norm_clstoken"])
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 331 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 332 |
+
if isinstance(module, nn.Linear):
|
| 333 |
+
trunc_normal_(module.weight, std=0.02)
|
| 334 |
+
if module.bias is not None:
|
| 335 |
+
nn.init.zeros_(module.bias)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 339 |
+
model = DinoVisionTransformer(
|
| 340 |
+
patch_size=patch_size,
|
| 341 |
+
embed_dim=384,
|
| 342 |
+
depth=12,
|
| 343 |
+
num_heads=6,
|
| 344 |
+
mlp_ratio=4,
|
| 345 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 346 |
+
num_register_tokens=num_register_tokens,
|
| 347 |
+
**kwargs,
|
| 348 |
+
)
|
| 349 |
+
return model
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 353 |
+
model = DinoVisionTransformer(
|
| 354 |
+
patch_size=patch_size,
|
| 355 |
+
embed_dim=768,
|
| 356 |
+
depth=12,
|
| 357 |
+
num_heads=12,
|
| 358 |
+
mlp_ratio=4,
|
| 359 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 360 |
+
num_register_tokens=num_register_tokens,
|
| 361 |
+
**kwargs,
|
| 362 |
+
)
|
| 363 |
+
return model
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 367 |
+
model = DinoVisionTransformer(
|
| 368 |
+
patch_size=patch_size,
|
| 369 |
+
embed_dim=1024,
|
| 370 |
+
depth=24,
|
| 371 |
+
num_heads=16,
|
| 372 |
+
mlp_ratio=4,
|
| 373 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 374 |
+
num_register_tokens=num_register_tokens,
|
| 375 |
+
**kwargs,
|
| 376 |
+
)
|
| 377 |
+
return model
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 381 |
+
"""
|
| 382 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 383 |
+
"""
|
| 384 |
+
model = DinoVisionTransformer(
|
| 385 |
+
patch_size=patch_size,
|
| 386 |
+
embed_dim=1536,
|
| 387 |
+
depth=40,
|
| 388 |
+
num_heads=24,
|
| 389 |
+
mlp_ratio=4,
|
| 390 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 391 |
+
num_register_tokens=num_register_tokens,
|
| 392 |
+
**kwargs,
|
| 393 |
+
)
|
| 394 |
+
return model
|
hyworldmirror/models/models/__init__.py
ADDED
|
File without changes
|
hyworldmirror/models/models/rasterization.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
from gsplat.rendering import rasterization
|
| 9 |
+
from gsplat.strategy import DefaultStrategy
|
| 10 |
+
|
| 11 |
+
from ..utils.frustum import calculate_unprojected_mask
|
| 12 |
+
from ..utils.geometry import depth_to_world_coords_points
|
| 13 |
+
from ..utils import sh_utils, act_gs
|
| 14 |
+
|
| 15 |
+
from typing import List
|
| 16 |
+
|
| 17 |
+
class Rasterizer:
|
| 18 |
+
def __init__(self, rasterization_mode="classic", packed=True, abs_grad=True, with_eval3d=False,
|
| 19 |
+
camera_model="pinhole", sparse_grad=False, distributed=False, grad_strategy=DefaultStrategy):
|
| 20 |
+
self.rasterization_mode = rasterization_mode
|
| 21 |
+
self.packed = packed
|
| 22 |
+
self.abs_grad = abs_grad
|
| 23 |
+
self.camera_model = camera_model
|
| 24 |
+
self.sparse_grad = sparse_grad
|
| 25 |
+
self.grad_strategy = grad_strategy
|
| 26 |
+
self.distributed = distributed
|
| 27 |
+
self.with_eval3d = with_eval3d
|
| 28 |
+
|
| 29 |
+
def rasterize_splats(
|
| 30 |
+
self,
|
| 31 |
+
means,
|
| 32 |
+
quats,
|
| 33 |
+
scales,
|
| 34 |
+
opacities,
|
| 35 |
+
colors,
|
| 36 |
+
camtoworlds: Tensor,
|
| 37 |
+
Ks: Tensor,
|
| 38 |
+
width: int,
|
| 39 |
+
height: int,
|
| 40 |
+
**kwargs,
|
| 41 |
+
) -> Tuple[Tensor, Tensor, Dict]:
|
| 42 |
+
render_colors, render_alphas, _ = rasterization(
|
| 43 |
+
means=means,
|
| 44 |
+
quats=quats,
|
| 45 |
+
scales=scales,
|
| 46 |
+
opacities=opacities,
|
| 47 |
+
colors=colors,
|
| 48 |
+
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
|
| 49 |
+
Ks=Ks, # [C, 3, 3]
|
| 50 |
+
width=width,
|
| 51 |
+
height=height,
|
| 52 |
+
packed=self.packed,
|
| 53 |
+
absgrad=(
|
| 54 |
+
self.abs_grad
|
| 55 |
+
if isinstance(self.grad_strategy, DefaultStrategy)
|
| 56 |
+
else False
|
| 57 |
+
),
|
| 58 |
+
sparse_grad=self.sparse_grad,
|
| 59 |
+
rasterize_mode=self.rasterization_mode,
|
| 60 |
+
distributed=self.distributed,
|
| 61 |
+
camera_model=self.camera_model,
|
| 62 |
+
with_eval3d=self.with_eval3d,
|
| 63 |
+
render_mode="RGB+ED",
|
| 64 |
+
**kwargs,
|
| 65 |
+
)
|
| 66 |
+
return render_colors[..., :3], render_colors[..., 3:], render_alphas
|
| 67 |
+
|
| 68 |
+
def rasterize_batches(self, means, quats, scales, opacities, colors, viewmats, Ks, width, height, **kwargs):
|
| 69 |
+
rendered_colors, rendered_depths, rendered_alphas = [], [], []
|
| 70 |
+
batch_size = len(means)
|
| 71 |
+
for i in range(batch_size):
|
| 72 |
+
means_i = means[i] # [N, 4]
|
| 73 |
+
quats_i = quats[i] # [N, 4]
|
| 74 |
+
scales_i = scales[i] # [N, 3]
|
| 75 |
+
opacities_i = opacities[i] # [N,]
|
| 76 |
+
colors_i = colors[i] # [N, 3]
|
| 77 |
+
viewmats_i = viewmats[i] # [V, 4, 4]
|
| 78 |
+
Ks_i = Ks[i] # [V, 3, 3]
|
| 79 |
+
render_colors_i, render_depths_i, render_alphas_i = self.rasterize_splats(
|
| 80 |
+
means_i, quats_i, scales_i, opacities_i, colors_i, viewmats_i, Ks_i, width, height, **kwargs
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
rendered_colors.append(render_colors_i) # V H W 3
|
| 84 |
+
rendered_depths.append(render_depths_i) # V H W 1
|
| 85 |
+
rendered_alphas.append(render_alphas_i) # V H W 1
|
| 86 |
+
|
| 87 |
+
rendered_colors = torch.stack(rendered_colors, dim=0) # B V H W 3
|
| 88 |
+
rendered_depths = torch.stack(rendered_depths, dim=0) # B V H W 1
|
| 89 |
+
rendered_alphas = torch.stack(rendered_alphas, dim=0) # B V H W 1
|
| 90 |
+
|
| 91 |
+
return rendered_colors, rendered_depths, rendered_alphas
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class GaussianSplatRenderer(nn.Module):
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
feature_dim: int = 256, # Output channels of gs_feat_head
|
| 98 |
+
sh_degree: int = 0,
|
| 99 |
+
enable_prune: bool = True,
|
| 100 |
+
voxel_size: float = 0.002, # Default voxel size for prune_gs
|
| 101 |
+
enable_conf_filter: bool = False, # Enable confidence filtering
|
| 102 |
+
conf_threshold_percent: float = 30.0, # Confidence threshold percentage
|
| 103 |
+
max_gaussians: int = 5000000, # Maximum number of Gaussians
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
self.feature_dim = feature_dim
|
| 108 |
+
self.sh_degree = sh_degree
|
| 109 |
+
self.nums_sh = (sh_degree + 1) ** 2
|
| 110 |
+
self.voxel_size = voxel_size
|
| 111 |
+
self.enable_prune = enable_prune
|
| 112 |
+
self.enable_conf_filter = enable_conf_filter
|
| 113 |
+
self.conf_threshold_percent = conf_threshold_percent
|
| 114 |
+
self.max_gaussians = max_gaussians
|
| 115 |
+
|
| 116 |
+
# Predict Gaussian parameters from GS features (quaternions/scales/opacities/SH/weights)
|
| 117 |
+
splits_and_inits = [
|
| 118 |
+
(4, 1.0, 0.0), # quats
|
| 119 |
+
(3, 0.00003, -7.0), # scales
|
| 120 |
+
(1, 1.0, -2.0), # opacities
|
| 121 |
+
(3 * self.nums_sh, 1.0, 0.0), # residual_sh
|
| 122 |
+
(1, 1.0, -2.0), # weights
|
| 123 |
+
]
|
| 124 |
+
gaussian_raw_channels = 4 + 3 + 1 + self.nums_sh * 3 + 1
|
| 125 |
+
|
| 126 |
+
self.gs_head = nn.Sequential(
|
| 127 |
+
nn.Conv2d(feature_dim // 2, feature_dim, kernel_size=3, padding=1, bias=False),
|
| 128 |
+
nn.ReLU(True),
|
| 129 |
+
nn.Conv2d(feature_dim, gaussian_raw_channels, kernel_size=1),
|
| 130 |
+
)
|
| 131 |
+
# Initialize weights and biases of the final layer by segments
|
| 132 |
+
final_conv_layer = self.gs_head[-1]
|
| 133 |
+
start_channels = 0
|
| 134 |
+
for out_channel, s, b in splits_and_inits:
|
| 135 |
+
nn.init.xavier_uniform_(final_conv_layer.weight[start_channels:start_channels+out_channel], s)
|
| 136 |
+
nn.init.constant_(final_conv_layer.bias[start_channels:start_channels+out_channel], b)
|
| 137 |
+
start_channels += out_channel
|
| 138 |
+
|
| 139 |
+
# Rasterizer
|
| 140 |
+
self.rasterizer = Rasterizer()
|
| 141 |
+
|
| 142 |
+
# ======== Main entry point: Complete GS rendering and fill results back to predictions ========
|
| 143 |
+
def render(
|
| 144 |
+
self,
|
| 145 |
+
gs_feats: torch.Tensor, # [B, S, 3, H, W]
|
| 146 |
+
images: torch.Tensor, # [B, S+V, 3, H, W]
|
| 147 |
+
predictions: Dict[str, torch.Tensor], # From WorldMirror: pose/depth/pts3d etc
|
| 148 |
+
views: Dict[str, torch.Tensor],
|
| 149 |
+
context_predictions: Dict[str, torch.Tensor],
|
| 150 |
+
is_inference: bool=True,
|
| 151 |
+
) -> Dict[str, torch.Tensor]:
|
| 152 |
+
"""
|
| 153 |
+
Returns predictions with the following fields filled:
|
| 154 |
+
- rendered_colors / rendered_depths / (rendered_alphas during training)
|
| 155 |
+
- gt_colors / gt_depths / valid_masks
|
| 156 |
+
- splats / rendered_extrinsics / rendered_intrinsics
|
| 157 |
+
"""
|
| 158 |
+
B, _, _, H, W = images.shape
|
| 159 |
+
S = context_predictions.get("imgs", images).shape[1] # context view nums
|
| 160 |
+
V = images.shape[1] - S # target view nums
|
| 161 |
+
|
| 162 |
+
# 1) Predict GS features from tokens, then convert to Gaussian parameters
|
| 163 |
+
gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
|
| 164 |
+
# Align input dtype with gs_head weights (handles fp32 input from
|
| 165 |
+
# precision-critical DPT output_conv2 when model runs in bf16 mode)
|
| 166 |
+
head_dtype = next(self.gs_head.parameters()).dtype
|
| 167 |
+
if gs_feats_reshape.dtype != head_dtype:
|
| 168 |
+
gs_feats_reshape = gs_feats_reshape.to(head_dtype)
|
| 169 |
+
gs_params = self.gs_head(gs_feats_reshape)
|
| 170 |
+
gt_colors = images.permute(0, 1, 3, 4, 2)
|
| 171 |
+
|
| 172 |
+
# 2) Select rendering cameras
|
| 173 |
+
if self.training:
|
| 174 |
+
# Using all gt cameras
|
| 175 |
+
render_viewmats, render_Ks = self.prepare_cameras(views, S + V)
|
| 176 |
+
gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
|
| 177 |
+
gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
|
| 178 |
+
unproject_masks = calculate_unprojected_mask(views, S) # [B, V, H, W]
|
| 179 |
+
valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
|
| 180 |
+
else:
|
| 181 |
+
# Re-predict the camera for novel views and perform translation scale alignment
|
| 182 |
+
pred_all_extrinsic, pred_all_intrinsic = self.prepare_cameras(predictions, S + V)
|
| 183 |
+
scale_factor = torch.ones(B, device=images.device)
|
| 184 |
+
if "camera_poses" in context_predictions:
|
| 185 |
+
pred_context_extrinsic, _ = self.prepare_cameras(context_predictions, S)
|
| 186 |
+
scale_factor = pred_context_extrinsic[:, :, :3, 3].norm(dim=-1).mean(dim=1, keepdim=True) / (
|
| 187 |
+
pred_all_extrinsic[:, :S, :3, 3].norm(dim=-1).mean(dim=1, keepdim=True) + 1e-6
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
pred_all_extrinsic[..., :3, 3] = pred_all_extrinsic[..., :3, 3] * scale_factor.unsqueeze(-1)
|
| 191 |
+
render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
|
| 192 |
+
valid_masks = views.get("valid_mask", torch.ones(B, S + V, H, W, dtype=bool, device=images.device))
|
| 193 |
+
|
| 194 |
+
# 3) Generate splats from gs_params + predictions, and perform voxel merging
|
| 195 |
+
if self.training:
|
| 196 |
+
splats = self.prepare_splats(
|
| 197 |
+
views,
|
| 198 |
+
predictions,
|
| 199 |
+
images,
|
| 200 |
+
gs_params,
|
| 201 |
+
S,
|
| 202 |
+
position_from="gsdepth+gtcamera",
|
| 203 |
+
)
|
| 204 |
+
elif not is_inference:
|
| 205 |
+
splats = self.prepare_splats(
|
| 206 |
+
views,
|
| 207 |
+
predictions,
|
| 208 |
+
images,
|
| 209 |
+
gs_params,
|
| 210 |
+
S,
|
| 211 |
+
context_predictions,
|
| 212 |
+
position_from="gsdepth+predcamera",
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
splats = self.prepare_splats(
|
| 216 |
+
views,
|
| 217 |
+
predictions,
|
| 218 |
+
images,
|
| 219 |
+
gs_params,
|
| 220 |
+
S,
|
| 221 |
+
position_from="gsdepth+predcamera",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if is_inference:
|
| 225 |
+
predictions["splats"] = splats
|
| 226 |
+
return predictions
|
| 227 |
+
|
| 228 |
+
# Apply confidence filtering before pruning
|
| 229 |
+
if self.enable_conf_filter and "gs_depth_conf" in predictions:
|
| 230 |
+
splats = self.apply_confidence_filter(splats, predictions["gs_depth_conf"])
|
| 231 |
+
|
| 232 |
+
if self.enable_prune:
|
| 233 |
+
splats = self.prune_gs(splats, voxel_size=self.voxel_size)
|
| 234 |
+
|
| 235 |
+
predictions["splats"] = splats
|
| 236 |
+
|
| 237 |
+
# 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
|
| 238 |
+
|
| 239 |
+
# Prevent OOM by using chunked rendering
|
| 240 |
+
rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
|
| 241 |
+
chunk_size = 2
|
| 242 |
+
for i in range(0, gt_colors.shape[1], chunk_size):
|
| 243 |
+
end_idx = min(i + chunk_size, gt_colors.shape[1])
|
| 244 |
+
viewmats_i = render_viewmats[:, i:end_idx]
|
| 245 |
+
Ks_i = render_Ks[:, i:end_idx]
|
| 246 |
+
|
| 247 |
+
rendered_colors, rendered_depths, rendered_alphas = self.rasterizer.rasterize_batches(
|
| 248 |
+
splats["means"], splats["quats"], splats["scales"], splats["opacities"],
|
| 249 |
+
splats["sh"] if "sh" in splats else splats["colors"],
|
| 250 |
+
viewmats_i.detach(), Ks_i.detach(),
|
| 251 |
+
width=images.shape[-1], height=images.shape[-2],
|
| 252 |
+
sh_degree=min(self.sh_degree, 0) if "sh" in splats else None,
|
| 253 |
+
)
|
| 254 |
+
rendered_colors_list.append(rendered_colors)
|
| 255 |
+
rendered_depths_list.append(rendered_depths)
|
| 256 |
+
rendered_alphas_list.append(rendered_alphas)
|
| 257 |
+
|
| 258 |
+
rendered_colors = torch.cat(rendered_colors_list, dim=1)
|
| 259 |
+
rendered_depths = torch.cat(rendered_depths_list, dim=1)
|
| 260 |
+
rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
|
| 261 |
+
|
| 262 |
+
if self.training and V > 0:
|
| 263 |
+
nvs_rendered_mask = rendered_alphas[:, S:, ..., 0].detach() > 0.1
|
| 264 |
+
valid_masks[:, S:] = nvs_rendered_mask & valid_masks[:, S:]
|
| 265 |
+
|
| 266 |
+
# 5) return predictions
|
| 267 |
+
predictions["rendered_colors"] = rendered_colors
|
| 268 |
+
predictions["rendered_depths"] = rendered_depths
|
| 269 |
+
predictions["rendered_alphas"] = rendered_alphas
|
| 270 |
+
predictions["gt_colors"] = gt_colors.float()
|
| 271 |
+
predictions["gt_depths"] = views.get("depthmap")
|
| 272 |
+
predictions["valid_masks"] = valid_masks.bool()
|
| 273 |
+
predictions["rendered_extrinsics"] = render_viewmats
|
| 274 |
+
predictions["rendered_intrinsics"] = render_Ks
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
return predictions
|
| 278 |
+
|
| 279 |
+
def apply_confidence_filter(self, splats, gs_depth_conf):
|
| 280 |
+
"""
|
| 281 |
+
Apply confidence filtering to Gaussian splats before pruning.
|
| 282 |
+
Discard bottom p% confidence points, keep top (100-p)%.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
splats: Dictionary containing Gaussian parameters
|
| 286 |
+
gs_depth_conf: Confidence tensor [B, S, H, W]
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Filtered splats dictionary
|
| 290 |
+
"""
|
| 291 |
+
if not self.enable_conf_filter or gs_depth_conf is None:
|
| 292 |
+
return splats
|
| 293 |
+
|
| 294 |
+
device = splats["means"].device
|
| 295 |
+
B, N = splats["means"].shape[:2]
|
| 296 |
+
|
| 297 |
+
# Flatten confidence: [B, S, H, W] -> [B, N]
|
| 298 |
+
conf = gs_depth_conf.flatten(1).to(device)
|
| 299 |
+
# Mask invalid/very small values
|
| 300 |
+
conf = conf.masked_fill(conf <= 1e-5, float("-inf"))
|
| 301 |
+
|
| 302 |
+
# Keep top (100-p)% points, discard bottom p%
|
| 303 |
+
if self.conf_threshold_percent > 0:
|
| 304 |
+
keep_from_percent = int(np.ceil(N * (100.0 - self.conf_threshold_percent) / 100.0))
|
| 305 |
+
else:
|
| 306 |
+
keep_from_percent = N
|
| 307 |
+
K = max(1, min(self.max_gaussians, keep_from_percent))
|
| 308 |
+
|
| 309 |
+
# Select top-K indices for each batch (deterministic, no randomness)
|
| 310 |
+
topk_idx = torch.topk(conf, K, dim=1, largest=True, sorted=False).indices # [B, K]
|
| 311 |
+
|
| 312 |
+
filtered = {}
|
| 313 |
+
mask_keys = ["means", "quats", "scales", "opacities", "sh", "weights"]
|
| 314 |
+
|
| 315 |
+
for key in splats.keys():
|
| 316 |
+
if key in mask_keys and key in splats:
|
| 317 |
+
x = splats[key]
|
| 318 |
+
if x.ndim == 2: # [B, N]
|
| 319 |
+
filtered[key] = torch.gather(x, 1, topk_idx)
|
| 320 |
+
else:
|
| 321 |
+
# Expand indices to match tensor dimensions
|
| 322 |
+
expand_idx = topk_idx.clone()
|
| 323 |
+
for i in range(x.ndim - 2):
|
| 324 |
+
expand_idx = expand_idx.unsqueeze(-1)
|
| 325 |
+
expand_idx = expand_idx.expand(-1, -1, *x.shape[2:])
|
| 326 |
+
filtered[key] = torch.gather(x, 1, expand_idx)
|
| 327 |
+
else:
|
| 328 |
+
filtered[key] = splats[key]
|
| 329 |
+
|
| 330 |
+
return filtered
|
| 331 |
+
|
| 332 |
+
def prune_gs(self, splats, voxel_size=0.002, filter_mask=None):
|
| 333 |
+
"""
|
| 334 |
+
Prune Gaussian splats by optional mask filtering + voxel merging.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
splats: Dictionary containing Gaussian parameters.
|
| 338 |
+
Each value is [B, S*H*W, ...] (batch of per-pixel gaussians).
|
| 339 |
+
voxel_size: Size of voxels for spatial grouping.
|
| 340 |
+
filter_mask: Optional bool tensor [B, S*H*W] or numpy [S, H, W].
|
| 341 |
+
True = keep, False = discard. Applied before voxel merge.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
Dictionary with pruned/merged splats (list-of-tensors per batch).
|
| 345 |
+
"""
|
| 346 |
+
B = splats["means"].shape[0]
|
| 347 |
+
merged_splats_list = []
|
| 348 |
+
device = splats["means"].device
|
| 349 |
+
|
| 350 |
+
for i in range(B):
|
| 351 |
+
# Extract splats for current batch
|
| 352 |
+
splats_i = {k: splats[k][i] for k in ["means", "quats", "scales", "opacities", "sh", "weights"]}
|
| 353 |
+
|
| 354 |
+
# --- Apply filter_mask (discard unwanted gaussians before merge) ---
|
| 355 |
+
if filter_mask is not None:
|
| 356 |
+
if isinstance(filter_mask, np.ndarray):
|
| 357 |
+
fm = torch.from_numpy(filter_mask.reshape(-1)).to(device)
|
| 358 |
+
elif filter_mask.dim() == 3:
|
| 359 |
+
# [S, H, W] -> flatten
|
| 360 |
+
fm = filter_mask.reshape(-1).to(device)
|
| 361 |
+
else:
|
| 362 |
+
fm = filter_mask[i].to(device)
|
| 363 |
+
fm = fm.bool()
|
| 364 |
+
splats_i = {k: v[fm] for k, v in splats_i.items()}
|
| 365 |
+
|
| 366 |
+
N_in = splats_i["means"].shape[0]
|
| 367 |
+
if N_in == 0:
|
| 368 |
+
# All filtered out — push empty tensors
|
| 369 |
+
merged_splats_list.append({
|
| 370 |
+
"means": torch.zeros((0, 3), device=device),
|
| 371 |
+
"quats": torch.zeros((0, 4), device=device),
|
| 372 |
+
"scales": torch.zeros((0, 3), device=device),
|
| 373 |
+
"opacities": torch.zeros(0, device=device),
|
| 374 |
+
"sh": torch.zeros((0, self.nums_sh, 3), device=device),
|
| 375 |
+
})
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
# Compute voxel indices
|
| 379 |
+
coords = splats_i["means"]
|
| 380 |
+
voxel_indices = (coords / voxel_size).floor().long()
|
| 381 |
+
min_indices = voxel_indices.min(dim=0)[0]
|
| 382 |
+
voxel_indices = voxel_indices - min_indices
|
| 383 |
+
max_dims = voxel_indices.max(dim=0)[0] + 1
|
| 384 |
+
|
| 385 |
+
# Flatten 3D voxel indices to 1D
|
| 386 |
+
flat_indices = (voxel_indices[:, 0] * max_dims[1] * max_dims[2] +
|
| 387 |
+
voxel_indices[:, 1] * max_dims[2] +
|
| 388 |
+
voxel_indices[:, 2])
|
| 389 |
+
|
| 390 |
+
# Find unique voxels and inverse mapping
|
| 391 |
+
unique_voxels, inverse_indices = torch.unique(flat_indices, return_inverse=True)
|
| 392 |
+
K = len(unique_voxels)
|
| 393 |
+
|
| 394 |
+
# Initialize merged splats
|
| 395 |
+
merged = {
|
| 396 |
+
"means": torch.zeros((K, 3), device=device),
|
| 397 |
+
"quats": torch.zeros((K, 4), device=device),
|
| 398 |
+
"scales": torch.zeros((K, 3), device=device),
|
| 399 |
+
"opacities": torch.zeros(K, device=device),
|
| 400 |
+
"sh": torch.zeros((K, self.nums_sh, 3), device=device)
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
# Get weights and compute weight sums per voxel
|
| 404 |
+
weights = splats_i["weights"]
|
| 405 |
+
weight_sums = torch.zeros(K, device=device)
|
| 406 |
+
weight_sums.scatter_add_(0, inverse_indices, weights)
|
| 407 |
+
weight_sums = torch.clamp(weight_sums, min=1e-8)
|
| 408 |
+
|
| 409 |
+
# Merge means (weighted average)
|
| 410 |
+
for d in range(3):
|
| 411 |
+
merged["means"][:, d].scatter_add_(0, inverse_indices,
|
| 412 |
+
splats_i["means"][:, d] * weights)
|
| 413 |
+
merged["means"] = merged["means"] / weight_sums.unsqueeze(1)
|
| 414 |
+
|
| 415 |
+
# Merge spherical harmonics (weighted average)
|
| 416 |
+
for d in range(3):
|
| 417 |
+
merged["sh"][:, 0, d].scatter_add_(0, inverse_indices,
|
| 418 |
+
splats_i["sh"][:, 0, d] * weights)
|
| 419 |
+
merged["sh"] = merged["sh"] / weight_sums.unsqueeze(-1).unsqueeze(-1)
|
| 420 |
+
|
| 421 |
+
# Merge opacities (weighted sum of squares)
|
| 422 |
+
merged["opacities"].scatter_add_(0, inverse_indices, weights * weights)
|
| 423 |
+
merged["opacities"] = merged["opacities"] / weight_sums
|
| 424 |
+
|
| 425 |
+
# Merge scales (weighted average)
|
| 426 |
+
for d in range(3):
|
| 427 |
+
merged["scales"][:, d].scatter_add_(0, inverse_indices,
|
| 428 |
+
splats_i["scales"][:, d] * weights)
|
| 429 |
+
merged["scales"] = merged["scales"] / weight_sums.unsqueeze(1)
|
| 430 |
+
|
| 431 |
+
# Merge quaternions (weighted average + normalization)
|
| 432 |
+
for d in range(4):
|
| 433 |
+
merged["quats"][:, d].scatter_add_(0, inverse_indices,
|
| 434 |
+
splats_i["quats"][:, d] * weights)
|
| 435 |
+
quat_norms = torch.norm(merged["quats"], dim=1, keepdim=True)
|
| 436 |
+
merged["quats"] = merged["quats"] / torch.clamp(quat_norms, min=1e-8)
|
| 437 |
+
|
| 438 |
+
merged_splats_list.append(merged)
|
| 439 |
+
|
| 440 |
+
# Reorganize output
|
| 441 |
+
output = {}
|
| 442 |
+
for key in ["means", "sh", "opacities", "scales", "quats"]:
|
| 443 |
+
output[key] = [merged[key] for merged in merged_splats_list]
|
| 444 |
+
|
| 445 |
+
return output
|
| 446 |
+
|
| 447 |
+
def prepare_splats(self, views, predictions, images, gs_params, context_nums,
|
| 448 |
+
context_predictions={}, position_from="gsdepth+gtcamera"):
|
| 449 |
+
"""
|
| 450 |
+
Prepare Gaussian splats from model predictions and input data.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
views: Dictionary containing view data (camera poses, intrinsics, etc.)
|
| 454 |
+
predictions: Model predictions including depth, pose_enc, etc.
|
| 455 |
+
images: Input images [B, S_all, 3, H, W]
|
| 456 |
+
gs_params: Gaussian splatting parameters from model
|
| 457 |
+
context_predictions: Optional context predictions for camera poses
|
| 458 |
+
position_from: Method to compute 3D positions ("pts3d", "gsdepth+gtcamera", "gsdepth+predcamera",
|
| 459 |
+
"depth_head+gtcamera", "depth_head+predcamera")
|
| 460 |
+
debug: Whether to use debug mode with ground truth data
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
splats: Dictionary containing prepared Gaussian splat parameters
|
| 464 |
+
"""
|
| 465 |
+
B, _, _, H, W = images.shape
|
| 466 |
+
S = context_nums
|
| 467 |
+
splats = {}
|
| 468 |
+
|
| 469 |
+
# Only take parameters from source view branch
|
| 470 |
+
gs_params = rearrange(gs_params, "(b s) c h w -> b s h w c", b=B)
|
| 471 |
+
splats["gs_feats"] = gs_params.reshape(B, S*H*W, -1)
|
| 472 |
+
|
| 473 |
+
# Split Gaussian parameters
|
| 474 |
+
quats, scales, opacities, residual_sh, weights = torch.split(
|
| 475 |
+
gs_params, [4, 3, 1, self.nums_sh * 3, 1], dim=-1
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Apply activation functions to Gaussian parameters
|
| 479 |
+
splats["quats"] = act_gs.reg_dense_rotation(quats.reshape(B, S * H * W, 4))
|
| 480 |
+
splats["scales"] = act_gs.reg_dense_scales(scales.reshape(B, S * H * W, 3)).clamp_max(0.3)
|
| 481 |
+
splats["opacities"] = act_gs.reg_dense_opacities(opacities.reshape(B, S * H * W))
|
| 482 |
+
residual_sh = act_gs.reg_dense_sh(residual_sh.reshape(B, S * H * W, self.nums_sh * 3))
|
| 483 |
+
|
| 484 |
+
# Handle spherical harmonics (SH) coefficients
|
| 485 |
+
new_sh = torch.zeros_like(residual_sh)
|
| 486 |
+
new_sh[..., 0, :] = sh_utils.RGB2SH(
|
| 487 |
+
images[:, :S].permute(0, 1, 3, 4, 2).reshape(B, S * H * W, 3)
|
| 488 |
+
)
|
| 489 |
+
splats['sh'] = new_sh + residual_sh
|
| 490 |
+
splats['residual_sh'] = residual_sh
|
| 491 |
+
|
| 492 |
+
splats["weights"] = act_gs.reg_dense_weights(weights.reshape(B, S * H * W))
|
| 493 |
+
|
| 494 |
+
# Compute 3D positions based on specified method
|
| 495 |
+
if position_from == "pts3d":
|
| 496 |
+
pts3d = predictions["pts3d"][:, :S].reshape(B, S * H * W, 3)
|
| 497 |
+
splats["means"] = pts3d
|
| 498 |
+
elif position_from == "gsdepth+gtcamera":
|
| 499 |
+
depth = predictions["gs_depth"][:, :S].reshape(B * S, H, W)
|
| 500 |
+
pose4x4 = views["camera_poses"][:, :S].reshape(B * S, 4, 4)
|
| 501 |
+
intrinsic = views["camera_intrs"][:, :S].reshape(B * S, 3, 3)
|
| 502 |
+
pts3d, _, _ = depth_to_world_coords_points(depth, pose4x4, intrinsic)
|
| 503 |
+
pts3d = pts3d.reshape(B, S * H * W, 3)
|
| 504 |
+
splats["means"] = pts3d
|
| 505 |
+
|
| 506 |
+
elif position_from == "gsdepth+predcamera":
|
| 507 |
+
depth = predictions["gs_depth"][:, :S].reshape(B * S , H, W)
|
| 508 |
+
pose4x4 = context_predictions.get("camera_poses", predictions["camera_poses"])[:, :S].reshape(B * S, 4, 4)
|
| 509 |
+
intrinsic = context_predictions.get("camera_intrs", predictions["camera_intrs"])[:, :S].reshape(B * S, 3, 3)
|
| 510 |
+
pts3d, _, _ = depth_to_world_coords_points(depth, pose4x4.detach(), intrinsic.detach())
|
| 511 |
+
pts3d = pts3d.reshape(B, S * H * W, 3)
|
| 512 |
+
splats["means"] = pts3d
|
| 513 |
+
else:
|
| 514 |
+
raise ValueError(f"Invalid position_from={position_from}")
|
| 515 |
+
|
| 516 |
+
return splats
|
| 517 |
+
|
| 518 |
+
def prepare_cameras(self, views, nums):
|
| 519 |
+
viewmats = views['camera_poses'][:, :nums]
|
| 520 |
+
Ks = views['camera_intrs'][:, :nums]
|
| 521 |
+
return viewmats, Ks
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
hyworldmirror/models/models/visual_transformer.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from typing import Tuple, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.checkpoint import checkpoint
|
| 8 |
+
|
| 9 |
+
from ..layers import PatchEmbed, PatchEmbed_Mlp
|
| 10 |
+
from ..layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
| 11 |
+
from ..layers.block import Block, DistBlock
|
| 12 |
+
from ...comm.padding import minimal_pad_to_divisible,depad_by_length,pad_by_length
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from ...comm.communication import _All2All,_Allgather
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 19 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VisualGeometryTransformer(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
The VisualGeometryTransformer applies alternating-attention over input frames,
|
| 25 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
img_size (int): Image size in pixels.
|
| 29 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
| 30 |
+
embed_dim (int): Dimension of the token embeddings.
|
| 31 |
+
depth (int): Number of blocks.
|
| 32 |
+
num_heads (int): Number of attention heads.
|
| 33 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
| 34 |
+
num_register_tokens (int): Number of register tokens.
|
| 35 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
| 36 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
| 37 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
| 38 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
| 39 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
| 40 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
| 41 |
+
qk_norm (bool): Whether to apply QK normalization.
|
| 42 |
+
rope_base (int): Base frequency for rotary embedding.
|
| 43 |
+
rope_normalize_coords (str): Normalize coordinates for rotary embedding.
|
| 44 |
+
rope_shift_coords (float): Shift coordinates for rotary embedding.
|
| 45 |
+
rope_jitter_coords (float): Jitter coordinates for rotary embedding.
|
| 46 |
+
rope_rescale_coords (float): Rescale coordinates for rotary embedding.
|
| 47 |
+
init_values (float): Init scale for layer scale.
|
| 48 |
+
enable_condition (bool): Whether to enable conditioning inputs.
|
| 49 |
+
sampling_strategy (str): Sampling strategy for patches.
|
| 50 |
+
fixed_patch_embed (bool): Whether to fix patch embedding weights.
|
| 51 |
+
condition_strategy (list[str]): Strategy for each conditioning input.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
img_size=518,
|
| 57 |
+
patch_size=14,
|
| 58 |
+
embed_dim=1024,
|
| 59 |
+
depth=24,
|
| 60 |
+
num_heads=16,
|
| 61 |
+
mlp_ratio=4.0,
|
| 62 |
+
num_register_tokens=4,
|
| 63 |
+
block_fn=Block,
|
| 64 |
+
qkv_bias=True,
|
| 65 |
+
proj_bias=True,
|
| 66 |
+
ffn_bias=True,
|
| 67 |
+
patch_embed="dinov2_vitl14_reg",
|
| 68 |
+
qk_norm=True,
|
| 69 |
+
rope_base=100.0,
|
| 70 |
+
normalized_rope=False,
|
| 71 |
+
rope_normalize_coords="separate",
|
| 72 |
+
rope_shift_coords=None,
|
| 73 |
+
rope_jitter_coords=None,
|
| 74 |
+
rope_rescale_coords=None,
|
| 75 |
+
init_values=0.01,
|
| 76 |
+
enable_cond=False,
|
| 77 |
+
sampling_strategy="uniform",
|
| 78 |
+
fixed_patch_embed=False,
|
| 79 |
+
condition_strategy=["token", "pow3r", "token"],
|
| 80 |
+
intermediate_idxs: List[int] = [4, 11, 17, 23]
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
# Store config parameters
|
| 84 |
+
self.enable_cond = enable_cond
|
| 85 |
+
self.sampling_strategy = sampling_strategy
|
| 86 |
+
self.cond_methods = condition_strategy
|
| 87 |
+
self.intermediate_idxs = intermediate_idxs
|
| 88 |
+
self.depth = depth
|
| 89 |
+
self.patch_size = patch_size
|
| 90 |
+
|
| 91 |
+
# Initialize patch embedding module
|
| 92 |
+
self.patch_embed = self._init_patch_embedding_module(
|
| 93 |
+
patch_embed, img_size, patch_size, num_register_tokens,
|
| 94 |
+
embed_dim=embed_dim, is_fixed=fixed_patch_embed
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Initialize conditioning embeddings if enabled
|
| 98 |
+
if self.enable_cond:
|
| 99 |
+
self._init_cond_embeddings(embed_dim, img_size, patch_size, num_register_tokens)
|
| 100 |
+
|
| 101 |
+
# Initialize rotary position embedding
|
| 102 |
+
self._init_rotary_position_embedding(rope_base, normalized_rope, embed_dim // num_heads, rope_normalize_coords, rope_shift_coords, rope_jitter_coords, rope_rescale_coords)
|
| 103 |
+
|
| 104 |
+
# Initialize transformer blocks
|
| 105 |
+
self._init_transformer_blocks(block_fn, embed_dim, num_heads, mlp_ratio, qkv_bias, proj_bias, ffn_bias, init_values, qk_norm)
|
| 106 |
+
|
| 107 |
+
# Initialize learnable tokens
|
| 108 |
+
self._init_learnable_tokens(embed_dim, num_register_tokens)
|
| 109 |
+
|
| 110 |
+
# Calculate patch start index based on conditioning
|
| 111 |
+
if self.enable_cond:
|
| 112 |
+
self.patch_start_idx = 1 + num_register_tokens + 1 + 1 # camera + register + pose + rays
|
| 113 |
+
else:
|
| 114 |
+
self.patch_start_idx = 1 + num_register_tokens # camera + register
|
| 115 |
+
|
| 116 |
+
# Register normalization constants
|
| 117 |
+
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
| 118 |
+
self.register_buffer(name, torch.FloatTensor(value).reshape(1, 1, 3, 1, 1), persistent=False)
|
| 119 |
+
|
| 120 |
+
self.use_reentrant = False
|
| 121 |
+
|
| 122 |
+
def _init_patch_embedding_module(
|
| 123 |
+
self,
|
| 124 |
+
patch_embed_type,
|
| 125 |
+
img_size,
|
| 126 |
+
patch_size,
|
| 127 |
+
num_reg_tokens,
|
| 128 |
+
interpolate_antialias=True,
|
| 129 |
+
interpolate_offset=0.0,
|
| 130 |
+
block_chunks=0,
|
| 131 |
+
init_values=1.0,
|
| 132 |
+
embed_dim=1024,
|
| 133 |
+
is_fixed=False,
|
| 134 |
+
in_chans=3
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Create the patch embedding module. If 'conv', we use a
|
| 138 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
| 139 |
+
"""
|
| 140 |
+
if "conv" in patch_embed_type:
|
| 141 |
+
if 'mlp' in patch_embed_type:
|
| 142 |
+
patch_embed_module = PatchEmbed_Mlp(
|
| 143 |
+
img_size=img_size,
|
| 144 |
+
patch_size=patch_size,
|
| 145 |
+
in_chans=in_chans,
|
| 146 |
+
embed_dim=embed_dim
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
patch_embed_module = PatchEmbed(
|
| 150 |
+
img_size=img_size,
|
| 151 |
+
patch_size=patch_size,
|
| 152 |
+
in_chans=in_chans,
|
| 153 |
+
embed_dim=embed_dim
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
vit_models = {
|
| 157 |
+
"dinov2_vitl14_reg": vit_large,
|
| 158 |
+
"dinov2_vitb14_reg": vit_base,
|
| 159 |
+
"dinov2_vits14_reg": vit_small,
|
| 160 |
+
"dinov2_vitg2_reg": vit_giant2,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
patch_embed_module = vit_models[patch_embed_type](
|
| 164 |
+
img_size=img_size,
|
| 165 |
+
patch_size=patch_size,
|
| 166 |
+
num_register_tokens=num_reg_tokens,
|
| 167 |
+
interpolate_antialias=interpolate_antialias,
|
| 168 |
+
interpolate_offset=interpolate_offset,
|
| 169 |
+
block_chunks=block_chunks,
|
| 170 |
+
init_values=init_values,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Disable gradient updates for mask token
|
| 174 |
+
if hasattr(patch_embed_module, "mask_token"):
|
| 175 |
+
patch_embed_module.mask_token.requires_grad_(False)
|
| 176 |
+
|
| 177 |
+
if is_fixed:
|
| 178 |
+
for param in patch_embed_module.parameters():
|
| 179 |
+
param.requires_grad_(False)
|
| 180 |
+
|
| 181 |
+
return patch_embed_module
|
| 182 |
+
|
| 183 |
+
def _init_cond_embeddings(self, embed_dim, img_size, patch_size, num_reg_tokens):
|
| 184 |
+
"""Initialize conditioning embeddings for camera, depth, and rays."""
|
| 185 |
+
assert self.cond_methods is not None
|
| 186 |
+
assert self.cond_methods[0] == "token"
|
| 187 |
+
|
| 188 |
+
# Camera pose embedding
|
| 189 |
+
if self.cond_methods[0] == "token":
|
| 190 |
+
self.pose_embed = nn.Sequential(
|
| 191 |
+
nn.Linear(7, embed_dim, bias=True),
|
| 192 |
+
nn.SiLU(),
|
| 193 |
+
nn.Linear(embed_dim, embed_dim, bias=True)
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
raise NotImplementedError
|
| 197 |
+
|
| 198 |
+
# Depth map embedding
|
| 199 |
+
if self.cond_methods[1] == "pow3r":
|
| 200 |
+
self.depth_embed = self._init_patch_embedding_module(
|
| 201 |
+
"conv+mlp", img_size, patch_size, num_reg_tokens,
|
| 202 |
+
embed_dim=embed_dim, in_chans=1
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
raise NotImplementedError
|
| 206 |
+
|
| 207 |
+
# Ray direction embedding
|
| 208 |
+
if self.cond_methods[2] == "token":
|
| 209 |
+
self.ray_embed = nn.Sequential(
|
| 210 |
+
nn.Linear(4, embed_dim, bias=True),
|
| 211 |
+
nn.SiLU(),
|
| 212 |
+
nn.Linear(embed_dim, embed_dim, bias=True)
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
raise NotImplementedError
|
| 216 |
+
|
| 217 |
+
def _init_rotary_position_embedding(self, rope_base, normalized_rope, head_dim, rope_normalize_coords, rope_shift_coords, rope_jitter_coords, rope_rescale_coords):
|
| 218 |
+
if normalized_rope:
|
| 219 |
+
print("[INFO] Using normalized RoPE!")
|
| 220 |
+
from ..layers.norm_rope import NormalizedRotaryPositionEmbedding2D, PositionGetter
|
| 221 |
+
if head_dim % 4 != 0:
|
| 222 |
+
raise ValueError("RoPE requires head_dim divisible by 4 (embed_dim must be divisible by 4*num_heads)")
|
| 223 |
+
self.rope = NormalizedRotaryPositionEmbedding2D(
|
| 224 |
+
head_dim=head_dim,
|
| 225 |
+
base=rope_base,
|
| 226 |
+
normalize_coords=rope_normalize_coords,
|
| 227 |
+
shift_coords=rope_shift_coords,
|
| 228 |
+
jitter_coords=rope_jitter_coords,
|
| 229 |
+
rescale_coords=rope_rescale_coords,
|
| 230 |
+
) if rope_base > 0 else None
|
| 231 |
+
self.pos_getter = PositionGetter() if self.rope is not None else None
|
| 232 |
+
else:
|
| 233 |
+
from ..layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 234 |
+
print("[INFO] Using standard RoPE!")
|
| 235 |
+
self.rope = RotaryPositionEmbedding2D(
|
| 236 |
+
frequency=rope_base,
|
| 237 |
+
) if rope_base > 0 else None
|
| 238 |
+
self.pos_getter = PositionGetter() if self.rope is not None else None
|
| 239 |
+
|
| 240 |
+
def _init_transformer_blocks(self, block_fn, embed_dim, num_heads, mlp_ratio, qkv_bias, proj_bias, ffn_bias, init_values, qk_norm):
|
| 241 |
+
self.frame_blocks = nn.ModuleList([
|
| 242 |
+
block_fn(
|
| 243 |
+
dim=embed_dim,
|
| 244 |
+
num_heads=num_heads,
|
| 245 |
+
mlp_ratio=mlp_ratio,
|
| 246 |
+
qkv_bias=qkv_bias,
|
| 247 |
+
proj_bias=proj_bias,
|
| 248 |
+
ffn_bias=ffn_bias,
|
| 249 |
+
init_values=init_values,
|
| 250 |
+
qk_norm=qk_norm,
|
| 251 |
+
rope=self.rope,
|
| 252 |
+
)
|
| 253 |
+
for _ in range(self.depth)
|
| 254 |
+
])
|
| 255 |
+
|
| 256 |
+
self.global_blocks = nn.ModuleList([
|
| 257 |
+
block_fn(
|
| 258 |
+
dim=embed_dim,
|
| 259 |
+
num_heads=num_heads,
|
| 260 |
+
mlp_ratio=mlp_ratio,
|
| 261 |
+
qkv_bias=qkv_bias,
|
| 262 |
+
proj_bias=proj_bias,
|
| 263 |
+
ffn_bias=ffn_bias,
|
| 264 |
+
init_values=init_values,
|
| 265 |
+
qk_norm=qk_norm,
|
| 266 |
+
rope=self.rope
|
| 267 |
+
)
|
| 268 |
+
for _ in range(self.depth)
|
| 269 |
+
])
|
| 270 |
+
|
| 271 |
+
def _init_learnable_tokens(self, embed_dim, num_reg_tokens):
|
| 272 |
+
"""Initialize learnable tokens."""
|
| 273 |
+
self.cam_token = nn.Parameter(torch.zeros(1, 2, 1, embed_dim))
|
| 274 |
+
self.reg_token = nn.Parameter(torch.zeros(1, 2, num_reg_tokens, embed_dim))
|
| 275 |
+
nn.init.normal_(self.cam_token, std=1e-6)
|
| 276 |
+
nn.init.normal_(self.reg_token, std=1e-6)
|
| 277 |
+
|
| 278 |
+
def forward(self, images: torch.Tensor, priors: List | None=None, cond_flags: List[int]=[0,0,0], ctx_frames: int=None, enable_bf16=False, sp_size: int=1, sp_group: torch._C._distributed_c10d.ProcessGroup=None) -> Tuple[List[torch.Tensor], int]:
|
| 279 |
+
"""
|
| 280 |
+
Args:
|
| 281 |
+
images: Input images with shape [B, S, 3, H, W], in range [0, 1]
|
| 282 |
+
priors: Optional tuple of (depth, rays, poses) for conditioning
|
| 283 |
+
cond_flags: List indicating which conditions to use [pose, depth, rays]
|
| 284 |
+
ctx_frames: Number of context frames to use
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
(list[torch.Tensor], int): List of attention block outputs and patch_start_idx
|
| 288 |
+
"""
|
| 289 |
+
depth_maps, ray_dirs, poses = priors if priors is not None else (None, None, None)
|
| 290 |
+
|
| 291 |
+
# Slice to context frames if specified
|
| 292 |
+
if ctx_frames is not None:
|
| 293 |
+
for var_name in ['images', 'depth_maps', 'ray_dirs', 'poses']:
|
| 294 |
+
var = locals()[var_name]
|
| 295 |
+
if var is not None:
|
| 296 |
+
locals()[var_name] = var[:, :ctx_frames].clone()
|
| 297 |
+
|
| 298 |
+
# Process image tokens
|
| 299 |
+
b, seq_len, ch, h, w = images.shape
|
| 300 |
+
if ch != 3:
|
| 301 |
+
raise ValueError(f"Expected 3 input channels, got {ch}")
|
| 302 |
+
|
| 303 |
+
with torch.amp.autocast('cuda', enabled=(not enable_bf16), dtype=torch.bfloat16):
|
| 304 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
| 305 |
+
images = images.reshape(b * seq_len, ch, h, w)
|
| 306 |
+
patch_tokens = self.patch_embed(images)
|
| 307 |
+
if isinstance(patch_tokens, dict):
|
| 308 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
| 309 |
+
|
| 310 |
+
_, patch_count, embed_dim = patch_tokens.shape
|
| 311 |
+
|
| 312 |
+
# Prepare special tokens
|
| 313 |
+
cam_tokens = expand_and_flatten_special_tokens(self.cam_token, b, seq_len)
|
| 314 |
+
reg_tokens = expand_and_flatten_special_tokens(self.reg_token, b, seq_len)
|
| 315 |
+
|
| 316 |
+
# Process all tokens (optional conditioning)
|
| 317 |
+
if self.enable_cond:
|
| 318 |
+
pose_tokens, depth_tokens, ray_tokens = self._process_conditioning(depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags)
|
| 319 |
+
# Add condition tokens to patch tokens
|
| 320 |
+
patch_tokens = patch_tokens + depth_tokens
|
| 321 |
+
all_tokens = torch.cat([cam_tokens, reg_tokens, pose_tokens, ray_tokens, patch_tokens], dim=1)
|
| 322 |
+
else:
|
| 323 |
+
all_tokens = torch.cat([cam_tokens, reg_tokens, patch_tokens], dim=1)
|
| 324 |
+
|
| 325 |
+
_, patch_count, embed_dim = all_tokens.shape
|
| 326 |
+
|
| 327 |
+
# Position embedding
|
| 328 |
+
pos_emb = None
|
| 329 |
+
if self.rope is not None:
|
| 330 |
+
pos_emb = self.pos_getter(b * seq_len, h // self.patch_size, w // self.patch_size, device=images.device)
|
| 331 |
+
if self.patch_start_idx > 0:
|
| 332 |
+
pos_emb = pos_emb + 1
|
| 333 |
+
special_pos = torch.zeros(b * seq_len, self.patch_start_idx, 2, device=images.device, dtype=pos_emb.dtype)
|
| 334 |
+
pos_emb = torch.cat([special_pos, pos_emb], dim=1)
|
| 335 |
+
|
| 336 |
+
if sp_size>1:
|
| 337 |
+
rank_in_sp_group = dist.get_group_rank(sp_group,dist.get_rank())
|
| 338 |
+
all_tokens,tk_padding_len = minimal_pad_to_divisible(all_tokens, sp_size, dim=1,pad_value=0)
|
| 339 |
+
all_tokens = torch.chunk(all_tokens, sp_size,dim=1)[rank_in_sp_group]
|
| 340 |
+
|
| 341 |
+
_, patch_count, embed_dim = all_tokens.shape
|
| 342 |
+
token_shape = (b, seq_len, patch_count, embed_dim)
|
| 343 |
+
# Forward through attention blocks
|
| 344 |
+
with torch.amp.autocast('cuda', enabled=(not enable_bf16), dtype=torch.bfloat16):
|
| 345 |
+
outputs = []
|
| 346 |
+
global_tokens = None
|
| 347 |
+
if sp_size>1:
|
| 348 |
+
for idx in range(self.depth):
|
| 349 |
+
local_tokens = self._process_dist_attention_blocks(
|
| 350 |
+
tokens=all_tokens if global_tokens is None else global_tokens,
|
| 351 |
+
b=b,
|
| 352 |
+
seq_len=seq_len,
|
| 353 |
+
patch_count=patch_count,
|
| 354 |
+
embed_dim=embed_dim,
|
| 355 |
+
block_idx=idx,
|
| 356 |
+
blocks=self.frame_blocks,
|
| 357 |
+
block_type='frame',
|
| 358 |
+
pos=pos_emb,
|
| 359 |
+
sp_size = sp_size,
|
| 360 |
+
sp_group = sp_group,
|
| 361 |
+
padding_tokens = tk_padding_len
|
| 362 |
+
)
|
| 363 |
+
global_tokens = self._process_dist_attention_blocks(
|
| 364 |
+
tokens=local_tokens,
|
| 365 |
+
b=b,
|
| 366 |
+
seq_len=seq_len,
|
| 367 |
+
patch_count=patch_count,
|
| 368 |
+
embed_dim=embed_dim,
|
| 369 |
+
block_idx=idx,
|
| 370 |
+
blocks=self.global_blocks,
|
| 371 |
+
block_type='global',
|
| 372 |
+
pos=pos_emb,
|
| 373 |
+
sp_size = sp_size,
|
| 374 |
+
sp_group = sp_group,
|
| 375 |
+
padding_tokens = tk_padding_len
|
| 376 |
+
)
|
| 377 |
+
global_tokens = global_tokens.reshape(b,-1,embed_dim)
|
| 378 |
+
global_tokens = _Allgather.apply(global_tokens,1,sp_group,False)
|
| 379 |
+
global_tokens = depad_by_length(global_tokens,tk_padding_len*seq_len,1)
|
| 380 |
+
global_tokens = global_tokens.reshape(b,seq_len,-1,embed_dim)
|
| 381 |
+
global_tokens = pad_by_length(global_tokens,tk_padding_len,2)
|
| 382 |
+
global_tokens = torch.chunk(global_tokens, sp_size,dim=2)[rank_in_sp_group]
|
| 383 |
+
|
| 384 |
+
# Combine frame and global intermediates
|
| 385 |
+
if idx in self.intermediate_idxs:
|
| 386 |
+
local_tokens = _Allgather.apply(local_tokens,2,sp_group,False)
|
| 387 |
+
local_tokens = depad_by_length(local_tokens,tk_padding_len,2)
|
| 388 |
+
global_tokens = _Allgather.apply(global_tokens,2,sp_group,False)
|
| 389 |
+
global_tokens = depad_by_length(global_tokens,tk_padding_len,2)
|
| 390 |
+
combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
|
| 391 |
+
outputs.append(combined_out)
|
| 392 |
+
global_tokens = pad_by_length(global_tokens,tk_padding_len,2)
|
| 393 |
+
global_tokens = torch.chunk(global_tokens, sp_size,dim=2)[rank_in_sp_group]
|
| 394 |
+
else:
|
| 395 |
+
for idx in range(self.depth):
|
| 396 |
+
local_tokens = self._process_attention_blocks(
|
| 397 |
+
tokens=all_tokens if global_tokens is None else global_tokens,
|
| 398 |
+
b=b,
|
| 399 |
+
seq_len=seq_len,
|
| 400 |
+
patch_count=patch_count,
|
| 401 |
+
embed_dim=embed_dim,
|
| 402 |
+
block_idx=idx,
|
| 403 |
+
blocks=self.frame_blocks,
|
| 404 |
+
block_type='frame',
|
| 405 |
+
pos=pos_emb,
|
| 406 |
+
)
|
| 407 |
+
global_tokens = self._process_attention_blocks(
|
| 408 |
+
tokens=local_tokens,
|
| 409 |
+
b=b,
|
| 410 |
+
seq_len=seq_len,
|
| 411 |
+
patch_count=patch_count,
|
| 412 |
+
embed_dim=embed_dim,
|
| 413 |
+
block_idx=idx,
|
| 414 |
+
blocks=self.global_blocks,
|
| 415 |
+
block_type='global',
|
| 416 |
+
pos=pos_emb,
|
| 417 |
+
)
|
| 418 |
+
# Combine frame and global intermediates
|
| 419 |
+
if idx in self.intermediate_idxs:
|
| 420 |
+
combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
|
| 421 |
+
outputs.append(combined_out)
|
| 422 |
+
|
| 423 |
+
# Combine frame and global intermediates
|
| 424 |
+
if idx in self.intermediate_idxs:
|
| 425 |
+
combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
|
| 426 |
+
outputs.append(combined_out)
|
| 427 |
+
|
| 428 |
+
return outputs, self.patch_start_idx
|
| 429 |
+
|
| 430 |
+
def _process_conditioning(self, depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags):
|
| 431 |
+
"""Process conditioning inputs."""
|
| 432 |
+
h, w = images.shape[-2:]
|
| 433 |
+
|
| 434 |
+
# Process camera pose embedding
|
| 435 |
+
use_poses = (cond_flags[0] == 1 and poses is not None)
|
| 436 |
+
if use_poses:
|
| 437 |
+
poses = poses.reshape(b*seq_len, -1)
|
| 438 |
+
pose_tokens = self.pose_embed(poses).unsqueeze(1)
|
| 439 |
+
else:
|
| 440 |
+
pose_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
|
| 441 |
+
|
| 442 |
+
# Process depth map embedding
|
| 443 |
+
use_depth = cond_flags[1] == 1 and depth_maps is not None
|
| 444 |
+
if use_depth:
|
| 445 |
+
depth_maps = depth_maps.reshape(b*seq_len, 1, h, w)
|
| 446 |
+
depth_tokens = self.depth_embed(depth_maps).reshape(b * seq_len, patch_count, embed_dim)
|
| 447 |
+
else:
|
| 448 |
+
depth_tokens = torch.zeros((b*seq_len, patch_count, embed_dim), device=images.device, dtype=images.dtype)
|
| 449 |
+
|
| 450 |
+
# Process ray direction embedding
|
| 451 |
+
use_rays = cond_flags[2] == 1 and ray_dirs is not None
|
| 452 |
+
if use_rays:
|
| 453 |
+
ray_dirs = ray_dirs.reshape(b*seq_len, -1)
|
| 454 |
+
ray_tokens = self.ray_embed(ray_dirs).unsqueeze(1)
|
| 455 |
+
else:
|
| 456 |
+
ray_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
|
| 457 |
+
|
| 458 |
+
return pose_tokens, depth_tokens, ray_tokens
|
| 459 |
+
|
| 460 |
+
def _process_attention_blocks(self, tokens, b, seq_len, patch_count, embed_dim, block_idx, blocks, block_type, pos=None):
|
| 461 |
+
"""Process attention blocks with tokens in shape (B*S, P, C)."""
|
| 462 |
+
token_shape = (b, seq_len, patch_count, embed_dim)
|
| 463 |
+
if block_type == 'frame': # local
|
| 464 |
+
target_shape = (b * seq_len, patch_count, embed_dim)
|
| 465 |
+
pos_target_shape = (b * seq_len, patch_count, 2) if pos is not None else None
|
| 466 |
+
else: # global
|
| 467 |
+
target_shape = (b, seq_len * patch_count, embed_dim)
|
| 468 |
+
pos_target_shape = (b, seq_len * patch_count, 2) if pos is not None else None
|
| 469 |
+
|
| 470 |
+
if tokens.shape != target_shape:
|
| 471 |
+
tokens = tokens.reshape(*target_shape)
|
| 472 |
+
|
| 473 |
+
if pos is not None and pos.shape != pos_target_shape:
|
| 474 |
+
pos = pos.reshape(*pos_target_shape)
|
| 475 |
+
|
| 476 |
+
if self.training:
|
| 477 |
+
# tokens = blocks[block_idx](tokens, pos=pos)
|
| 478 |
+
tokens = checkpoint(blocks[block_idx], tokens, pos=pos, use_reentrant=self.use_reentrant)
|
| 479 |
+
else:
|
| 480 |
+
tokens = blocks[block_idx](tokens, pos=pos)
|
| 481 |
+
|
| 482 |
+
return tokens.reshape(*token_shape)
|
| 483 |
+
|
| 484 |
+
def _process_dist_attention_blocks(self, tokens, b, seq_len, patch_count, embed_dim, block_idx, blocks, block_type, pos=None,
|
| 485 |
+
sp_size = 1,
|
| 486 |
+
sp_group = None,
|
| 487 |
+
padding_tokens = 0):
|
| 488 |
+
"""Process attention blocks with tokens in shape (B*S, P, C)."""
|
| 489 |
+
token_shape = (b, seq_len, patch_count, embed_dim)
|
| 490 |
+
if block_type == 'frame': # local
|
| 491 |
+
target_shape = (b * seq_len, patch_count, embed_dim)
|
| 492 |
+
pos_target_shape = (b * seq_len, patch_count*sp_size-padding_tokens, 2) if pos is not None else None
|
| 493 |
+
else: # global
|
| 494 |
+
target_shape = (b, seq_len * patch_count, embed_dim)
|
| 495 |
+
pos_target_shape = (b, seq_len * (patch_count*sp_size-padding_tokens), 2) if pos is not None else None
|
| 496 |
+
# padding_tokens = padding_tokens*seq_len
|
| 497 |
+
|
| 498 |
+
if block_type=="global":
|
| 499 |
+
rank_in_sp_group = dist.get_group_rank(sp_group,dist.get_rank())
|
| 500 |
+
tokens = _Allgather.apply(tokens,2,sp_group,False) #(1,7,4*146,64)
|
| 501 |
+
tokens = depad_by_length(tokens,padding_tokens,2) #(1,7,4*146-2,64)
|
| 502 |
+
tokens = tokens.reshape(b,-1,embed_dim) #(1,7*(4*146-2),64)
|
| 503 |
+
padding_tokens = padding_tokens*seq_len
|
| 504 |
+
tokens = pad_by_length(tokens,padding_tokens,1) #(1,4088,1024)
|
| 505 |
+
tokens = torch.chunk(tokens, sp_size,dim=1)[rank_in_sp_group]
|
| 506 |
+
|
| 507 |
+
if tokens.shape != target_shape:
|
| 508 |
+
tokens = tokens.reshape(*target_shape)
|
| 509 |
+
|
| 510 |
+
if pos is not None and pos.shape != pos_target_shape:
|
| 511 |
+
pos = pos.reshape(*pos_target_shape)
|
| 512 |
+
|
| 513 |
+
if self.training:
|
| 514 |
+
# tokens = blocks[block_idx](tokens, pos=pos)
|
| 515 |
+
tokens = checkpoint(blocks[block_idx], tokens, pos=pos, use_reentrant=self.use_reentrant, sp_size=sp_size, sp_group=sp_group, padding_tokens=padding_tokens, block_type =block_type,token_shape=token_shape)
|
| 516 |
+
else:
|
| 517 |
+
tokens = blocks[block_idx](tokens, pos=pos, sp_size=sp_size, sp_group=sp_group, padding_tokens=padding_tokens, block_type =block_type,token_shape=token_shape)
|
| 518 |
+
|
| 519 |
+
return tokens.reshape(*token_shape)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def expand_and_flatten_special_tokens(token_tensor, b, seq_len):
|
| 524 |
+
"""
|
| 525 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing.
|
| 526 |
+
Uses first position for frame 0, second position for remaining frames.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
token_tensor: Input tensor with shape (1, 2, X, C)
|
| 530 |
+
b: Batch size
|
| 531 |
+
seq_len: Sequence length
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 535 |
+
"""
|
| 536 |
+
# First frame uses position 0, remaining frames use position 1
|
| 537 |
+
first_frame_tokens = token_tensor[:, 0:1, ...].expand(b, 1, *token_tensor.shape[2:])
|
| 538 |
+
remaining_frame_tokens = token_tensor[:, 1:, ...].expand(b, seq_len - 1, *token_tensor.shape[2:])
|
| 539 |
+
|
| 540 |
+
# Concatenate and flatten
|
| 541 |
+
combined_tokens = torch.cat([first_frame_tokens, remaining_frame_tokens], dim=1)
|
| 542 |
+
return combined_tokens.reshape(b * seq_len, *combined_tokens.shape[2:])
|
hyworldmirror/models/models/worldmirror.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .visual_transformer import VisualGeometryTransformer
|
| 8 |
+
from ..heads.camera_head import CameraHead
|
| 9 |
+
from ..heads.dense_head import DPTHead
|
| 10 |
+
from ..heads.gs_head import GSFeatHead
|
| 11 |
+
from .rasterization import GaussianSplatRenderer
|
| 12 |
+
from ..utils.camera_utils import (
|
| 13 |
+
vector_to_camera_matrices,
|
| 14 |
+
extrinsics_to_vector,
|
| 15 |
+
)
|
| 16 |
+
from ..utils.priors import normalize_depth, normalize_poses
|
| 17 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 18 |
+
|
| 19 |
+
from ..layers.block import Block, DistBlock
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WorldMirror(nn.Module, PyTorchModelHubMixin):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
img_size=518,
|
| 27 |
+
patch_size=14,
|
| 28 |
+
model_size="large",
|
| 29 |
+
embed_dim=1024,
|
| 30 |
+
depth=24,
|
| 31 |
+
num_heads=16,
|
| 32 |
+
mlp_ratio=4.0,
|
| 33 |
+
gs_dim=256,
|
| 34 |
+
num_register_tokens=4,
|
| 35 |
+
enable_cond=True,
|
| 36 |
+
enable_cam=True,
|
| 37 |
+
enable_pts=True,
|
| 38 |
+
enable_depth=True,
|
| 39 |
+
enable_depth_mask=True,
|
| 40 |
+
enable_norm=True,
|
| 41 |
+
enable_gs=True,
|
| 42 |
+
enable_bf16=False,
|
| 43 |
+
patch_embed="dinov2_vitl14_reg",
|
| 44 |
+
fixed_patch_embed=False,
|
| 45 |
+
sampling_strategy="uniform",
|
| 46 |
+
dpt_gradient_checkpoint=False,
|
| 47 |
+
condition_strategy=["token", "pow3r", "token"],
|
| 48 |
+
rope_base=100.0,
|
| 49 |
+
normalized_rope=True,
|
| 50 |
+
rope_normalize_coords="separate",
|
| 51 |
+
rope_shift_coords=None,
|
| 52 |
+
rope_jitter_coords=None,
|
| 53 |
+
rope_rescale_coords=None,
|
| 54 |
+
sp_size=1,
|
| 55 |
+
# Legacy parameters (ignored, kept for checkpoint compatibility)
|
| 56 |
+
set_sky_region_to_maxdepth=False,
|
| 57 |
+
disable_gs_depth=False,
|
| 58 |
+
):
|
| 59 |
+
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
self.intermediate_layer_idx = {
|
| 63 |
+
"small": [2, 5, 8, 11],
|
| 64 |
+
"base": [2, 5, 8, 11],
|
| 65 |
+
"large": [4, 11, 17, 23],
|
| 66 |
+
"giant": [9, 19, 29, 39],
|
| 67 |
+
}
|
| 68 |
+
self.model_size = model_size
|
| 69 |
+
if model_size == "large":
|
| 70 |
+
embed_dim = 1024
|
| 71 |
+
depth = 24
|
| 72 |
+
num_heads = 16
|
| 73 |
+
mlp_ratio = 4.0
|
| 74 |
+
gs_dim = 256
|
| 75 |
+
num_register_tokens = 4
|
| 76 |
+
elif model_size == "base":
|
| 77 |
+
embed_dim = 768
|
| 78 |
+
depth = 12
|
| 79 |
+
num_heads = 12
|
| 80 |
+
mlp_ratio = 4.0
|
| 81 |
+
gs_dim = 256
|
| 82 |
+
num_register_tokens = 4
|
| 83 |
+
elif model_size == "small":
|
| 84 |
+
embed_dim = 384
|
| 85 |
+
depth = 12
|
| 86 |
+
num_heads = 6
|
| 87 |
+
mlp_ratio = 4.0
|
| 88 |
+
gs_dim = 128
|
| 89 |
+
num_register_tokens = 4
|
| 90 |
+
elif model_size is None:
|
| 91 |
+
pass
|
| 92 |
+
print(
|
| 93 |
+
f"[WorldMirror] model_size: {model_size}, embed_dim: {embed_dim}, "
|
| 94 |
+
f"depth: {depth}, num_heads: {num_heads}, mlp_ratio: {mlp_ratio}, "
|
| 95 |
+
f"gs_dim: {gs_dim}, num_register_tokens: {num_register_tokens}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.img_size = img_size
|
| 99 |
+
self.patch_size = patch_size
|
| 100 |
+
self.embed_dim = embed_dim
|
| 101 |
+
self.depth = depth
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
self.mlp_ratio = mlp_ratio
|
| 104 |
+
self.gs_dim = gs_dim
|
| 105 |
+
self.num_register_tokens = num_register_tokens
|
| 106 |
+
|
| 107 |
+
self.normalized_rope = normalized_rope
|
| 108 |
+
self.rope_normalize_coords = rope_normalize_coords
|
| 109 |
+
self.rope_shift_coords = rope_shift_coords
|
| 110 |
+
self.rope_jitter_coords = rope_jitter_coords
|
| 111 |
+
self.rope_rescale_coords = rope_rescale_coords
|
| 112 |
+
|
| 113 |
+
self.enable_cam = enable_cam
|
| 114 |
+
self.enable_pts = enable_pts
|
| 115 |
+
self.enable_depth = enable_depth
|
| 116 |
+
self.enable_depth_mask = enable_depth_mask
|
| 117 |
+
self.enable_cond = enable_cond
|
| 118 |
+
self.enable_norm = enable_norm
|
| 119 |
+
self.enable_gs = enable_gs
|
| 120 |
+
self.enable_bf16 = enable_bf16
|
| 121 |
+
self.patch_embed = patch_embed
|
| 122 |
+
self.sampling = sampling_strategy
|
| 123 |
+
self.dpt_checkpoint = dpt_gradient_checkpoint
|
| 124 |
+
self.cond_methods = condition_strategy
|
| 125 |
+
self.config = self._store_config()
|
| 126 |
+
self.sp_size = sp_size
|
| 127 |
+
|
| 128 |
+
self.visual_geometry_transformer = VisualGeometryTransformer(
|
| 129 |
+
img_size=img_size,
|
| 130 |
+
patch_size=patch_size,
|
| 131 |
+
embed_dim=embed_dim,
|
| 132 |
+
depth=depth,
|
| 133 |
+
num_heads=num_heads,
|
| 134 |
+
mlp_ratio=mlp_ratio,
|
| 135 |
+
num_register_tokens=num_register_tokens,
|
| 136 |
+
block_fn=Block if self.sp_size == 1 else DistBlock,
|
| 137 |
+
normalized_rope=normalized_rope,
|
| 138 |
+
rope_normalize_coords=rope_normalize_coords,
|
| 139 |
+
rope_shift_coords=rope_shift_coords,
|
| 140 |
+
rope_jitter_coords=rope_jitter_coords,
|
| 141 |
+
rope_rescale_coords=rope_rescale_coords,
|
| 142 |
+
enable_cond=enable_cond,
|
| 143 |
+
sampling_strategy=sampling_strategy,
|
| 144 |
+
patch_embed=patch_embed,
|
| 145 |
+
fixed_patch_embed=fixed_patch_embed,
|
| 146 |
+
condition_strategy=condition_strategy,
|
| 147 |
+
intermediate_idxs=self.intermediate_layer_idx[model_size],
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self._init_heads(embed_dim, patch_size, gs_dim)
|
| 151 |
+
|
| 152 |
+
if enable_bf16:
|
| 153 |
+
self.to = self._bf16_to
|
| 154 |
+
|
| 155 |
+
def _store_config(self):
|
| 156 |
+
"""Save the model configuration."""
|
| 157 |
+
return {
|
| 158 |
+
"img_size": self.img_size,
|
| 159 |
+
"patch_size": self.patch_size,
|
| 160 |
+
"embed_dim": self.embed_dim,
|
| 161 |
+
"depth": self.depth,
|
| 162 |
+
"num_heads": self.num_heads,
|
| 163 |
+
"mlp_ratio": self.mlp_ratio,
|
| 164 |
+
"gs_dim": self.gs_dim,
|
| 165 |
+
"num_register_tokens": self.num_register_tokens,
|
| 166 |
+
"normalized_rope": self.normalized_rope,
|
| 167 |
+
"rope_normalize_coords": self.rope_normalize_coords,
|
| 168 |
+
"rope_shift_coords": self.rope_shift_coords,
|
| 169 |
+
"rope_jitter_coords": self.rope_jitter_coords,
|
| 170 |
+
"rope_rescale_coords": self.rope_rescale_coords,
|
| 171 |
+
"enable_cam": self.enable_cam,
|
| 172 |
+
"enable_pts": self.enable_pts,
|
| 173 |
+
"enable_depth": self.enable_depth,
|
| 174 |
+
"enable_depth_mask": self.enable_depth_mask,
|
| 175 |
+
"enable_norm": self.enable_norm,
|
| 176 |
+
"enable_gs": self.enable_gs,
|
| 177 |
+
"patch_embed": self.patch_embed,
|
| 178 |
+
"sampling_strategy": self.sampling,
|
| 179 |
+
"dpt_gradient_checkpoint": self.dpt_checkpoint,
|
| 180 |
+
"condition_strategy": self.cond_methods,
|
| 181 |
+
"model_size": self.model_size,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def _init_heads(self, dim, patch_size, gs_dim):
|
| 185 |
+
"""Initialize all prediction heads."""
|
| 186 |
+
|
| 187 |
+
if self.enable_cam:
|
| 188 |
+
self.cam_head = CameraHead(
|
| 189 |
+
dim_in=2 * dim,
|
| 190 |
+
block_fn=Block if self.sp_size == 1 else DistBlock,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if self.enable_pts:
|
| 194 |
+
self.pts_head = DPTHead(
|
| 195 |
+
dim_in=2 * dim,
|
| 196 |
+
output_dim=4,
|
| 197 |
+
patch_size=patch_size,
|
| 198 |
+
activation="inv_log+expp1",
|
| 199 |
+
gradient_checkpoint=self.dpt_checkpoint,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if self.enable_depth:
|
| 203 |
+
self.depth_head = DPTHead(
|
| 204 |
+
dim_in=2 * dim,
|
| 205 |
+
output_dim=2 if not self.enable_depth_mask else 3,
|
| 206 |
+
patch_size=patch_size,
|
| 207 |
+
activation="exp+expp1" if not self.enable_depth_mask else "exp+expp1+linear",
|
| 208 |
+
enable_depth_mask=self.enable_depth_mask,
|
| 209 |
+
gradient_checkpoint=self.dpt_checkpoint,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if self.enable_norm:
|
| 213 |
+
self.norm_head = DPTHead(
|
| 214 |
+
dim_in=2 * dim,
|
| 215 |
+
output_dim=4,
|
| 216 |
+
patch_size=patch_size,
|
| 217 |
+
activation="norm+expp1",
|
| 218 |
+
gradient_checkpoint=self.dpt_checkpoint,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if self.enable_gs:
|
| 222 |
+
self.gs_head = DPTHead(
|
| 223 |
+
dim_in=2 * dim,
|
| 224 |
+
output_dim=2 if not self.enable_depth_mask else 3,
|
| 225 |
+
patch_size=patch_size,
|
| 226 |
+
features=gs_dim,
|
| 227 |
+
is_gsdpt=True,
|
| 228 |
+
activation="exp+expp1" if not self.enable_depth_mask else "exp+expp1+linear",
|
| 229 |
+
enable_depth_mask=self.enable_depth_mask,
|
| 230 |
+
gradient_checkpoint=self.dpt_checkpoint,
|
| 231 |
+
)
|
| 232 |
+
self.gs_renderer = GaussianSplatRenderer(
|
| 233 |
+
feature_dim=gs_dim,
|
| 234 |
+
sh_degree=0,
|
| 235 |
+
enable_prune=True,
|
| 236 |
+
voxel_size=0.002,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def _bf16_to(self, *args, **kwargs):
|
| 240 |
+
"""Custom to() for bf16 mode: selectively move heads to target device/dtype."""
|
| 241 |
+
self.visual_geometry_transformer = self.visual_geometry_transformer.to(*args, **kwargs)
|
| 242 |
+
if self.enable_cam:
|
| 243 |
+
self.cam_head = self.cam_head.to(*args, **kwargs)
|
| 244 |
+
if self.enable_pts:
|
| 245 |
+
self.pts_head = self.pts_head.to(*args, **kwargs)
|
| 246 |
+
if self.enable_depth:
|
| 247 |
+
self.depth_head = self.depth_head.to(*args, **kwargs)
|
| 248 |
+
if self.enable_norm:
|
| 249 |
+
self.norm_head = self.norm_head.to(*args, **kwargs)
|
| 250 |
+
if self.enable_gs:
|
| 251 |
+
self.gs_head = self.gs_head.to(*args, **kwargs)
|
| 252 |
+
self.gs_renderer = self.gs_renderer.to(*args, **kwargs)
|
| 253 |
+
return self
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
views: Dict[str, torch.Tensor],
|
| 258 |
+
cond_flags: List[int] = [0, 0, 0],
|
| 259 |
+
is_inference=True,
|
| 260 |
+
sp_size=1,
|
| 261 |
+
sp_group=None,
|
| 262 |
+
):
|
| 263 |
+
"""Execute forward pass through the WorldMirror model.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
views: Input data dictionary containing 'img' and optional priors.
|
| 267 |
+
cond_flags: Conditioning flags [pose, depth, intrinsics].
|
| 268 |
+
is_inference: Whether running in inference mode.
|
| 269 |
+
sp_size: Sequence parallel size (>1 for multi-GPU).
|
| 270 |
+
sp_group: Process group for SP communication.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
dict: Prediction results dictionary.
|
| 274 |
+
"""
|
| 275 |
+
if self.enable_bf16:
|
| 276 |
+
views['img'] = views['img'].to(torch.bfloat16)
|
| 277 |
+
|
| 278 |
+
imgs = views["img"]
|
| 279 |
+
use_cond = sum(cond_flags) > 0
|
| 280 |
+
|
| 281 |
+
if use_cond:
|
| 282 |
+
priors = self.extract_priors(views)
|
| 283 |
+
token_list, patch_start_idx = self.visual_geometry_transformer(
|
| 284 |
+
imgs, priors, cond_flags=cond_flags,
|
| 285 |
+
enable_bf16=self.enable_bf16, sp_size=sp_size, sp_group=sp_group,
|
| 286 |
+
)
|
| 287 |
+
else:
|
| 288 |
+
token_list, patch_start_idx = self.visual_geometry_transformer(
|
| 289 |
+
imgs, enable_bf16=self.enable_bf16, sp_size=sp_size, sp_group=sp_group,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
with torch.amp.autocast('cuda', enabled=(not self.enable_bf16), dtype=torch.float32):
|
| 293 |
+
if sp_size > 1:
|
| 294 |
+
preds = self._gen_all_preds_frame_sp(
|
| 295 |
+
token_list, imgs, patch_start_idx, views, cond_flags,
|
| 296 |
+
is_inference, sp_size, sp_group,
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
preds = self._gen_all_preds(
|
| 300 |
+
token_list, imgs, patch_start_idx, views, cond_flags, is_inference,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return preds
|
| 304 |
+
|
| 305 |
+
def _gen_all_preds_frame_sp(
|
| 306 |
+
self, token_list, imgs, patch_start_idx, views, cond_flags, is_inference,
|
| 307 |
+
sp_size, sp_group,
|
| 308 |
+
):
|
| 309 |
+
"""Generate predictions with frame-parallel DPT heads for SP inference.
|
| 310 |
+
|
| 311 |
+
Splits S frames across sp_size ranks. Each rank processes S/sp_size frames
|
| 312 |
+
through ALL head types, then Allgather to reconstruct full results.
|
| 313 |
+
CameraHead runs on all frames on every rank (cross-view attention needed).
|
| 314 |
+
"""
|
| 315 |
+
preds = {}
|
| 316 |
+
rank = dist.get_rank()
|
| 317 |
+
rank_in_sp = dist.get_group_rank(sp_group, rank)
|
| 318 |
+
|
| 319 |
+
B, S, C_img, H, W = imgs.shape
|
| 320 |
+
|
| 321 |
+
# Determine frame assignment for this rank
|
| 322 |
+
if S >= sp_size:
|
| 323 |
+
base_chunk = S // sp_size
|
| 324 |
+
remainder = S % sp_size
|
| 325 |
+
if rank_in_sp < remainder:
|
| 326 |
+
my_count = base_chunk + 1
|
| 327 |
+
my_start = rank_in_sp * (base_chunk + 1)
|
| 328 |
+
else:
|
| 329 |
+
my_count = base_chunk
|
| 330 |
+
my_start = remainder * (base_chunk + 1) + (rank_in_sp - remainder) * base_chunk
|
| 331 |
+
else:
|
| 332 |
+
if rank_in_sp < S:
|
| 333 |
+
my_count = 1
|
| 334 |
+
my_start = rank_in_sp
|
| 335 |
+
else:
|
| 336 |
+
my_count = 0
|
| 337 |
+
my_start = S
|
| 338 |
+
|
| 339 |
+
my_end = my_start + my_count
|
| 340 |
+
has_frames = my_count > 0
|
| 341 |
+
|
| 342 |
+
if has_frames:
|
| 343 |
+
token_list_chunk = [t[:, my_start:my_end].contiguous() for t in token_list]
|
| 344 |
+
imgs_chunk = imgs[:, my_start:my_end].contiguous()
|
| 345 |
+
|
| 346 |
+
# Camera head: runs on ALL frames on every rank (cross-view attention)
|
| 347 |
+
if self.enable_cam:
|
| 348 |
+
cam_seq = self.cam_head(token_list)
|
| 349 |
+
cam_params = cam_seq[-1]
|
| 350 |
+
preds["camera_params"] = cam_params
|
| 351 |
+
c2w_mat, int_mat = self.transform_camera_vector(cam_params, H, W)
|
| 352 |
+
preds["camera_poses"] = c2w_mat
|
| 353 |
+
preds["camera_intrs"] = int_mat
|
| 354 |
+
|
| 355 |
+
# DPT heads: frame-parallel
|
| 356 |
+
if self.enable_depth:
|
| 357 |
+
if has_frames:
|
| 358 |
+
if self.enable_depth_mask:
|
| 359 |
+
depth_chunk, depth_conf_chunk, depth_mask_logits_chunk = self.depth_head(
|
| 360 |
+
token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
depth_chunk, depth_conf_chunk = self.depth_head(
|
| 364 |
+
token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
depth_chunk = torch.zeros(B, 0, H, W, 1, dtype=imgs.dtype, device=imgs.device)
|
| 368 |
+
depth_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 369 |
+
if self.enable_depth_mask:
|
| 370 |
+
depth_mask_logits_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 371 |
+
|
| 372 |
+
preds["depth"] = self._frame_allgather_variable(depth_chunk, my_count, S, sp_size, sp_group, dim=1)
|
| 373 |
+
preds["depth_conf"] = self._frame_allgather_variable(depth_conf_chunk, my_count, S, sp_size, sp_group, dim=1)
|
| 374 |
+
if self.enable_depth_mask:
|
| 375 |
+
depth_mask_logits_full = self._frame_allgather_variable(
|
| 376 |
+
depth_mask_logits_chunk, my_count, S, sp_size, sp_group, dim=1,
|
| 377 |
+
)
|
| 378 |
+
preds["depth_mask_logits"] = depth_mask_logits_full
|
| 379 |
+
preds["depth_mask"] = depth_mask_logits_full.sigmoid()
|
| 380 |
+
|
| 381 |
+
if self.enable_pts:
|
| 382 |
+
if has_frames:
|
| 383 |
+
pts_chunk, pts_conf_chunk = self.pts_head(
|
| 384 |
+
token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
pts_chunk = torch.zeros(B, 0, H, W, 3, dtype=imgs.dtype, device=imgs.device)
|
| 388 |
+
pts_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 389 |
+
|
| 390 |
+
preds["pts3d"] = self._frame_allgather_variable(pts_chunk, my_count, S, sp_size, sp_group, dim=1)
|
| 391 |
+
preds["pts3d_conf"] = self._frame_allgather_variable(pts_conf_chunk, my_count, S, sp_size, sp_group, dim=1)
|
| 392 |
+
|
| 393 |
+
if self.enable_norm:
|
| 394 |
+
if has_frames:
|
| 395 |
+
normals_chunk, norm_conf_chunk = self.norm_head(
|
| 396 |
+
token_list_chunk, images=imgs_chunk, patch_start_idx=patch_start_idx,
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
normals_chunk = torch.zeros(B, 0, H, W, 3, dtype=imgs.dtype, device=imgs.device)
|
| 400 |
+
norm_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 401 |
+
|
| 402 |
+
preds["normals"] = self._frame_allgather_variable(normals_chunk, my_count, S, sp_size, sp_group, dim=1)
|
| 403 |
+
preds["normals_conf"] = self._frame_allgather_variable(norm_conf_chunk, my_count, S, sp_size, sp_group, dim=1)
|
| 404 |
+
|
| 405 |
+
# GS head: frame-parallel, then render on full gathered data
|
| 406 |
+
if self.enable_gs:
|
| 407 |
+
context_preds, context_nums = self.prepare_contexts(views, cond_flags, is_inference)
|
| 408 |
+
gs_token_list = context_preds.get("token_list", token_list)
|
| 409 |
+
gs_imgs = context_preds.get("imgs", imgs)
|
| 410 |
+
gs_S = gs_imgs.shape[1]
|
| 411 |
+
|
| 412 |
+
if gs_S == S and has_frames:
|
| 413 |
+
gs_token_chunk = [t[:, my_start:my_end].contiguous() for t in gs_token_list]
|
| 414 |
+
gs_imgs_chunk = gs_imgs[:, my_start:my_end].contiguous()
|
| 415 |
+
|
| 416 |
+
if self.enable_depth_mask:
|
| 417 |
+
gs_feat_chunk, gs_depth_chunk, gs_depth_conf_chunk, gs_dmask_chunk = self.gs_head(
|
| 418 |
+
gs_token_chunk, images=gs_imgs_chunk, patch_start_idx=patch_start_idx,
|
| 419 |
+
)
|
| 420 |
+
else:
|
| 421 |
+
gs_feat_chunk, gs_depth_chunk, gs_depth_conf_chunk = self.gs_head(
|
| 422 |
+
gs_token_chunk, images=gs_imgs_chunk, patch_start_idx=patch_start_idx,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
gs_feat = self._frame_allgather_variable(gs_feat_chunk, my_count, gs_S, sp_size, sp_group, dim=1)
|
| 426 |
+
gs_depth = self._frame_allgather_variable(gs_depth_chunk, my_count, gs_S, sp_size, sp_group, dim=1)
|
| 427 |
+
gs_depth_conf = self._frame_allgather_variable(gs_depth_conf_chunk, my_count, gs_S, sp_size, sp_group, dim=1)
|
| 428 |
+
if self.enable_depth_mask:
|
| 429 |
+
gs_depth_mask_logits = self._frame_allgather_variable(
|
| 430 |
+
gs_dmask_chunk, my_count, gs_S, sp_size, sp_group, dim=1,
|
| 431 |
+
)
|
| 432 |
+
preds["gs_depth_mask_logits"] = gs_depth_mask_logits
|
| 433 |
+
preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
|
| 434 |
+
|
| 435 |
+
elif gs_S == S and not has_frames:
|
| 436 |
+
gs_feat_c = self.gs_dim // 2
|
| 437 |
+
gs_feat_chunk = torch.zeros(B, 0, gs_feat_c, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 438 |
+
gs_depth_chunk = torch.zeros(B, 0, H, W, 1, dtype=imgs.dtype, device=imgs.device)
|
| 439 |
+
gs_depth_conf_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 440 |
+
|
| 441 |
+
gs_feat = self._frame_allgather_variable(gs_feat_chunk, 0, gs_S, sp_size, sp_group, dim=1)
|
| 442 |
+
gs_depth = self._frame_allgather_variable(gs_depth_chunk, 0, gs_S, sp_size, sp_group, dim=1)
|
| 443 |
+
gs_depth_conf = self._frame_allgather_variable(gs_depth_conf_chunk, 0, gs_S, sp_size, sp_group, dim=1)
|
| 444 |
+
if self.enable_depth_mask:
|
| 445 |
+
gs_dmask_chunk = torch.zeros(B, 0, H, W, dtype=imgs.dtype, device=imgs.device)
|
| 446 |
+
gs_depth_mask_logits = self._frame_allgather_variable(
|
| 447 |
+
gs_dmask_chunk, 0, gs_S, sp_size, sp_group, dim=1,
|
| 448 |
+
)
|
| 449 |
+
preds["gs_depth_mask_logits"] = gs_depth_mask_logits
|
| 450 |
+
preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
|
| 451 |
+
else:
|
| 452 |
+
if self.enable_depth_mask:
|
| 453 |
+
gs_feat, gs_depth, gs_depth_conf, gs_depth_mask_logits = self.gs_head(
|
| 454 |
+
gs_token_list, images=gs_imgs, patch_start_idx=patch_start_idx,
|
| 455 |
+
)
|
| 456 |
+
preds["gs_depth_mask_logits"] = gs_depth_mask_logits
|
| 457 |
+
preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
|
| 458 |
+
else:
|
| 459 |
+
gs_feat, gs_depth, gs_depth_conf = self.gs_head(
|
| 460 |
+
gs_token_list, images=gs_imgs, patch_start_idx=patch_start_idx,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
preds["gs_depth"] = gs_depth
|
| 464 |
+
preds["gs_depth_conf"] = gs_depth_conf
|
| 465 |
+
|
| 466 |
+
preds = self.gs_renderer.render(
|
| 467 |
+
gs_feats=gs_feat,
|
| 468 |
+
images=imgs,
|
| 469 |
+
predictions=preds,
|
| 470 |
+
views=views,
|
| 471 |
+
context_predictions=context_preds,
|
| 472 |
+
is_inference=is_inference,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
return preds
|
| 476 |
+
|
| 477 |
+
def _frame_allgather_variable(self, chunk, my_count, total_S, sp_size, sp_group, dim=1):
|
| 478 |
+
"""Allgather tensors with potentially variable chunk sizes across ranks.
|
| 479 |
+
|
| 480 |
+
Pads each chunk to max_chunk_size, allgathers, then extracts valid frames
|
| 481 |
+
from each rank's chunk to reconstruct the correct frame order.
|
| 482 |
+
"""
|
| 483 |
+
if sp_size <= 1:
|
| 484 |
+
return chunk
|
| 485 |
+
|
| 486 |
+
if total_S >= sp_size:
|
| 487 |
+
base_chunk = total_S // sp_size
|
| 488 |
+
remainder = total_S % sp_size
|
| 489 |
+
counts = [(base_chunk + 1) if r < remainder else base_chunk
|
| 490 |
+
for r in range(sp_size)]
|
| 491 |
+
else:
|
| 492 |
+
counts = [1 if r < total_S else 0 for r in range(sp_size)]
|
| 493 |
+
|
| 494 |
+
max_chunk = max(counts)
|
| 495 |
+
|
| 496 |
+
current_size = chunk.shape[dim]
|
| 497 |
+
if current_size < max_chunk:
|
| 498 |
+
pad_size = max_chunk - current_size
|
| 499 |
+
pad_shape = list(chunk.shape)
|
| 500 |
+
pad_shape[dim] = pad_size
|
| 501 |
+
padding = torch.zeros(pad_shape, dtype=chunk.dtype, device=chunk.device)
|
| 502 |
+
chunk = torch.cat([chunk, padding], dim=dim)
|
| 503 |
+
|
| 504 |
+
chunk = chunk.contiguous()
|
| 505 |
+
gathered_list = [torch.zeros_like(chunk) for _ in range(sp_size)]
|
| 506 |
+
dist.all_gather(gathered_list, chunk, group=sp_group)
|
| 507 |
+
|
| 508 |
+
valid_chunks = []
|
| 509 |
+
for r in range(sp_size):
|
| 510 |
+
cnt = counts[r]
|
| 511 |
+
if cnt > 0:
|
| 512 |
+
slices = [slice(None)] * gathered_list[r].dim()
|
| 513 |
+
slices[dim] = slice(0, cnt)
|
| 514 |
+
valid_chunks.append(gathered_list[r][tuple(slices)])
|
| 515 |
+
|
| 516 |
+
return torch.cat(valid_chunks, dim=dim).contiguous()
|
| 517 |
+
|
| 518 |
+
def _gen_all_preds(
|
| 519 |
+
self, token_list, imgs, patch_start_idx, views, cond_flags, is_inference
|
| 520 |
+
):
|
| 521 |
+
"""Generate all enabled predictions (single-GPU path)."""
|
| 522 |
+
preds = {}
|
| 523 |
+
|
| 524 |
+
if self.enable_cam:
|
| 525 |
+
cam_seq = self.cam_head(token_list)
|
| 526 |
+
cam_params = cam_seq[-1]
|
| 527 |
+
preds["camera_params"] = cam_params
|
| 528 |
+
c2w_mat, int_mat = self.transform_camera_vector(
|
| 529 |
+
cam_params, imgs.shape[-2], imgs.shape[-1]
|
| 530 |
+
)
|
| 531 |
+
preds["camera_poses"] = c2w_mat
|
| 532 |
+
preds["camera_intrs"] = int_mat
|
| 533 |
+
|
| 534 |
+
if self.enable_depth:
|
| 535 |
+
if self.enable_depth_mask:
|
| 536 |
+
depth, depth_conf, depth_mask_logits = self.depth_head(
|
| 537 |
+
token_list, images=imgs, patch_start_idx=patch_start_idx,
|
| 538 |
+
)
|
| 539 |
+
preds["depth_mask_logits"] = depth_mask_logits
|
| 540 |
+
preds["depth_mask"] = depth_mask_logits.sigmoid()
|
| 541 |
+
else:
|
| 542 |
+
depth, depth_conf = self.depth_head(
|
| 543 |
+
token_list, images=imgs, patch_start_idx=patch_start_idx,
|
| 544 |
+
)
|
| 545 |
+
preds["depth"] = depth
|
| 546 |
+
preds["depth_conf"] = depth_conf
|
| 547 |
+
|
| 548 |
+
if self.enable_pts:
|
| 549 |
+
pts, pts_conf = self.pts_head(
|
| 550 |
+
token_list, images=imgs, patch_start_idx=patch_start_idx,
|
| 551 |
+
)
|
| 552 |
+
preds["pts3d"] = pts
|
| 553 |
+
preds["pts3d_conf"] = pts_conf
|
| 554 |
+
|
| 555 |
+
if self.enable_norm:
|
| 556 |
+
normals, norm_conf = self.norm_head(
|
| 557 |
+
token_list, images=imgs, patch_start_idx=patch_start_idx,
|
| 558 |
+
)
|
| 559 |
+
preds["normals"] = normals
|
| 560 |
+
preds["normals_conf"] = norm_conf
|
| 561 |
+
|
| 562 |
+
if self.enable_gs:
|
| 563 |
+
context_preds, context_nums = self.prepare_contexts(views, cond_flags, is_inference)
|
| 564 |
+
if self.enable_depth_mask:
|
| 565 |
+
gs_feat, gs_depth, gs_depth_conf, gs_depth_mask_logits = self.gs_head(
|
| 566 |
+
context_preds.get("token_list", token_list),
|
| 567 |
+
images=context_preds.get("imgs", imgs),
|
| 568 |
+
patch_start_idx=patch_start_idx,
|
| 569 |
+
)
|
| 570 |
+
preds["gs_depth_mask_logits"] = gs_depth_mask_logits
|
| 571 |
+
preds["gs_depth_mask"] = gs_depth_mask_logits.sigmoid()
|
| 572 |
+
else:
|
| 573 |
+
gs_feat, gs_depth, gs_depth_conf = self.gs_head(
|
| 574 |
+
context_preds.get("token_list", token_list),
|
| 575 |
+
images=context_preds.get("imgs", imgs),
|
| 576 |
+
patch_start_idx=patch_start_idx,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
preds["gs_depth"] = gs_depth
|
| 580 |
+
preds["gs_depth_conf"] = gs_depth_conf
|
| 581 |
+
|
| 582 |
+
preds = self.gs_renderer.render(
|
| 583 |
+
gs_feats=gs_feat,
|
| 584 |
+
images=imgs,
|
| 585 |
+
predictions=preds,
|
| 586 |
+
views=views,
|
| 587 |
+
context_predictions=context_preds,
|
| 588 |
+
is_inference=is_inference,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
return preds
|
| 592 |
+
|
| 593 |
+
def extract_priors(self, views):
|
| 594 |
+
"""Extract and normalize geometric priors from input views.
|
| 595 |
+
|
| 596 |
+
Returns (depths, rays, poses) tuple — each may be None if unavailable.
|
| 597 |
+
"""
|
| 598 |
+
h, w = views["img"].shape[-2:]
|
| 599 |
+
depths = rays = poses = None
|
| 600 |
+
|
| 601 |
+
if "camera_poses" in views:
|
| 602 |
+
extrinsics = views["camera_poses"][:, :, :3]
|
| 603 |
+
extrinsics = normalize_poses(extrinsics)
|
| 604 |
+
cam_params = extrinsics_to_vector(extrinsics)
|
| 605 |
+
poses = cam_params[:, :, :7]
|
| 606 |
+
if self.enable_bf16:
|
| 607 |
+
poses = poses.to(torch.bfloat16)
|
| 608 |
+
|
| 609 |
+
if "depthmap" in views:
|
| 610 |
+
depth_h, depth_w = views["depthmap"].shape[-2:]
|
| 611 |
+
depths = views["depthmap"]
|
| 612 |
+
if depth_h != h or depth_w != w:
|
| 613 |
+
depths = F.interpolate(depths, size=(h, w), mode="bilinear", align_corners=False)
|
| 614 |
+
depths = normalize_depth(depths)
|
| 615 |
+
if self.enable_bf16:
|
| 616 |
+
depths = depths.to(torch.bfloat16)
|
| 617 |
+
|
| 618 |
+
if "camera_intrs" in views:
|
| 619 |
+
intrinsics = views["camera_intrs"][:, :, :3, :3]
|
| 620 |
+
fx, fy = intrinsics[:, :, 0, 0] / w, intrinsics[:, :, 1, 1] / h
|
| 621 |
+
cx, cy = intrinsics[:, :, 0, 2] / w, intrinsics[:, :, 1, 2] / h
|
| 622 |
+
rays = torch.stack([fx, fy, cx, cy], dim=-1)
|
| 623 |
+
if self.enable_bf16:
|
| 624 |
+
rays = rays.to(torch.bfloat16)
|
| 625 |
+
|
| 626 |
+
return (depths, rays, poses)
|
| 627 |
+
|
| 628 |
+
def transform_camera_vector(self, camera_params, h, w):
|
| 629 |
+
"""Convert camera parameter vector to c2w and intrinsic matrices."""
|
| 630 |
+
ext_mat, int_mat = vector_to_camera_matrices(camera_params, image_hw=(h, w))
|
| 631 |
+
homo_row = torch.tensor([0, 0, 0, 1], device=ext_mat.device).view(1, 1, 1, 4)
|
| 632 |
+
homo_row = homo_row.repeat(ext_mat.shape[0], ext_mat.shape[1], 1, 1)
|
| 633 |
+
w2c_mat = torch.cat([ext_mat, homo_row], dim=2)
|
| 634 |
+
try:
|
| 635 |
+
c2w_mat = torch.linalg.inv(w2c_mat)
|
| 636 |
+
except Exception as e:
|
| 637 |
+
print(f"[WorldMirror] linalg.inv fallback to CPU: {e}")
|
| 638 |
+
c2w_mat = torch.linalg.inv(w2c_mat.cpu()).to(camera_params.device)
|
| 639 |
+
return c2w_mat, int_mat
|
| 640 |
+
|
| 641 |
+
def prepare_contexts(self, views, cond_flags, is_inference):
|
| 642 |
+
"""Prepare context views for GS rendering (training only, passthrough in inference)."""
|
| 643 |
+
context_preds = {}
|
| 644 |
+
if is_inference:
|
| 645 |
+
return context_preds, views["img"].shape[1]
|
| 646 |
+
|
| 647 |
+
assert self.enable_cam and self.enable_gs
|
| 648 |
+
if "is_target" not in views:
|
| 649 |
+
context_nums = views["img"].shape[1]
|
| 650 |
+
else:
|
| 651 |
+
context_nums = (views["is_target"][0] == False).sum().item()
|
| 652 |
+
context_imgs = views["img"][:, :context_nums]
|
| 653 |
+
|
| 654 |
+
use_cond = sum(cond_flags) > 0
|
| 655 |
+
|
| 656 |
+
if self.enable_bf16:
|
| 657 |
+
context_imgs = context_imgs.to(torch.bfloat16)
|
| 658 |
+
|
| 659 |
+
with torch.amp.autocast('cuda', enabled=(not self.enable_bf16), dtype=torch.bfloat16):
|
| 660 |
+
if use_cond:
|
| 661 |
+
priors = self.extract_priors(views)
|
| 662 |
+
context_priors = (
|
| 663 |
+
prior[:, :context_nums] if prior is not None else None
|
| 664 |
+
for prior in priors
|
| 665 |
+
)
|
| 666 |
+
context_token_list, _ = self.visual_geometry_transformer(
|
| 667 |
+
context_imgs, context_priors, cond_flags=cond_flags,
|
| 668 |
+
enable_bf16=self.enable_bf16,
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
context_token_list, _ = self.visual_geometry_transformer(
|
| 672 |
+
context_imgs, enable_bf16=self.enable_bf16,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
context_cam_seq = self.cam_head(context_token_list)
|
| 676 |
+
context_cam_params = context_cam_seq[-1]
|
| 677 |
+
context_c2w_mat, context_int_mat = self.transform_camera_vector(
|
| 678 |
+
context_cam_params, context_imgs.shape[-2], context_imgs.shape[-1]
|
| 679 |
+
)
|
| 680 |
+
context_preds["camera_poses"] = context_c2w_mat
|
| 681 |
+
context_preds["camera_intrs"] = context_int_mat
|
| 682 |
+
context_preds["token_list"] = context_token_list
|
| 683 |
+
context_preds["imgs"] = context_imgs
|
| 684 |
+
|
| 685 |
+
return context_preds, context_nums
|
hyworldmirror/models/utils/__init__.py
ADDED
|
File without changes
|
hyworldmirror/models/utils/act_gs.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def reg_dense_offsets(xyz, shift=6.0):
|
| 6 |
+
d = xyz.norm(dim=-1, keepdim=True)
|
| 7 |
+
return xyz / d.clamp(min=1e-8) * (torch.exp(d - shift) - torch.exp(-shift))
|
| 8 |
+
|
| 9 |
+
def reg_dense_scales(scales):
|
| 10 |
+
return scales.exp()
|
| 11 |
+
|
| 12 |
+
def reg_dense_rotation(rotations, eps=1e-8):
|
| 13 |
+
return rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
|
| 14 |
+
|
| 15 |
+
def reg_dense_sh(sh):
|
| 16 |
+
return rearrange(sh, '... (d_sh xyz) -> ... d_sh xyz', xyz=3)
|
| 17 |
+
|
| 18 |
+
def reg_dense_opacities(opacities):
|
| 19 |
+
return opacities.sigmoid()
|
| 20 |
+
|
| 21 |
+
def reg_dense_weights(weights):
|
| 22 |
+
return weights.sigmoid()
|
hyworldmirror/models/utils/camera_utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .rotation import quat_to_rotmat, rotmat_to_quat
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def camera_params_to_vector(
|
| 6 |
+
ext, intr, image_hw=None
|
| 7 |
+
):
|
| 8 |
+
"""Convert camera matrices to a compact vector."""
|
| 9 |
+
# ext: (..., 3, 4): Camera-to-world extrinsic [R|t]
|
| 10 |
+
# intr: (..., 3, 3): Intrinsics
|
| 11 |
+
# image_hw: (h, w)
|
| 12 |
+
R = ext[..., :3, :3] # Rotation part
|
| 13 |
+
t = ext[..., :3, 3] # Translation part
|
| 14 |
+
q = rotmat_to_quat(R) # Quaternion (wxyz)
|
| 15 |
+
h, w = image_hw
|
| 16 |
+
fov_v = 2.0 * torch.atan(h * 0.5 / intr[..., 1, 1]) # Vertical FOV
|
| 17 |
+
fov_u = 2.0 * torch.atan(w * 0.5 / intr[..., 0, 0]) # Horizontal FOV
|
| 18 |
+
vec = torch.stack([
|
| 19 |
+
t[..., 0], t[..., 1], t[..., 2],
|
| 20 |
+
q[..., 0], q[..., 1], q[..., 2], q[..., 3],
|
| 21 |
+
fov_v, fov_u
|
| 22 |
+
], dim=-1).float()
|
| 23 |
+
return vec
|
| 24 |
+
|
| 25 |
+
def extrinsics_to_vector(ext):
|
| 26 |
+
"""Convert extrinsics to [t, q] vector."""
|
| 27 |
+
# ext: (..., 3, 4)
|
| 28 |
+
R = ext[..., :3, :3]
|
| 29 |
+
t = ext[..., :3, 3]
|
| 30 |
+
q = rotmat_to_quat(R)
|
| 31 |
+
vec = torch.stack([
|
| 32 |
+
t[..., 0], t[..., 1], t[..., 2],
|
| 33 |
+
q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
| 34 |
+
], dim=-1).float()
|
| 35 |
+
return vec
|
| 36 |
+
|
| 37 |
+
def vector_to_extrinsics(cam_vec):
|
| 38 |
+
"""Convert [t, q] vector to extrinsic matrix."""
|
| 39 |
+
# cam_vec: (..., 7)
|
| 40 |
+
q = cam_vec[..., 3:7]
|
| 41 |
+
t = cam_vec[..., :3]
|
| 42 |
+
R = quat_to_rotmat(q)
|
| 43 |
+
ext = torch.cat([R, t.unsqueeze(-1)], dim=-1)
|
| 44 |
+
return ext
|
| 45 |
+
|
| 46 |
+
def vector_to_camera_matrices(
|
| 47 |
+
cam_vec, image_hw=None, build_intr=True
|
| 48 |
+
):
|
| 49 |
+
"""Reconstruct extrinsic and intrinsic matrix from vector."""
|
| 50 |
+
# cam_vec: (..., 9)
|
| 51 |
+
intr = None
|
| 52 |
+
# Decompose vector
|
| 53 |
+
t = cam_vec[..., 0:3]
|
| 54 |
+
q = cam_vec[..., 3:7]
|
| 55 |
+
fov_v = cam_vec[..., 7]
|
| 56 |
+
fov_u = cam_vec[..., 8]
|
| 57 |
+
|
| 58 |
+
# Build extrinsic: [R|t]
|
| 59 |
+
R = quat_to_rotmat(q)
|
| 60 |
+
ext = torch.cat([R, t.unsqueeze(-1)], dim=-1)
|
| 61 |
+
|
| 62 |
+
# Build intrinsic if needed
|
| 63 |
+
if build_intr:
|
| 64 |
+
h, w = image_hw
|
| 65 |
+
fy = h * 0.5 / torch.tan(fov_v * 0.5)
|
| 66 |
+
fx = w * 0.5 / torch.tan(fov_u * 0.5)
|
| 67 |
+
shape = cam_vec.shape[:-1] + (3, 3)
|
| 68 |
+
intr = torch.zeros(shape, device=cam_vec.device, dtype=cam_vec.dtype)
|
| 69 |
+
intr[..., 0, 0] = fx
|
| 70 |
+
intr[..., 1, 1] = fy
|
| 71 |
+
intr[..., 0, 2] = w * 0.5
|
| 72 |
+
intr[..., 1, 2] = h * 0.5
|
| 73 |
+
intr[..., 2, 2] = 1.0
|
| 74 |
+
|
| 75 |
+
return ext, intr
|
hyworldmirror/models/utils/frustum.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Calculate the loss mask for the target views in the batch
|
| 6 |
+
@torch.no_grad()
|
| 7 |
+
def calculate_unprojected_mask(views, context_nums):
|
| 8 |
+
'''Calcuate the loss mask for the target views in the batch'''
|
| 9 |
+
target_depth = views["depthmap"][:, context_nums:]
|
| 10 |
+
target_intrinsics = views["camera_intrs"][:, context_nums:]
|
| 11 |
+
target_c2w = views["camera_poses"][:, context_nums:]
|
| 12 |
+
context_depth = views["depthmap"][:, :context_nums]
|
| 13 |
+
context_intrinsics = views["camera_intrs"][:, :context_nums]
|
| 14 |
+
context_c2w = views["camera_poses"][:, :context_nums]
|
| 15 |
+
|
| 16 |
+
target_intrinsics = target_intrinsics[..., :3, :3]
|
| 17 |
+
context_intrinsics = context_intrinsics[..., :3, :3]
|
| 18 |
+
|
| 19 |
+
mask = calculate_in_frustum_mask(
|
| 20 |
+
target_depth, target_intrinsics, target_c2w,
|
| 21 |
+
context_depth, context_intrinsics, context_c2w
|
| 22 |
+
)
|
| 23 |
+
return mask
|
| 24 |
+
|
| 25 |
+
@torch.no_grad()
|
| 26 |
+
def calculate_in_frustum_mask(depth_1, intrinsics_1, c2w_1, depth_2, intrinsics_2, c2w_2):
|
| 27 |
+
"""
|
| 28 |
+
A function that takes in the depth, intrinsics and c2w matrices of two sets
|
| 29 |
+
of views, and then works out which of the pixels in the first set of views
|
| 30 |
+
has a direct corresponding pixel in any of views in the second set
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
depth_1: (b, v1, h, w)
|
| 34 |
+
intrinsics_1: (b, v1, 3, 3)
|
| 35 |
+
c2w_1: (b, v1, 4, 4)
|
| 36 |
+
depth_2: (b, v2, h, w)
|
| 37 |
+
intrinsics_2: (b, v2, 3, 3)
|
| 38 |
+
c2w_2: (b, v2, 4, 4)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
torch.Tensor: valid mask with shape (b, v1, v2, h, w).
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
_, v1, h, w = depth_1.shape
|
| 45 |
+
_, v2, _, _ = depth_2.shape
|
| 46 |
+
|
| 47 |
+
# Unproject the depth to get the 3D points in world space
|
| 48 |
+
points_3d = unproject_depth(depth_1[..., None], intrinsics_1, c2w_1) # (b, v1, h, w, 3)
|
| 49 |
+
|
| 50 |
+
# Project the 3D points into the pixel space of all the second views simultaneously
|
| 51 |
+
camera_points = world_space_to_camera_space(points_3d, c2w_2) # (b, v1, v2, h, w, 3)
|
| 52 |
+
points_2d = camera_space_to_pixel_space(camera_points, intrinsics_2) # (b, v1, v2, h, w, 2)
|
| 53 |
+
|
| 54 |
+
# Calculate the depth of each point
|
| 55 |
+
rendered_depth = camera_points[..., 2] # (b, v1, v2, h, w)
|
| 56 |
+
|
| 57 |
+
# We use three conditions to determine if a point should be masked
|
| 58 |
+
|
| 59 |
+
# Condition 1: Check if the points are in the frustum of any of the v2 views
|
| 60 |
+
in_frustum_mask = (
|
| 61 |
+
(points_2d[..., 0] > 0) &
|
| 62 |
+
(points_2d[..., 0] < w) &
|
| 63 |
+
(points_2d[..., 1] > 0) &
|
| 64 |
+
(points_2d[..., 1] < h)
|
| 65 |
+
) # (b, v1, v2, h, w)
|
| 66 |
+
in_frustum_mask = in_frustum_mask.any(dim=-3) # (b, v1, h, w)
|
| 67 |
+
|
| 68 |
+
# Condition 2: Check if the points have non-zero (i.e. valid) depth in the input view
|
| 69 |
+
non_zero_depth = depth_1 > 1e-6
|
| 70 |
+
|
| 71 |
+
# Condition 3: Check if the points have matching depth to any of the v2
|
| 72 |
+
# views torch.nn.functional.grid_sample expects the input coordinates to
|
| 73 |
+
# be normalized to the range [-1, 1], so we normalize first
|
| 74 |
+
points_2d[..., 0] /= w
|
| 75 |
+
points_2d[..., 1] /= h
|
| 76 |
+
points_2d = points_2d * 2 - 1
|
| 77 |
+
matching_depth = torch.ones_like(rendered_depth, dtype=torch.bool)
|
| 78 |
+
for b in range(depth_1.shape[0]):
|
| 79 |
+
for i in range(v1):
|
| 80 |
+
for j in range(v2):
|
| 81 |
+
depth = einops.rearrange(depth_2[b, j], 'h w -> 1 1 h w')
|
| 82 |
+
coords = einops.rearrange(points_2d[b, i, j], 'h w c -> 1 h w c')
|
| 83 |
+
sampled_depths = torch.nn.functional.grid_sample(depth, coords, align_corners=False)[0, 0]
|
| 84 |
+
matching_depth[b, i, j] = torch.isclose(rendered_depth[b, i, j], sampled_depths, atol=1e-1)
|
| 85 |
+
|
| 86 |
+
matching_depth = matching_depth.any(dim=-3) # (..., v1, h, w)
|
| 87 |
+
|
| 88 |
+
mask = in_frustum_mask & non_zero_depth & matching_depth
|
| 89 |
+
return mask
|
| 90 |
+
|
| 91 |
+
# --- Projections ---
|
| 92 |
+
def homogenize_points(points):
|
| 93 |
+
"""Append a '1' along the final dimension of the tensor (i.e. convert xyz->xyz1)"""
|
| 94 |
+
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def normalize_homogenous_points(points):
|
| 98 |
+
"""Normalize the point vectors"""
|
| 99 |
+
return points / points[..., -1:]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
|
| 103 |
+
"""
|
| 104 |
+
Convert pixel space points to camera space points.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
|
| 108 |
+
depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
|
| 109 |
+
intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
torch.Tensor: Camera space points with shape (b, v, h, w, 3).
|
| 113 |
+
"""
|
| 114 |
+
pixel_space_points = homogenize_points(pixel_space_points)
|
| 115 |
+
camera_space_points = torch.einsum('b v i j , h w j -> b v h w i', intrinsics.inverse(), pixel_space_points)
|
| 116 |
+
camera_space_points = camera_space_points * depth
|
| 117 |
+
return camera_space_points
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def camera_space_to_world_space(camera_space_points, c2w):
|
| 121 |
+
"""
|
| 122 |
+
Convert camera space points to world space points.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
|
| 126 |
+
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor: World space points with shape (b, v, h, w, 3).
|
| 130 |
+
"""
|
| 131 |
+
camera_space_points = homogenize_points(camera_space_points)
|
| 132 |
+
world_space_points = torch.einsum('b v i j , b v h w j -> b v h w i', c2w, camera_space_points)
|
| 133 |
+
return world_space_points[..., :3]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def camera_space_to_pixel_space(camera_space_points, intrinsics):
|
| 137 |
+
"""
|
| 138 |
+
Convert camera space points to pixel space points.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
|
| 142 |
+
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
|
| 146 |
+
"""
|
| 147 |
+
camera_space_points = normalize_homogenous_points(camera_space_points)
|
| 148 |
+
pixel_space_points = torch.einsum('b u i j , b v u h w j -> b v u h w i', intrinsics, camera_space_points)
|
| 149 |
+
return pixel_space_points[..., :2]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def world_space_to_camera_space(world_space_points, c2w):
|
| 153 |
+
"""
|
| 154 |
+
Convert world space points to pixel space points.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
|
| 158 |
+
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
|
| 162 |
+
"""
|
| 163 |
+
world_space_points = homogenize_points(world_space_points)
|
| 164 |
+
camera_space_points = torch.einsum('b u i j , b v h w j -> b v u h w i', c2w.inverse(), world_space_points)
|
| 165 |
+
return camera_space_points[..., :3]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def unproject_depth(depth, intrinsics, c2w):
|
| 169 |
+
"""
|
| 170 |
+
Turn the depth map into a 3D point cloud in world space
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
depth: (b, v, h, w, 1)
|
| 174 |
+
intrinsics: (b, v, 3, 3)
|
| 175 |
+
c2w: (b, v, 4, 4)
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
torch.Tensor: World space points with shape (b, v, h, w, 3).
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
# Compute indices of pixels
|
| 182 |
+
h, w = depth.shape[-3], depth.shape[-2]
|
| 183 |
+
x_grid, y_grid = torch.meshgrid(
|
| 184 |
+
torch.arange(w, device=depth.device, dtype=torch.float32),
|
| 185 |
+
torch.arange(h, device=depth.device, dtype=torch.float32),
|
| 186 |
+
indexing='xy'
|
| 187 |
+
) # (h, w), (h, w)
|
| 188 |
+
|
| 189 |
+
# Compute coordinates of pixels in camera space
|
| 190 |
+
pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
|
| 191 |
+
camera_points = pixel_space_to_camera_space(pixel_space_points, depth, intrinsics) # (..., h, w, 3)
|
| 192 |
+
|
| 193 |
+
# Convert points to world space
|
| 194 |
+
world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
|
| 195 |
+
|
| 196 |
+
return world_points
|
hyworldmirror/models/utils/geometry.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
def depth_to_camera_coords(depthmap, camera_intrinsics):
|
| 6 |
+
"""
|
| 7 |
+
Convert depth map to 3D camera coordinates.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
depthmap (BxHxW tensor): Batch of depth maps
|
| 11 |
+
camera_intrinsics (Bx3x3 tensor): Camera intrinsics matrix for each camera
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
X_cam (BxHxWx3 tensor): 3D points in camera coordinates
|
| 15 |
+
valid_mask (BxHxW tensor): Mask indicating valid depth pixels
|
| 16 |
+
"""
|
| 17 |
+
B, H, W = depthmap.shape
|
| 18 |
+
device = depthmap.device
|
| 19 |
+
dtype = depthmap.dtype
|
| 20 |
+
|
| 21 |
+
# Ensure intrinsics are float
|
| 22 |
+
camera_intrinsics = camera_intrinsics.float()
|
| 23 |
+
|
| 24 |
+
# Extract focal lengths and principal points
|
| 25 |
+
fx = camera_intrinsics[:, 0, 0] # (B,)
|
| 26 |
+
fy = camera_intrinsics[:, 1, 1] # (B,)
|
| 27 |
+
cx = camera_intrinsics[:, 0, 2] # (B,)
|
| 28 |
+
cy = camera_intrinsics[:, 1, 2] # (B,)
|
| 29 |
+
|
| 30 |
+
# Generate pixel grid
|
| 31 |
+
v_grid, u_grid = torch.meshgrid(
|
| 32 |
+
torch.arange(H, dtype=dtype, device=device),
|
| 33 |
+
torch.arange(W, dtype=dtype, device=device),
|
| 34 |
+
indexing='ij'
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Reshape for broadcasting: (1, H, W)
|
| 38 |
+
u_grid = u_grid.unsqueeze(0)
|
| 39 |
+
v_grid = v_grid.unsqueeze(0)
|
| 40 |
+
|
| 41 |
+
# Compute 3D camera coordinates
|
| 42 |
+
# X = (u - cx) * Z / fx
|
| 43 |
+
# Y = (v - cy) * Z / fy
|
| 44 |
+
# Z = depth
|
| 45 |
+
z_cam = depthmap # (B, H, W)
|
| 46 |
+
x_cam = (u_grid - cx.view(B, 1, 1)) * z_cam / fx.view(B, 1, 1)
|
| 47 |
+
y_cam = (v_grid - cy.view(B, 1, 1)) * z_cam / fy.view(B, 1, 1)
|
| 48 |
+
|
| 49 |
+
# Stack to form (B, H, W, 3)
|
| 50 |
+
X_cam = torch.stack([x_cam, y_cam, z_cam], dim=-1)
|
| 51 |
+
|
| 52 |
+
# Valid depth mask
|
| 53 |
+
valid_mask = depthmap > 0.0
|
| 54 |
+
|
| 55 |
+
return X_cam, valid_mask
|
| 56 |
+
|
| 57 |
+
def depth_to_world_coords_points(
|
| 58 |
+
depth_map: torch.Tensor, extrinsic: torch.Tensor, intrinsic: torch.Tensor, eps=1e-8
|
| 59 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 60 |
+
"""
|
| 61 |
+
Convert a batch of depth maps to world coordinates.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
depth_map (torch.Tensor): (B, H, W) Depth map
|
| 65 |
+
extrinsic (torch.Tensor): (B, 4, 4) Camera extrinsic matrix (camera-to-world transformation)
|
| 66 |
+
intrinsic (torch.Tensor): (B, 3, 3) Camera intrinsic matrix
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
world_coords_points (torch.Tensor): (B, H, W, 3) World coordinates
|
| 70 |
+
camera_points (torch.Tensor): (B, H, W, 3) Camera coordinates
|
| 71 |
+
point_mask (torch.Tensor): (B, H, W) Valid depth mask
|
| 72 |
+
"""
|
| 73 |
+
if depth_map is None:
|
| 74 |
+
return None, None, None
|
| 75 |
+
|
| 76 |
+
# Valid depth mask (B, H, W)
|
| 77 |
+
point_mask = depth_map > eps
|
| 78 |
+
|
| 79 |
+
# Convert depth map to camera coordinates (B, H, W, 3)
|
| 80 |
+
camera_points, _ = depth_to_camera_coords(depth_map, intrinsic)
|
| 81 |
+
|
| 82 |
+
# Apply extrinsic matrix (camera -> world)
|
| 83 |
+
R_cam_to_world = extrinsic[:, :3, :3] # (B, 3, 3)
|
| 84 |
+
t_cam_to_world = extrinsic[:, :3, 3] # (B, 3)
|
| 85 |
+
|
| 86 |
+
# Transform (B, H, W, 3) x (B, 3, 3)^T + (B, 3) -> (B, H, W, 3)
|
| 87 |
+
world_coords_points = torch.einsum('bhwi,bji->bhwj', camera_points, R_cam_to_world) + t_cam_to_world[:, None, None, :]
|
| 88 |
+
|
| 89 |
+
return world_coords_points, camera_points, point_mask
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def closed_form_inverse_se3(se3: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
Efficiently invert batched SE(3) matrices of shape (B, 4, 4).
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
se3 (torch.Tensor): (B, 4, 4) Transformation matrices
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
out (torch.Tensor): (B, 4, 4) Inverse transformation matrices
|
| 101 |
+
"""
|
| 102 |
+
assert se3.ndim == 3 and se3.shape[1:] == (4, 4), f"se3 must be (B, 4, 4), got {se3.shape}"
|
| 103 |
+
R = se3[:, :3, :3] # (B, 3, 3)
|
| 104 |
+
t = se3[:, :3, 3] # (B, 3)
|
| 105 |
+
Rt = R.transpose(1, 2) # (B, 3, 3)
|
| 106 |
+
t_inv = -torch.bmm(Rt, t.unsqueeze(-1)).squeeze(-1) # (B, 3)
|
| 107 |
+
out = se3.new_zeros(se3.shape)
|
| 108 |
+
out[:, :3, :3] = Rt
|
| 109 |
+
out[:, :3, 3] = t_inv
|
| 110 |
+
out[:, 3, 3] = 1.0
|
| 111 |
+
return out
|
hyworldmirror/models/utils/grid.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 5 |
+
"""
|
| 6 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 10 |
+
embed_dim: Output channel dimension for embeddings
|
| 11 |
+
omega_0: Base frequency for sinusoidal encoding
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 15 |
+
"""
|
| 16 |
+
H, W, grid_dim = pos_grid.shape
|
| 17 |
+
assert grid_dim == 2
|
| 18 |
+
assert embed_dim % 2 == 0
|
| 19 |
+
|
| 20 |
+
device = pos_grid.device
|
| 21 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 22 |
+
|
| 23 |
+
# Generate frequency bands
|
| 24 |
+
omega = torch.arange(embed_dim // 4, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
| 25 |
+
omega /= embed_dim / 4.0
|
| 26 |
+
omega = 1.0 / omega_0**omega # (D/4,)
|
| 27 |
+
|
| 28 |
+
# Process x and y coordinates separately
|
| 29 |
+
pos_x = pos_flat[:, 0].reshape(-1) # (H*W,)
|
| 30 |
+
pos_y = pos_flat[:, 1].reshape(-1) # (H*W,)
|
| 31 |
+
|
| 32 |
+
# Compute outer products
|
| 33 |
+
out_x = torch.einsum("m,d->md", pos_x, omega) # (H*W, D/4)
|
| 34 |
+
out_y = torch.einsum("m,d->md", pos_y, omega) # (H*W, D/4)
|
| 35 |
+
|
| 36 |
+
# Apply sin and cos
|
| 37 |
+
emb_x = torch.cat([torch.sin(out_x), torch.cos(out_x)], dim=1) # (H*W, D/2)
|
| 38 |
+
emb_y = torch.cat([torch.sin(out_y), torch.cos(out_y)], dim=1) # (H*W, D/2)
|
| 39 |
+
|
| 40 |
+
# Combine x and y embeddings
|
| 41 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # (H*W, D)
|
| 42 |
+
|
| 43 |
+
return emb.float().view(H, W, embed_dim) # [H, W, D]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Inspired by https://github.com/microsoft/moge
|
| 47 |
+
def create_uv_grid(
|
| 48 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 49 |
+
) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 52 |
+
|
| 53 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 54 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 55 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
width (int): Number of points horizontally.
|
| 59 |
+
height (int): Number of points vertically.
|
| 60 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 61 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 62 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 66 |
+
"""
|
| 67 |
+
# Derive aspect ratio if not explicitly provided
|
| 68 |
+
if aspect_ratio is None:
|
| 69 |
+
aspect_ratio = float(width) / float(height)
|
| 70 |
+
|
| 71 |
+
# Compute normalized spans for X and Y
|
| 72 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 73 |
+
span_x = aspect_ratio / diag_factor
|
| 74 |
+
span_y = 1.0 / diag_factor
|
| 75 |
+
|
| 76 |
+
# Establish the linspace boundaries
|
| 77 |
+
left_x = -span_x * (width - 1) / width
|
| 78 |
+
right_x = span_x * (width - 1) / width
|
| 79 |
+
top_y = -span_y * (height - 1) / height
|
| 80 |
+
bottom_y = span_y * (height - 1) / height
|
| 81 |
+
|
| 82 |
+
# Generate 1D coordinates
|
| 83 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 84 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 85 |
+
|
| 86 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 87 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 88 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 89 |
+
|
| 90 |
+
return uv_grid
|
hyworldmirror/models/utils/priors.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def normalize_poses(extrinsics, padding=0.1, return_stats=False):
|
| 5 |
+
"""
|
| 6 |
+
Normalize camera positions to unit cube, processing each batch separately
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
extrinsics: Camera extrinsic matrices with shape (B, S, 3, 4)
|
| 10 |
+
padding: Boundary space within [0,1] range to prevent values near boundaries
|
| 11 |
+
return_stats: Whether to return normalization statistics
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
normalized_extrinsics: Normalized extrinsic matrices
|
| 15 |
+
(optional) stats: Dictionary containing scale and translation information
|
| 16 |
+
"""
|
| 17 |
+
B, S, _, _ = extrinsics.shape
|
| 18 |
+
device = extrinsics.device
|
| 19 |
+
|
| 20 |
+
# Check input validity and handle NaN/Inf values
|
| 21 |
+
for i in range(B):
|
| 22 |
+
if torch.isnan(extrinsics[i]).any() or torch.isinf(extrinsics[i]).any():
|
| 23 |
+
print(f"Warning: dataset sample has NaN/Inf in extrinsics")
|
| 24 |
+
extrinsics[i] = torch.nan_to_num(
|
| 25 |
+
extrinsics[i], nan=0.0, posinf=1e6, neginf=-1e6
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
normalized_extrinsics = extrinsics.clone()
|
| 29 |
+
|
| 30 |
+
# Store normalization parameters if needed
|
| 31 |
+
if return_stats:
|
| 32 |
+
stats = {
|
| 33 |
+
'scale_factors': torch.zeros(B, device=device),
|
| 34 |
+
'translation_vectors': torch.zeros(B, 3, device=device)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
for b in range(B):
|
| 38 |
+
# Extract camera positions for this batch
|
| 39 |
+
positions = extrinsics[b, :, :3, 3] # (S, 3)
|
| 40 |
+
|
| 41 |
+
# Filter valid positions to ignore outliers
|
| 42 |
+
valid_mask = torch.isfinite(positions).all(dim=1) # (S,)
|
| 43 |
+
|
| 44 |
+
if valid_mask.sum() == 0:
|
| 45 |
+
# No valid positions, use default values
|
| 46 |
+
print(f"Warning: Batch {b} has no valid camera positions")
|
| 47 |
+
normalized_extrinsics[b, :, :3, 3] = 0.5 # Place at center
|
| 48 |
+
if return_stats:
|
| 49 |
+
stats['scale_factors'][b] = 1.0
|
| 50 |
+
stats['translation_vectors'][b] = 0.0
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
valid_positions = positions[valid_mask]
|
| 54 |
+
|
| 55 |
+
# Calculate bounds using percentiles for robustness
|
| 56 |
+
if valid_positions.shape[0] > 10:
|
| 57 |
+
# Use 5% and 95% percentiles instead of min/max
|
| 58 |
+
min_pos = torch.quantile(valid_positions, 0.05, dim=0)
|
| 59 |
+
max_pos = torch.quantile(valid_positions, 0.95, dim=0)
|
| 60 |
+
else:
|
| 61 |
+
# Too few samples, use min/max
|
| 62 |
+
min_pos = torch.min(valid_positions, dim=0)[0]
|
| 63 |
+
max_pos = torch.max(valid_positions, dim=0)[0]
|
| 64 |
+
|
| 65 |
+
# Calculate scale factor considering all dimensions
|
| 66 |
+
pos_range = max_pos - min_pos
|
| 67 |
+
|
| 68 |
+
# Add small epsilon to prevent dimension collapse
|
| 69 |
+
eps = torch.maximum(
|
| 70 |
+
torch.tensor(1e-6, device=device),
|
| 71 |
+
torch.abs(max_pos) * 1e-6
|
| 72 |
+
)
|
| 73 |
+
pos_range = torch.maximum(pos_range, eps)
|
| 74 |
+
|
| 75 |
+
# Use maximum range as scale factor for uniform scaling
|
| 76 |
+
scale_factor = torch.max(pos_range)
|
| 77 |
+
scale_factor = torch.clamp(scale_factor, min=1e-6, max=1e6)
|
| 78 |
+
|
| 79 |
+
# Calculate center point for centering
|
| 80 |
+
center = (min_pos + max_pos) / 2.0
|
| 81 |
+
|
| 82 |
+
# Normalize: center first, then scale with padding
|
| 83 |
+
actual_scale = scale_factor / (1 - 2 * padding)
|
| 84 |
+
normalized_positions = (positions - center) / actual_scale + 0.5
|
| 85 |
+
|
| 86 |
+
# Ensure all values are within valid range
|
| 87 |
+
normalized_positions = torch.clamp(normalized_positions, 0.0, 1.0)
|
| 88 |
+
|
| 89 |
+
# Handle invalid positions by setting them to scene center
|
| 90 |
+
invalid_mask = ~torch.isfinite(positions).all(dim=1)
|
| 91 |
+
if invalid_mask.any():
|
| 92 |
+
normalized_positions[invalid_mask] = 0.5
|
| 93 |
+
|
| 94 |
+
normalized_extrinsics[b, :, :3, 3] = normalized_positions
|
| 95 |
+
|
| 96 |
+
if return_stats:
|
| 97 |
+
stats['scale_factors'][b] = actual_scale
|
| 98 |
+
stats['translation_vectors'][b] = center
|
| 99 |
+
|
| 100 |
+
# Final validation
|
| 101 |
+
assert torch.isfinite(normalized_extrinsics).all(), "Output contains non-finite values"
|
| 102 |
+
|
| 103 |
+
if return_stats:
|
| 104 |
+
return normalized_extrinsics, stats
|
| 105 |
+
return normalized_extrinsics
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def normalize_depth(depth, eps=1e-6, min_percentile=1, max_percentile=99):
|
| 109 |
+
"""
|
| 110 |
+
Normalize depth values to [0, 1] range using percentile-based scaling.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
depth: Input depth tensor with shape (B, S, H, W)
|
| 114 |
+
eps: Small epsilon value to prevent division by zero
|
| 115 |
+
min_percentile: Lower percentile for robust min calculation (default: 1)
|
| 116 |
+
max_percentile: Upper percentile for robust max calculation (default: 99)
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
normalized_depth: Depth tensor normalized to [0, 1] range with same shape (B, S, H, W)
|
| 120 |
+
"""
|
| 121 |
+
B, S, H, W = depth.shape
|
| 122 |
+
depth = depth.flatten(0,1) # [B*S, H, W]
|
| 123 |
+
|
| 124 |
+
# Handle invalid values
|
| 125 |
+
depth = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0)
|
| 126 |
+
|
| 127 |
+
normalized_list = []
|
| 128 |
+
for i in range(depth.shape[0]):
|
| 129 |
+
depth_img = depth[i] # [H, W]
|
| 130 |
+
depth_flat = depth_img.flatten()
|
| 131 |
+
|
| 132 |
+
# Filter out zero values if needed
|
| 133 |
+
non_zero_mask = depth_flat > 0
|
| 134 |
+
if non_zero_mask.sum() > 0:
|
| 135 |
+
values_to_use = depth_flat[non_zero_mask]
|
| 136 |
+
else:
|
| 137 |
+
values_to_use = depth_flat
|
| 138 |
+
|
| 139 |
+
# Only calculate percentiles when there are enough values
|
| 140 |
+
if values_to_use.numel() > 100: # Ensure enough samples for percentile calculation
|
| 141 |
+
# Calculate min and max percentiles
|
| 142 |
+
depth_min = torch.quantile(values_to_use, min_percentile/100.0)
|
| 143 |
+
depth_max = torch.quantile(values_to_use, max_percentile/100.0)
|
| 144 |
+
else:
|
| 145 |
+
# If too few samples, use min/max values
|
| 146 |
+
depth_min = values_to_use.min()
|
| 147 |
+
depth_max = values_to_use.max()
|
| 148 |
+
|
| 149 |
+
# Handle case where max equals min
|
| 150 |
+
if depth_max == depth_min:
|
| 151 |
+
depth_max = depth_min + 1.0
|
| 152 |
+
|
| 153 |
+
# Use relative epsilon
|
| 154 |
+
scale = torch.abs(depth_max - depth_min)
|
| 155 |
+
eps_val = max(eps, scale.item() * eps)
|
| 156 |
+
|
| 157 |
+
# Perform normalization
|
| 158 |
+
depth_norm_img = (depth_img - depth_min) / (depth_max - depth_min + eps_val)
|
| 159 |
+
|
| 160 |
+
# Ensure output is within [0,1] range
|
| 161 |
+
depth_norm_img = torch.clamp(depth_norm_img, 0.0, 1.0)
|
| 162 |
+
|
| 163 |
+
normalized_list.append(depth_norm_img)
|
| 164 |
+
|
| 165 |
+
# Recombine all normalized images
|
| 166 |
+
depth_norm = torch.stack(normalized_list)
|
| 167 |
+
|
| 168 |
+
return depth_norm.reshape(B, S, H, W)
|
hyworldmirror/models/utils/rotation.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def quat_to_rotmat(quaternions: torch.Tensor) -> torch.Tensor:
|
| 9 |
+
"""
|
| 10 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 11 |
+
|
| 12 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 13 |
+
Args:
|
| 14 |
+
quaternions: quaternions with real part last,
|
| 15 |
+
as tensor of shape (..., 4).
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 19 |
+
"""
|
| 20 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
| 21 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 22 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 23 |
+
|
| 24 |
+
o = torch.stack(
|
| 25 |
+
(
|
| 26 |
+
1 - two_s * (j * j + k * k),
|
| 27 |
+
two_s * (i * j - k * r),
|
| 28 |
+
two_s * (i * k + j * r),
|
| 29 |
+
two_s * (i * j + k * r),
|
| 30 |
+
1 - two_s * (i * i + k * k),
|
| 31 |
+
two_s * (j * k - i * r),
|
| 32 |
+
two_s * (i * k - j * r),
|
| 33 |
+
two_s * (j * k + i * r),
|
| 34 |
+
1 - two_s * (i * i + j * j),
|
| 35 |
+
),
|
| 36 |
+
-1,
|
| 37 |
+
)
|
| 38 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def rotmat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 50 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 51 |
+
"""
|
| 52 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 53 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 54 |
+
|
| 55 |
+
batch_dim = matrix.shape[:-2]
|
| 56 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 57 |
+
|
| 58 |
+
q_abs = _sqrt_positive_part(
|
| 59 |
+
torch.stack(
|
| 60 |
+
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 65 |
+
quat_by_rijk = torch.stack(
|
| 66 |
+
[
|
| 67 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 68 |
+
# `int`.
|
| 69 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 70 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 71 |
+
# `int`.
|
| 72 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 73 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 74 |
+
# `int`.
|
| 75 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 76 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 77 |
+
# `int`.
|
| 78 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 79 |
+
],
|
| 80 |
+
dim=-2,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 84 |
+
# the candidate won't be picked.
|
| 85 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 86 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 87 |
+
|
| 88 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 89 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 90 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
| 91 |
+
|
| 92 |
+
# Convert from rijk to ijkr
|
| 93 |
+
out = out[..., [1, 2, 3, 0]]
|
| 94 |
+
|
| 95 |
+
out = standardize_quaternion(out)
|
| 96 |
+
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 103 |
+
but with a zero subgradient where x is 0.
|
| 104 |
+
"""
|
| 105 |
+
ret = torch.zeros_like(x)
|
| 106 |
+
positive_mask = x > 0
|
| 107 |
+
if torch.is_grad_enabled():
|
| 108 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 109 |
+
else:
|
| 110 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 111 |
+
return ret
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""
|
| 116 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 117 |
+
part is non negative.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
quaternions: Quaternions with real part last,
|
| 121 |
+
as tensor of shape (..., 4).
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 125 |
+
"""
|
| 126 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
hyworldmirror/models/utils/sh_utils.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The PlenOctree Authors.
|
| 2 |
+
# Redistribution and use in source and binary forms, with or without
|
| 3 |
+
# modification, are permitted provided that the following conditions are met:
|
| 4 |
+
#
|
| 5 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
| 6 |
+
# this list of conditions and the following disclaimer.
|
| 7 |
+
#
|
| 8 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 9 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 10 |
+
# and/or other materials provided with the distribution.
|
| 11 |
+
#
|
| 12 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 13 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 14 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
| 15 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
| 16 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
| 17 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 18 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 19 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
| 20 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
| 21 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
| 22 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
| 23 |
+
|
| 24 |
+
C0 = 0.28209479177387814
|
| 25 |
+
C1 = 0.4886025119029199
|
| 26 |
+
C2 = [
|
| 27 |
+
1.0925484305920792,
|
| 28 |
+
-1.0925484305920792,
|
| 29 |
+
0.31539156525252005,
|
| 30 |
+
-1.0925484305920792,
|
| 31 |
+
0.5462742152960396
|
| 32 |
+
]
|
| 33 |
+
C3 = [
|
| 34 |
+
-0.5900435899266435,
|
| 35 |
+
2.890611442640554,
|
| 36 |
+
-0.4570457994644658,
|
| 37 |
+
0.3731763325901154,
|
| 38 |
+
-0.4570457994644658,
|
| 39 |
+
1.445305721320277,
|
| 40 |
+
-0.5900435899266435
|
| 41 |
+
]
|
| 42 |
+
C4 = [
|
| 43 |
+
2.5033429417967046,
|
| 44 |
+
-1.7701307697799304,
|
| 45 |
+
0.9461746957575601,
|
| 46 |
+
-0.6690465435572892,
|
| 47 |
+
0.10578554691520431,
|
| 48 |
+
-0.6690465435572892,
|
| 49 |
+
0.47308734787878004,
|
| 50 |
+
-1.7701307697799304,
|
| 51 |
+
0.6258357354491761,
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def eval_sh(deg, sh, dirs):
|
| 56 |
+
"""
|
| 57 |
+
Evaluate spherical harmonics at unit directions
|
| 58 |
+
using hardcoded SH polynomials.
|
| 59 |
+
Works with torch/np/jnp.
|
| 60 |
+
... Can be 0 or more batch dimensions.
|
| 61 |
+
Args:
|
| 62 |
+
deg: int SH deg. Currently, 0-3 supported
|
| 63 |
+
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
| 64 |
+
dirs: jnp.ndarray unit directions [..., 3]
|
| 65 |
+
Returns:
|
| 66 |
+
[..., C]
|
| 67 |
+
"""
|
| 68 |
+
assert deg <= 4 and deg >= 0
|
| 69 |
+
coeff = (deg + 1) ** 2
|
| 70 |
+
assert sh.shape[-1] >= coeff
|
| 71 |
+
|
| 72 |
+
result = C0 * sh[..., 0]
|
| 73 |
+
if deg > 0:
|
| 74 |
+
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
| 75 |
+
result = (result -
|
| 76 |
+
C1 * y * sh[..., 1] +
|
| 77 |
+
C1 * z * sh[..., 2] -
|
| 78 |
+
C1 * x * sh[..., 3])
|
| 79 |
+
|
| 80 |
+
if deg > 1:
|
| 81 |
+
xx, yy, zz = x * x, y * y, z * z
|
| 82 |
+
xy, yz, xz = x * y, y * z, x * z
|
| 83 |
+
result = (result +
|
| 84 |
+
C2[0] * xy * sh[..., 4] +
|
| 85 |
+
C2[1] * yz * sh[..., 5] +
|
| 86 |
+
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
| 87 |
+
C2[3] * xz * sh[..., 7] +
|
| 88 |
+
C2[4] * (xx - yy) * sh[..., 8])
|
| 89 |
+
|
| 90 |
+
if deg > 2:
|
| 91 |
+
result = (result +
|
| 92 |
+
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
| 93 |
+
C3[1] * xy * z * sh[..., 10] +
|
| 94 |
+
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
| 95 |
+
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
| 96 |
+
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
| 97 |
+
C3[5] * z * (xx - yy) * sh[..., 14] +
|
| 98 |
+
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
| 99 |
+
|
| 100 |
+
if deg > 3:
|
| 101 |
+
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
| 102 |
+
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
| 103 |
+
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
| 104 |
+
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
| 105 |
+
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
| 106 |
+
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
| 107 |
+
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
| 108 |
+
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
| 109 |
+
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
def RGB2SH(rgb):
|
| 113 |
+
return (rgb - 0.5) / C0
|
| 114 |
+
|
| 115 |
+
def SH2RGB(sh):
|
| 116 |
+
return sh * C0 + 0.5
|
hyworldmirror/utils/__init__.py
ADDED
|
File without changes
|
hyworldmirror/utils/geometry.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for geometry operations.
|
| 3 |
+
|
| 4 |
+
References: DUSt3R, MoGe
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from numbers import Number
|
| 8 |
+
from typing import Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from .warnings import no_warnings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def colmap_to_opencv_intrinsics(K):
|
| 15 |
+
"""
|
| 16 |
+
Modify camera intrinsics to follow a different convention.
|
| 17 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 18 |
+
- (0.5, 0.5) in Colmap
|
| 19 |
+
- (0,0) in OpenCV
|
| 20 |
+
"""
|
| 21 |
+
K = K.copy()
|
| 22 |
+
K[0, 2] -= 0.5
|
| 23 |
+
K[1, 2] -= 0.5
|
| 24 |
+
|
| 25 |
+
return K
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def opencv_to_colmap_intrinsics(K):
|
| 29 |
+
"""
|
| 30 |
+
Modify camera intrinsics to follow a different convention.
|
| 31 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 32 |
+
- (0.5, 0.5) in Colmap
|
| 33 |
+
- (0,0) in OpenCV
|
| 34 |
+
"""
|
| 35 |
+
K = K.copy()
|
| 36 |
+
K[0, 2] += 0.5
|
| 37 |
+
K[1, 2] += 0.5
|
| 38 |
+
|
| 39 |
+
return K
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12):
|
| 43 |
+
"""
|
| 44 |
+
Compute angle difference between 3D vectors using NumPy.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
v1 (np.ndarray): First vector of shape (..., 3)
|
| 48 |
+
v2 (np.ndarray): Second vector of shape (..., 3)
|
| 49 |
+
eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
np.ndarray: Angle differences in radians
|
| 53 |
+
"""
|
| 54 |
+
return np.arctan2(
|
| 55 |
+
np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@no_warnings(category=RuntimeWarning)
|
| 60 |
+
def points_to_normals(
|
| 61 |
+
point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None
|
| 62 |
+
) -> np.ndarray:
|
| 63 |
+
"""
|
| 64 |
+
Calculate normal map from point map. Value range is [-1, 1].
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
point (np.ndarray): shape (height, width, 3), point map
|
| 68 |
+
mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None.
|
| 69 |
+
edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
normal (np.ndarray): shape (height, width, 3), normal map.
|
| 73 |
+
"""
|
| 74 |
+
height, width = point.shape[-3:-1]
|
| 75 |
+
has_mask = mask is not None
|
| 76 |
+
|
| 77 |
+
if mask is None:
|
| 78 |
+
mask = np.ones_like(point[..., 0], dtype=bool)
|
| 79 |
+
mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
|
| 80 |
+
mask_pad[1:-1, 1:-1] = mask
|
| 81 |
+
mask = mask_pad
|
| 82 |
+
|
| 83 |
+
pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
|
| 84 |
+
pts[1:-1, 1:-1, :] = point
|
| 85 |
+
up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
|
| 86 |
+
left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
|
| 87 |
+
down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
|
| 88 |
+
right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
|
| 89 |
+
normal = np.stack(
|
| 90 |
+
[
|
| 91 |
+
np.cross(up, left, axis=-1),
|
| 92 |
+
np.cross(left, down, axis=-1),
|
| 93 |
+
np.cross(down, right, axis=-1),
|
| 94 |
+
np.cross(right, up, axis=-1),
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
|
| 98 |
+
|
| 99 |
+
valid = (
|
| 100 |
+
np.stack(
|
| 101 |
+
[
|
| 102 |
+
mask[:-2, 1:-1] & mask[1:-1, :-2],
|
| 103 |
+
mask[1:-1, :-2] & mask[2:, 1:-1],
|
| 104 |
+
mask[2:, 1:-1] & mask[1:-1, 2:],
|
| 105 |
+
mask[1:-1, 2:] & mask[:-2, 1:-1],
|
| 106 |
+
]
|
| 107 |
+
)
|
| 108 |
+
& mask[None, 1:-1, 1:-1]
|
| 109 |
+
)
|
| 110 |
+
if edge_threshold is not None:
|
| 111 |
+
view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal)
|
| 112 |
+
view_angle = np.minimum(view_angle, np.pi - view_angle)
|
| 113 |
+
valid = valid & (view_angle < np.deg2rad(edge_threshold))
|
| 114 |
+
|
| 115 |
+
normal = (normal * valid[..., None]).sum(axis=0)
|
| 116 |
+
normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
|
| 117 |
+
|
| 118 |
+
if has_mask:
|
| 119 |
+
normal_mask = valid.any(axis=0)
|
| 120 |
+
normal = np.where(normal_mask[..., None], normal, 0)
|
| 121 |
+
return normal, normal_mask
|
| 122 |
+
else:
|
| 123 |
+
return normal
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
|
| 127 |
+
"""
|
| 128 |
+
Create a sliding window view of the input array along a specified axis.
|
| 129 |
+
|
| 130 |
+
This function creates a memory-efficient view of the input array with sliding windows
|
| 131 |
+
of the specified size and stride. The window dimension is appended to the end of the
|
| 132 |
+
output array's shape. This is useful for operations like convolution, pooling, or
|
| 133 |
+
any analysis that requires examining local neighborhoods in the data.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
x (np.ndarray): Input array with shape (..., axis_size, ...)
|
| 137 |
+
window_size (int): Size of the sliding window
|
| 138 |
+
stride (int): Stride of the sliding window (step size between consecutive windows)
|
| 139 |
+
axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
np.ndarray: View of the input array with shape (..., n_windows, ..., window_size),
|
| 143 |
+
where n_windows = (axis_size - window_size + 1) // stride
|
| 144 |
+
|
| 145 |
+
Raises:
|
| 146 |
+
AssertionError: If window_size is larger than the size of the specified axis
|
| 147 |
+
|
| 148 |
+
Example:
|
| 149 |
+
>>> x = np.array([1, 2, 3, 4, 5, 6])
|
| 150 |
+
>>> sliding_window_1d(x, window_size=3, stride=2)
|
| 151 |
+
array([[1, 2, 3],
|
| 152 |
+
[3, 4, 5]])
|
| 153 |
+
"""
|
| 154 |
+
assert x.shape[axis] >= window_size, (
|
| 155 |
+
f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
|
| 156 |
+
)
|
| 157 |
+
axis = axis % x.ndim
|
| 158 |
+
shape = (
|
| 159 |
+
*x.shape[:axis],
|
| 160 |
+
(x.shape[axis] - window_size + 1) // stride,
|
| 161 |
+
*x.shape[axis + 1 :],
|
| 162 |
+
window_size,
|
| 163 |
+
)
|
| 164 |
+
strides = (
|
| 165 |
+
*x.strides[:axis],
|
| 166 |
+
stride * x.strides[axis],
|
| 167 |
+
*x.strides[axis + 1 :],
|
| 168 |
+
x.strides[axis],
|
| 169 |
+
)
|
| 170 |
+
x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
|
| 171 |
+
return x_sliding
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def sliding_window_nd(
|
| 175 |
+
x: np.ndarray,
|
| 176 |
+
window_size: Tuple[int, ...],
|
| 177 |
+
stride: Tuple[int, ...],
|
| 178 |
+
axis: Tuple[int, ...],
|
| 179 |
+
) -> np.ndarray:
|
| 180 |
+
"""
|
| 181 |
+
Create sliding windows along multiple dimensions of the input array.
|
| 182 |
+
|
| 183 |
+
This function applies sliding_window_1d sequentially along multiple axes to create
|
| 184 |
+
N-dimensional sliding windows. This is useful for operations that need to examine
|
| 185 |
+
local neighborhoods in multiple dimensions simultaneously.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
x (np.ndarray): Input array
|
| 189 |
+
window_size (Tuple[int, ...]): Size of the sliding window for each axis
|
| 190 |
+
stride (Tuple[int, ...]): Stride of the sliding window for each axis
|
| 191 |
+
axis (Tuple[int, ...]): Axes to perform sliding window over
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
np.ndarray: Array with sliding windows along the specified dimensions.
|
| 195 |
+
The window dimensions are appended to the end of the shape.
|
| 196 |
+
|
| 197 |
+
Note:
|
| 198 |
+
The length of window_size, stride, and axis tuples must be equal.
|
| 199 |
+
|
| 200 |
+
Example:
|
| 201 |
+
>>> x = np.random.rand(10, 10)
|
| 202 |
+
>>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1))
|
| 203 |
+
>>> # Creates 3x3 sliding windows with stride 2 in both dimensions
|
| 204 |
+
"""
|
| 205 |
+
axis = [axis[i] % x.ndim for i in range(len(axis))]
|
| 206 |
+
for i in range(len(axis)):
|
| 207 |
+
x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def sliding_window_2d(
|
| 212 |
+
x: np.ndarray,
|
| 213 |
+
window_size: Union[int, Tuple[int, int]],
|
| 214 |
+
stride: Union[int, Tuple[int, int]],
|
| 215 |
+
axis: Tuple[int, int] = (-2, -1),
|
| 216 |
+
) -> np.ndarray:
|
| 217 |
+
"""
|
| 218 |
+
Create 2D sliding windows over the input array.
|
| 219 |
+
|
| 220 |
+
Convenience function for creating 2D sliding windows, commonly used for image
|
| 221 |
+
processing operations like convolution, pooling, or patch extraction.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
x (np.ndarray): Input array
|
| 225 |
+
window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window.
|
| 226 |
+
If int, same size is used for both dimensions.
|
| 227 |
+
stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window.
|
| 228 |
+
If int, same stride is used for both dimensions.
|
| 229 |
+
axis (Tuple[int, int], optional): Two axes to perform sliding window over.
|
| 230 |
+
Defaults to (-2, -1) (last two dimensions).
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
np.ndarray: Array with 2D sliding windows. The window dimensions (height, width)
|
| 234 |
+
are appended to the end of the shape.
|
| 235 |
+
|
| 236 |
+
Example:
|
| 237 |
+
>>> image = np.random.rand(100, 100)
|
| 238 |
+
>>> patches = sliding_window_2d(image, window_size=8, stride=4)
|
| 239 |
+
>>> # Creates 8x8 patches with stride 4 from the image
|
| 240 |
+
"""
|
| 241 |
+
if isinstance(window_size, int):
|
| 242 |
+
window_size = (window_size, window_size)
|
| 243 |
+
if isinstance(stride, int):
|
| 244 |
+
stride = (stride, stride)
|
| 245 |
+
return sliding_window_nd(x, window_size, stride, axis)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def max_pool_1d(
|
| 249 |
+
x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Perform 1D max pooling on the input array.
|
| 253 |
+
|
| 254 |
+
Max pooling reduces the dimensionality of the input by taking the maximum value
|
| 255 |
+
within each sliding window. This is commonly used in neural networks and signal
|
| 256 |
+
processing for downsampling and feature extraction.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
x (np.ndarray): Input array
|
| 260 |
+
kernel_size (int): Size of the pooling kernel
|
| 261 |
+
stride (int): Stride of the pooling operation
|
| 262 |
+
padding (int, optional): Amount of padding to add on both sides. Defaults to 0.
|
| 263 |
+
axis (int, optional): Axis to perform max pooling over. Defaults to -1.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
np.ndarray: Max pooled array with reduced size along the specified axis
|
| 267 |
+
|
| 268 |
+
Note:
|
| 269 |
+
- For floating point arrays, padding is done with np.nan values
|
| 270 |
+
- For integer arrays, padding is done with the minimum value of the dtype
|
| 271 |
+
- np.nanmax is used to handle NaN values in the computation
|
| 272 |
+
|
| 273 |
+
Example:
|
| 274 |
+
>>> x = np.array([1, 3, 2, 4, 5, 1, 2])
|
| 275 |
+
>>> max_pool_1d(x, kernel_size=3, stride=2)
|
| 276 |
+
array([3, 5, 2])
|
| 277 |
+
"""
|
| 278 |
+
axis = axis % x.ndim
|
| 279 |
+
if padding > 0:
|
| 280 |
+
fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min
|
| 281 |
+
padding_arr = np.full(
|
| 282 |
+
(*x.shape[:axis], padding, *x.shape[axis + 1 :]),
|
| 283 |
+
fill_value=fill_value,
|
| 284 |
+
dtype=x.dtype,
|
| 285 |
+
)
|
| 286 |
+
x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
|
| 287 |
+
a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
|
| 288 |
+
max_pool = np.nanmax(a_sliding, axis=-1)
|
| 289 |
+
return max_pool
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def max_pool_nd(
|
| 293 |
+
x: np.ndarray,
|
| 294 |
+
kernel_size: Tuple[int, ...],
|
| 295 |
+
stride: Tuple[int, ...],
|
| 296 |
+
padding: Tuple[int, ...],
|
| 297 |
+
axis: Tuple[int, ...],
|
| 298 |
+
) -> np.ndarray:
|
| 299 |
+
"""
|
| 300 |
+
Perform N-dimensional max pooling on the input array.
|
| 301 |
+
|
| 302 |
+
This function applies max_pool_1d sequentially along multiple axes to perform
|
| 303 |
+
multi-dimensional max pooling. This is useful for downsampling multi-dimensional
|
| 304 |
+
data while preserving the most important features.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
x (np.ndarray): Input array
|
| 308 |
+
kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis
|
| 309 |
+
stride (Tuple[int, ...]): Stride of the pooling operation for each axis
|
| 310 |
+
padding (Tuple[int, ...]): Amount of padding for each axis
|
| 311 |
+
axis (Tuple[int, ...]): Axes to perform max pooling over
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
np.ndarray: Max pooled array with reduced size along the specified axes
|
| 315 |
+
|
| 316 |
+
Note:
|
| 317 |
+
The length of kernel_size, stride, padding, and axis tuples must be equal.
|
| 318 |
+
Max pooling is applied sequentially along each axis in the order specified.
|
| 319 |
+
|
| 320 |
+
Example:
|
| 321 |
+
>>> x = np.random.rand(10, 10, 10)
|
| 322 |
+
>>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2),
|
| 323 |
+
... padding=(0, 0, 0), axis=(-3, -2, -1))
|
| 324 |
+
>>> # Reduces each dimension by half with 2x2x2 max pooling
|
| 325 |
+
"""
|
| 326 |
+
for i in range(len(axis)):
|
| 327 |
+
x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
|
| 328 |
+
return x
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def max_pool_2d(
|
| 332 |
+
x: np.ndarray,
|
| 333 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 334 |
+
stride: Union[int, Tuple[int, int]],
|
| 335 |
+
padding: Union[int, Tuple[int, int]],
|
| 336 |
+
axis: Tuple[int, int] = (-2, -1),
|
| 337 |
+
):
|
| 338 |
+
"""
|
| 339 |
+
Perform 2D max pooling on the input array.
|
| 340 |
+
|
| 341 |
+
Convenience function for 2D max pooling, commonly used in computer vision
|
| 342 |
+
and image processing for downsampling images while preserving important features.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
x (np.ndarray): Input array
|
| 346 |
+
kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel.
|
| 347 |
+
If int, same size is used for both dimensions.
|
| 348 |
+
stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation.
|
| 349 |
+
If int, same stride is used for both dimensions.
|
| 350 |
+
padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions.
|
| 351 |
+
If int, same padding is used for both dimensions.
|
| 352 |
+
axis (Tuple[int, int], optional): Two axes to perform max pooling over.
|
| 353 |
+
Defaults to (-2, -1) (last two dimensions).
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
np.ndarray: 2D max pooled array with reduced size along the specified axes
|
| 357 |
+
|
| 358 |
+
Example:
|
| 359 |
+
>>> image = np.random.rand(64, 64)
|
| 360 |
+
>>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0)
|
| 361 |
+
>>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling
|
| 362 |
+
"""
|
| 363 |
+
if isinstance(kernel_size, Number):
|
| 364 |
+
kernel_size = (kernel_size, kernel_size)
|
| 365 |
+
if isinstance(stride, Number):
|
| 366 |
+
stride = (stride, stride)
|
| 367 |
+
if isinstance(padding, Number):
|
| 368 |
+
padding = (padding, padding)
|
| 369 |
+
axis = tuple(axis)
|
| 370 |
+
return max_pool_nd(x, kernel_size, stride, padding, axis)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@no_warnings(category=RuntimeWarning)
|
| 374 |
+
def depth_edge(
|
| 375 |
+
depth: np.ndarray,
|
| 376 |
+
atol: float = None,
|
| 377 |
+
rtol: float = None,
|
| 378 |
+
kernel_size: int = 3,
|
| 379 |
+
mask: np.ndarray = None,
|
| 380 |
+
) -> np.ndarray:
|
| 381 |
+
"""
|
| 382 |
+
Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
depth (np.ndarray): shape (..., height, width), linear depth map
|
| 386 |
+
atol (float): absolute tolerance
|
| 387 |
+
rtol (float): relative tolerance
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
edge (np.ndarray): shape (..., height, width) of dtype torch.bool
|
| 391 |
+
"""
|
| 392 |
+
if mask is None:
|
| 393 |
+
diff = max_pool_2d(
|
| 394 |
+
depth, kernel_size, stride=1, padding=kernel_size // 2
|
| 395 |
+
) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
|
| 396 |
+
else:
|
| 397 |
+
diff = max_pool_2d(
|
| 398 |
+
np.where(mask, depth, -np.inf),
|
| 399 |
+
kernel_size,
|
| 400 |
+
stride=1,
|
| 401 |
+
padding=kernel_size // 2,
|
| 402 |
+
) + max_pool_2d(
|
| 403 |
+
np.where(mask, -depth, -np.inf),
|
| 404 |
+
kernel_size,
|
| 405 |
+
stride=1,
|
| 406 |
+
padding=kernel_size // 2,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
edge = np.zeros_like(depth, dtype=bool)
|
| 410 |
+
if atol is not None:
|
| 411 |
+
edge |= diff > atol
|
| 412 |
+
|
| 413 |
+
if rtol is not None:
|
| 414 |
+
edge |= diff / depth > rtol
|
| 415 |
+
return edge
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def depth_aliasing(
|
| 419 |
+
depth: np.ndarray,
|
| 420 |
+
atol: float = None,
|
| 421 |
+
rtol: float = None,
|
| 422 |
+
kernel_size: int = 3,
|
| 423 |
+
mask: np.ndarray = None,
|
| 424 |
+
) -> np.ndarray:
|
| 425 |
+
"""
|
| 426 |
+
Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
|
| 427 |
+
Args:
|
| 428 |
+
depth (np.ndarray): shape (..., height, width), linear depth map
|
| 429 |
+
atol (float): absolute tolerance
|
| 430 |
+
rtol (float): relative tolerance
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
edge (np.ndarray): shape (..., height, width) of dtype torch.bool
|
| 434 |
+
"""
|
| 435 |
+
if mask is None:
|
| 436 |
+
diff_max = (
|
| 437 |
+
max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
|
| 438 |
+
)
|
| 439 |
+
diff_min = (
|
| 440 |
+
max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
|
| 441 |
+
)
|
| 442 |
+
else:
|
| 443 |
+
diff_max = (
|
| 444 |
+
max_pool_2d(
|
| 445 |
+
np.where(mask, depth, -np.inf),
|
| 446 |
+
kernel_size,
|
| 447 |
+
stride=1,
|
| 448 |
+
padding=kernel_size // 2,
|
| 449 |
+
)
|
| 450 |
+
- depth
|
| 451 |
+
)
|
| 452 |
+
diff_min = (
|
| 453 |
+
max_pool_2d(
|
| 454 |
+
np.where(mask, -depth, -np.inf),
|
| 455 |
+
kernel_size,
|
| 456 |
+
stride=1,
|
| 457 |
+
padding=kernel_size // 2,
|
| 458 |
+
)
|
| 459 |
+
+ depth
|
| 460 |
+
)
|
| 461 |
+
diff = np.minimum(diff_max, diff_min)
|
| 462 |
+
|
| 463 |
+
edge = np.zeros_like(depth, dtype=bool)
|
| 464 |
+
if atol is not None:
|
| 465 |
+
edge |= diff > atol
|
| 466 |
+
if rtol is not None:
|
| 467 |
+
edge |= diff / depth > rtol
|
| 468 |
+
return edge
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
@no_warnings(category=RuntimeWarning)
|
| 472 |
+
def normals_edge(
|
| 473 |
+
normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None
|
| 474 |
+
) -> np.ndarray:
|
| 475 |
+
"""
|
| 476 |
+
Compute the edge mask from normal map.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
normal (np.ndarray): shape (..., height, width, 3), normal map
|
| 480 |
+
tol (float): tolerance in degrees
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
edge (np.ndarray): shape (..., height, width) of dtype torch.bool
|
| 484 |
+
"""
|
| 485 |
+
assert normals.ndim >= 3 and normals.shape[-1] == 3, (
|
| 486 |
+
"normal should be of shape (..., height, width, 3)"
|
| 487 |
+
)
|
| 488 |
+
normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
|
| 489 |
+
|
| 490 |
+
padding = kernel_size // 2
|
| 491 |
+
normals_window = sliding_window_2d(
|
| 492 |
+
np.pad(
|
| 493 |
+
normals,
|
| 494 |
+
(
|
| 495 |
+
*([(0, 0)] * (normals.ndim - 3)),
|
| 496 |
+
(padding, padding),
|
| 497 |
+
(padding, padding),
|
| 498 |
+
(0, 0),
|
| 499 |
+
),
|
| 500 |
+
mode="edge",
|
| 501 |
+
),
|
| 502 |
+
window_size=kernel_size,
|
| 503 |
+
stride=1,
|
| 504 |
+
axis=(-3, -2),
|
| 505 |
+
)
|
| 506 |
+
if mask is None:
|
| 507 |
+
angle_diff = np.arccos(
|
| 508 |
+
(normals[..., None, None] * normals_window).sum(axis=-3)
|
| 509 |
+
).max(axis=(-2, -1))
|
| 510 |
+
else:
|
| 511 |
+
mask_window = sliding_window_2d(
|
| 512 |
+
np.pad(
|
| 513 |
+
mask,
|
| 514 |
+
(*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)),
|
| 515 |
+
mode="edge",
|
| 516 |
+
),
|
| 517 |
+
window_size=kernel_size,
|
| 518 |
+
stride=1,
|
| 519 |
+
axis=(-3, -2),
|
| 520 |
+
)
|
| 521 |
+
angle_diff = np.where(
|
| 522 |
+
mask_window,
|
| 523 |
+
np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)),
|
| 524 |
+
0,
|
| 525 |
+
).max(axis=(-2, -1))
|
| 526 |
+
|
| 527 |
+
angle_diff = max_pool_2d(
|
| 528 |
+
angle_diff, kernel_size, stride=1, padding=kernel_size // 2
|
| 529 |
+
)
|
| 530 |
+
edge = angle_diff > np.deg2rad(tol)
|
| 531 |
+
return edge
|
hyworldmirror/utils/inference_utils.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference utilities for WorldMirror pipeline.
|
| 3 |
+
|
| 4 |
+
Includes: image preprocessing, input preparation, prior loading, mask computation,
|
| 5 |
+
result saving, and timing utilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import glob
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
|
| 21 |
+
from ..models.utils.camera_utils import vector_to_camera_matrices
|
| 22 |
+
from ..models.utils.geometry import depth_to_world_coords_points
|
| 23 |
+
from .save_utils import (
|
| 24 |
+
save_depth_png, save_depth_npy, save_normal_png,
|
| 25 |
+
save_gs_ply, save_points_ply, save_camera_params,
|
| 26 |
+
)
|
| 27 |
+
from .video_utils import video_to_image_frames, video_to_image_frames_new
|
| 28 |
+
from .visual_util import segment_sky, download_file_from_url
|
| 29 |
+
from .geometry import depth_edge, normals_edge
|
| 30 |
+
|
| 31 |
+
_IO_WORKERS = 8
|
| 32 |
+
|
| 33 |
+
# ============================================================
|
| 34 |
+
# Image Preprocessing
|
| 35 |
+
# ============================================================
|
| 36 |
+
|
| 37 |
+
def _handle_alpha_channel(img_data):
|
| 38 |
+
"""Process RGBA images by blending with white background."""
|
| 39 |
+
if img_data.mode == "RGBA":
|
| 40 |
+
white_bg = Image.new("RGBA", img_data.size, (255, 255, 255, 255))
|
| 41 |
+
img_data = Image.alpha_composite(white_bg, img_data)
|
| 42 |
+
return img_data.convert("RGB")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _calculate_resize_dims(orig_w, orig_h, max_dim, resize_strategy, patch_size=14):
|
| 46 |
+
"""Calculate new dimensions based on resize strategy."""
|
| 47 |
+
if orig_w >= orig_h:
|
| 48 |
+
new_w = max_dim
|
| 49 |
+
new_h = round(orig_h * (new_w / orig_w) / patch_size) * patch_size
|
| 50 |
+
else:
|
| 51 |
+
new_h = max_dim
|
| 52 |
+
new_w = round(orig_w * (new_h / orig_h) / patch_size) * patch_size
|
| 53 |
+
return new_w, new_h
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _apply_padding(tensor_img, target_dim):
|
| 57 |
+
"""Apply padding to make tensor square."""
|
| 58 |
+
h_pad = target_dim - tensor_img.shape[1]
|
| 59 |
+
w_pad = target_dim - tensor_img.shape[2]
|
| 60 |
+
if h_pad > 0 or w_pad > 0:
|
| 61 |
+
pad_top, pad_bottom = h_pad // 2, h_pad - h_pad // 2
|
| 62 |
+
pad_left, pad_right = w_pad // 2, w_pad - w_pad // 2
|
| 63 |
+
return torch.nn.functional.pad(
|
| 64 |
+
tensor_img, (pad_left, pad_right, pad_top, pad_bottom),
|
| 65 |
+
mode="constant", value=1.0,
|
| 66 |
+
)
|
| 67 |
+
return tensor_img
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def prepare_images_to_tensor(file_paths, resize_strategy="crop", target_size=518):
|
| 71 |
+
"""Process image files into uniform tensor batch [1, N, 3, H, W]."""
|
| 72 |
+
if not file_paths:
|
| 73 |
+
raise ValueError("At least 1 image is required")
|
| 74 |
+
if resize_strategy not in ["crop", "pad"]:
|
| 75 |
+
raise ValueError("Strategy must be 'crop' or 'pad'")
|
| 76 |
+
|
| 77 |
+
tensor_list = []
|
| 78 |
+
converter = transforms.ToTensor()
|
| 79 |
+
|
| 80 |
+
for file_path in file_paths:
|
| 81 |
+
img_data = Image.open(file_path)
|
| 82 |
+
img_data = _handle_alpha_channel(img_data)
|
| 83 |
+
orig_w, orig_h = img_data.size
|
| 84 |
+
new_w, new_h = _calculate_resize_dims(orig_w, orig_h, target_size, resize_strategy)
|
| 85 |
+
|
| 86 |
+
img_data = img_data.resize((new_w, new_h), Image.Resampling.BICUBIC)
|
| 87 |
+
tensor_img = converter(img_data)
|
| 88 |
+
|
| 89 |
+
if resize_strategy == "crop":
|
| 90 |
+
if new_h > target_size:
|
| 91 |
+
crop_start = (new_h - target_size) // 2
|
| 92 |
+
tensor_img = tensor_img[:, crop_start:crop_start + target_size, :]
|
| 93 |
+
if new_w > target_size:
|
| 94 |
+
crop_start = (new_w - target_size) // 2
|
| 95 |
+
tensor_img = tensor_img[:, :, crop_start:crop_start + target_size]
|
| 96 |
+
elif resize_strategy == "pad":
|
| 97 |
+
tensor_img = _apply_padding(tensor_img, target_size)
|
| 98 |
+
|
| 99 |
+
tensor_list.append(tensor_img)
|
| 100 |
+
|
| 101 |
+
shapes = set((t.shape[1], t.shape[2]) for t in tensor_list)
|
| 102 |
+
if len(shapes) > 1:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"Inconsistent resolutions after preprocessing: {shapes}. "
|
| 105 |
+
f"All input images must have the same aspect ratio."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
batch_tensor = torch.stack(tensor_list)
|
| 109 |
+
if batch_tensor.dim() == 3:
|
| 110 |
+
batch_tensor = batch_tensor.unsqueeze(0)
|
| 111 |
+
return batch_tensor.unsqueeze(0)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ============================================================
|
| 115 |
+
# Input Preparation
|
| 116 |
+
# ============================================================
|
| 117 |
+
|
| 118 |
+
def prepare_input(input_path, target_size=518, fps=1,
|
| 119 |
+
video_strategy="new", min_frames=1, max_frames=64,
|
| 120 |
+
temp_dir=None):
|
| 121 |
+
"""Read images or extract video frames. Returns (img_paths, subdir_name)."""
|
| 122 |
+
input_path = Path(input_path)
|
| 123 |
+
video_exts = ['.mp4', '.avi', '.mov', '.webm', '.gif']
|
| 124 |
+
|
| 125 |
+
if input_path.is_file() and input_path.suffix.lower() in video_exts:
|
| 126 |
+
subdir_name = input_path.stem
|
| 127 |
+
frames_dir = Path(temp_dir or "/tmp") / f"frames_{subdir_name}"
|
| 128 |
+
frames_dir.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
min_f = max(1, min_frames)
|
| 130 |
+
max_f = min(64, max_frames)
|
| 131 |
+
if video_strategy == "new":
|
| 132 |
+
img_paths = video_to_image_frames_new(
|
| 133 |
+
str(input_path), str(frames_dir),
|
| 134 |
+
min_frames=min_f, max_frames=max_f, fallback_fps=fps,
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
img_paths = video_to_image_frames(str(input_path), str(frames_dir), fps=fps)
|
| 138 |
+
if len(img_paths) > max_f:
|
| 139 |
+
indices = np.linspace(0, len(img_paths) - 1, max_f, dtype=int)
|
| 140 |
+
img_paths = [img_paths[i] for i in indices]
|
| 141 |
+
if not img_paths:
|
| 142 |
+
raise RuntimeError(f"Failed to extract frames from {input_path}")
|
| 143 |
+
img_paths = sorted(img_paths)
|
| 144 |
+
print(f"[Input] Extracted {len(img_paths)} frames from video: {input_path}")
|
| 145 |
+
elif input_path.is_dir():
|
| 146 |
+
subdir_name = input_path.name
|
| 147 |
+
img_paths = []
|
| 148 |
+
for ext in ["*.jpeg", "*.jpg", "*.png", "*.webp"]:
|
| 149 |
+
img_paths.extend(glob.glob(os.path.join(str(input_path), ext)))
|
| 150 |
+
if not img_paths:
|
| 151 |
+
raise FileNotFoundError(f"No images found in {input_path}")
|
| 152 |
+
img_paths = sorted(img_paths)
|
| 153 |
+
print(f"[Input] Loaded {len(img_paths)} images from: {input_path}")
|
| 154 |
+
elif input_path.is_file() and input_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']:
|
| 155 |
+
subdir_name = input_path.stem
|
| 156 |
+
img_paths = [str(input_path)]
|
| 157 |
+
print(f"[Input] Single image input: {input_path}")
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"Invalid input path: {input_path}")
|
| 160 |
+
|
| 161 |
+
return img_paths, subdir_name
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def compute_adaptive_target_size(img_paths, max_target_size=518, patch_size=14):
|
| 165 |
+
"""Compute inference resolution = min(image_longest_edge, max_target_size).
|
| 166 |
+
|
| 167 |
+
Rounds down to nearest multiple of patch_size. Avoids upsampling small images.
|
| 168 |
+
"""
|
| 169 |
+
first_img = Image.open(img_paths[0])
|
| 170 |
+
orig_w, orig_h = first_img.size
|
| 171 |
+
longest_edge = max(orig_w, orig_h)
|
| 172 |
+
effective = min(longest_edge, max_target_size)
|
| 173 |
+
effective = (effective // patch_size) * patch_size
|
| 174 |
+
return max(effective, patch_size * 2)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ============================================================
|
| 178 |
+
# Prior Loading
|
| 179 |
+
# ============================================================
|
| 180 |
+
|
| 181 |
+
def compute_preprocessing_transform(img_paths, target_size, patch_size=14):
|
| 182 |
+
"""Compute the resize + center-crop transform applied by prepare_images_to_tensor.
|
| 183 |
+
|
| 184 |
+
Returns dict with orig/new/final sizes and scale/crop parameters.
|
| 185 |
+
"""
|
| 186 |
+
first_img = Image.open(img_paths[0])
|
| 187 |
+
orig_w, orig_h = first_img.size
|
| 188 |
+
new_w, new_h = _calculate_resize_dims(orig_w, orig_h, target_size, "crop", patch_size)
|
| 189 |
+
|
| 190 |
+
crop_y = (new_h - target_size) // 2 if new_h > target_size else 0
|
| 191 |
+
crop_x = (new_w - target_size) // 2 if new_w > target_size else 0
|
| 192 |
+
|
| 193 |
+
return {
|
| 194 |
+
"orig_w": orig_w, "orig_h": orig_h,
|
| 195 |
+
"new_w": new_w, "new_h": new_h,
|
| 196 |
+
"crop_x": crop_x, "crop_y": crop_y,
|
| 197 |
+
"final_w": min(new_w, target_size), "final_h": min(new_h, target_size),
|
| 198 |
+
"scale_x": new_w / orig_w, "scale_y": new_h / orig_h,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def load_prior_camera(prior_cam_path, img_paths, preprocess_transform=None):
|
| 203 |
+
"""Load camera priors from JSON. Returns (extrinsics [1,N,4,4], intrinsics [1,N,3,3])."""
|
| 204 |
+
with open(prior_cam_path, "r") as f:
|
| 205 |
+
cam_data = json.load(f)
|
| 206 |
+
|
| 207 |
+
stem_to_idx = {Path(p).stem: i for i, p in enumerate(img_paths)}
|
| 208 |
+
N = len(img_paths)
|
| 209 |
+
|
| 210 |
+
extrinsics = None
|
| 211 |
+
extr_list = cam_data.get("extrinsics", [])
|
| 212 |
+
if extr_list:
|
| 213 |
+
extr_array = np.zeros((N, 4, 4), dtype=np.float32)
|
| 214 |
+
matched = 0
|
| 215 |
+
for entry in extr_list:
|
| 216 |
+
cam_id = str(entry["camera_id"])
|
| 217 |
+
idx = stem_to_idx.get(cam_id)
|
| 218 |
+
if idx is None and cam_id.isdigit() and int(cam_id) < N:
|
| 219 |
+
idx = int(cam_id)
|
| 220 |
+
if idx is not None:
|
| 221 |
+
extr_array[idx] = np.array(entry["matrix"], dtype=np.float32)
|
| 222 |
+
matched += 1
|
| 223 |
+
if matched == N:
|
| 224 |
+
extrinsics = torch.from_numpy(extr_array).unsqueeze(0)
|
| 225 |
+
print(f"[Prior] Loaded extrinsics for {matched}/{N} cameras")
|
| 226 |
+
else:
|
| 227 |
+
print(f"[Prior] Warning: extrinsics matched {matched}/{N}, disabling")
|
| 228 |
+
|
| 229 |
+
intrinsics = None
|
| 230 |
+
intr_list = cam_data.get("intrinsics", [])
|
| 231 |
+
if intr_list:
|
| 232 |
+
intr_array = np.zeros((N, 3, 3), dtype=np.float32)
|
| 233 |
+
matched = 0
|
| 234 |
+
for entry in intr_list:
|
| 235 |
+
cam_id = str(entry["camera_id"])
|
| 236 |
+
idx = stem_to_idx.get(cam_id)
|
| 237 |
+
if idx is None and cam_id.isdigit() and int(cam_id) < N:
|
| 238 |
+
idx = int(cam_id)
|
| 239 |
+
if idx is not None:
|
| 240 |
+
intr_array[idx] = np.array(entry["matrix"], dtype=np.float32)
|
| 241 |
+
matched += 1
|
| 242 |
+
if matched == N:
|
| 243 |
+
intrinsics = torch.from_numpy(intr_array).unsqueeze(0)
|
| 244 |
+
print(f"[Prior] Loaded intrinsics for {matched}/{N} cameras")
|
| 245 |
+
else:
|
| 246 |
+
print(f"[Prior] Warning: intrinsics matched {matched}/{N}, disabling")
|
| 247 |
+
|
| 248 |
+
if intrinsics is not None and preprocess_transform is not None:
|
| 249 |
+
sx, sy = preprocess_transform["scale_x"], preprocess_transform["scale_y"]
|
| 250 |
+
cx_off, cy_off = preprocess_transform["crop_x"], preprocess_transform["crop_y"]
|
| 251 |
+
intrinsics = intrinsics.clone()
|
| 252 |
+
intrinsics[:, :, 0, :] *= sx
|
| 253 |
+
intrinsics[:, :, 1, :] *= sy
|
| 254 |
+
intrinsics[:, :, 0, 2] -= cx_off
|
| 255 |
+
intrinsics[:, :, 1, 2] -= cy_off
|
| 256 |
+
|
| 257 |
+
return extrinsics, intrinsics
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _read_depth_file(depth_path):
|
| 261 |
+
"""Read a single depth file (.npy, .exr, .png). Returns float32 [H, W]."""
|
| 262 |
+
ext = Path(depth_path).suffix.lower()
|
| 263 |
+
if ext == ".npy":
|
| 264 |
+
depthmap = np.load(depth_path).astype(np.float32)
|
| 265 |
+
if depthmap.ndim == 3:
|
| 266 |
+
depthmap = depthmap[:, :, 0]
|
| 267 |
+
elif ext == ".exr":
|
| 268 |
+
depthmap = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
| 269 |
+
if depthmap.ndim == 3:
|
| 270 |
+
depthmap = depthmap[:, :, 0]
|
| 271 |
+
elif ext == ".png":
|
| 272 |
+
depthmap = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
|
| 273 |
+
if depthmap is None:
|
| 274 |
+
raise FileNotFoundError(f"Cannot read depth PNG: {depth_path}")
|
| 275 |
+
depthmap = depthmap.astype(np.float32)
|
| 276 |
+
if depthmap.ndim == 3:
|
| 277 |
+
depthmap = depthmap[:, :, 0]
|
| 278 |
+
if depthmap.max() > 255:
|
| 279 |
+
depthmap = depthmap / 1000.0
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unsupported depth format: {ext}")
|
| 282 |
+
return np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def load_prior_depth(prior_depth_path, img_paths, target_h, target_w,
|
| 286 |
+
preprocess_transform=None):
|
| 287 |
+
"""Load depth priors from a folder. Returns [1, N, H, W] or None."""
|
| 288 |
+
depth_dir = Path(prior_depth_path)
|
| 289 |
+
if not depth_dir.is_dir():
|
| 290 |
+
return None
|
| 291 |
+
|
| 292 |
+
depth_files = {}
|
| 293 |
+
for f in sorted(depth_dir.iterdir()):
|
| 294 |
+
if f.suffix.lower() in (".npy", ".exr", ".png"):
|
| 295 |
+
if f.stem not in depth_files or f.suffix.lower() == ".npy":
|
| 296 |
+
depth_files[f.stem] = str(f)
|
| 297 |
+
|
| 298 |
+
N = len(img_paths)
|
| 299 |
+
depth_maps = []
|
| 300 |
+
for img_p in img_paths:
|
| 301 |
+
img_stem = Path(img_p).stem
|
| 302 |
+
dpath = depth_files.get(img_stem)
|
| 303 |
+
if dpath is None:
|
| 304 |
+
img_nums = ''.join(filter(str.isdigit, img_stem))
|
| 305 |
+
for dstem, dc in depth_files.items():
|
| 306 |
+
if img_nums and img_nums == ''.join(filter(str.isdigit, dstem)):
|
| 307 |
+
dpath = dc
|
| 308 |
+
break
|
| 309 |
+
if dpath is None:
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
depthmap = _read_depth_file(dpath)
|
| 313 |
+
if preprocess_transform is not None:
|
| 314 |
+
nw, nh = preprocess_transform["new_w"], preprocess_transform["new_h"]
|
| 315 |
+
cx, cy = preprocess_transform["crop_x"], preprocess_transform["crop_y"]
|
| 316 |
+
fw, fh = preprocess_transform["final_w"], preprocess_transform["final_h"]
|
| 317 |
+
if depthmap.shape[:2] != (nh, nw):
|
| 318 |
+
depthmap = cv2.resize(depthmap, (nw, nh), interpolation=cv2.INTER_LINEAR)
|
| 319 |
+
depthmap = depthmap[cy:cy + fh, cx:cx + fw]
|
| 320 |
+
else:
|
| 321 |
+
if depthmap.shape[:2] != (target_h, target_w):
|
| 322 |
+
depthmap = cv2.resize(depthmap, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
| 323 |
+
depth_maps.append(depthmap)
|
| 324 |
+
|
| 325 |
+
depth_tensor = torch.from_numpy(np.stack(depth_maps, axis=0)).unsqueeze(0)
|
| 326 |
+
print(f"[Prior] Loaded {N} depth maps from {prior_depth_path}")
|
| 327 |
+
return depth_tensor
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# ============================================================
|
| 331 |
+
# Mask Computation
|
| 332 |
+
# ============================================================
|
| 333 |
+
|
| 334 |
+
def create_filter_mask(
|
| 335 |
+
pts3d_conf, depth_preds, normal_preds, sky_mask,
|
| 336 |
+
confidence_percentile=10.0, edge_normal_threshold=5.0,
|
| 337 |
+
edge_depth_threshold=0.03, apply_confidence_mask=True,
|
| 338 |
+
apply_edge_mask=True, apply_sky_mask=False, gs_depth_preds=None,
|
| 339 |
+
):
|
| 340 |
+
"""Create filter mask based on confidence, edges, and sky segmentation.
|
| 341 |
+
|
| 342 |
+
Returns pts_mask [S,H,W] or (pts_mask, gs_mask) tuple if gs_depth_preds given.
|
| 343 |
+
"""
|
| 344 |
+
S, H, W = pts3d_conf.shape[:3]
|
| 345 |
+
final_mask_list = []
|
| 346 |
+
gs_mask_list = [] if gs_depth_preds is not None else None
|
| 347 |
+
|
| 348 |
+
for i in range(S):
|
| 349 |
+
final_mask = None
|
| 350 |
+
if apply_confidence_mask:
|
| 351 |
+
threshold = np.quantile(pts3d_conf[i], confidence_percentile / 100.0)
|
| 352 |
+
conf_mask = pts3d_conf[i] >= threshold
|
| 353 |
+
final_mask = conf_mask if final_mask is None else final_mask & conf_mask
|
| 354 |
+
|
| 355 |
+
pre_edge_mask = final_mask
|
| 356 |
+
|
| 357 |
+
if apply_edge_mask:
|
| 358 |
+
n_edges = normals_edge(normal_preds[i], tol=edge_normal_threshold, mask=pre_edge_mask)
|
| 359 |
+
d_edges = depth_edge(depth_preds[i, :, :, 0], rtol=edge_depth_threshold, mask=pre_edge_mask)
|
| 360 |
+
edge_mask = ~(d_edges & n_edges)
|
| 361 |
+
final_mask = edge_mask if final_mask is None else final_mask & edge_mask
|
| 362 |
+
|
| 363 |
+
if gs_depth_preds is not None:
|
| 364 |
+
gs_d_edges = depth_edge(gs_depth_preds[i, :, :, 0], rtol=edge_depth_threshold, mask=pre_edge_mask)
|
| 365 |
+
gs_edge_mask = ~(gs_d_edges & n_edges)
|
| 366 |
+
gs_frame_mask = gs_edge_mask if pre_edge_mask is None else pre_edge_mask & gs_edge_mask
|
| 367 |
+
|
| 368 |
+
if apply_sky_mask:
|
| 369 |
+
final_mask = sky_mask[i] if final_mask is None else final_mask & sky_mask[i]
|
| 370 |
+
if gs_depth_preds is not None and apply_edge_mask:
|
| 371 |
+
gs_frame_mask = gs_frame_mask & sky_mask[i]
|
| 372 |
+
|
| 373 |
+
final_mask_list.append(final_mask)
|
| 374 |
+
if gs_mask_list is not None:
|
| 375 |
+
gs_mask_list.append(gs_frame_mask if apply_edge_mask else final_mask)
|
| 376 |
+
|
| 377 |
+
def _stack(ml):
|
| 378 |
+
return np.stack(ml, axis=0) if ml[0] is not None else np.ones((S, H, W), dtype=bool)
|
| 379 |
+
|
| 380 |
+
pts_mask = _stack(final_mask_list)
|
| 381 |
+
if gs_mask_list is not None:
|
| 382 |
+
return pts_mask, _stack(gs_mask_list)
|
| 383 |
+
return pts_mask
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _compute_sky_mask_from_model(predictions, H, W, S, threshold=0.5):
|
| 387 |
+
"""Build sky mask from model predictions. Returns [S,H,W] bool or None."""
|
| 388 |
+
for key in ("gs_depth_mask_logits", "gs_depth_mask", "depth_mask_logits", "depth_mask"):
|
| 389 |
+
if key in predictions:
|
| 390 |
+
prob = predictions[key].sigmoid() if "logits" in key else predictions[key]
|
| 391 |
+
dm = prob[0].detach().cpu()
|
| 392 |
+
if dm.dim() == 4 and dm.shape[-1] == 1:
|
| 393 |
+
dm = dm.squeeze(-1)
|
| 394 |
+
if dm.dim() != 3 or dm.shape[0] != S:
|
| 395 |
+
return None
|
| 396 |
+
mask = (dm > threshold).numpy().astype(bool)
|
| 397 |
+
if mask.shape[1] != H or mask.shape[2] != W:
|
| 398 |
+
mask = np.stack([cv2.resize(mask[i].astype(np.uint8), (W, H),
|
| 399 |
+
interpolation=cv2.INTER_NEAREST) > 0
|
| 400 |
+
for i in range(S)], axis=0)
|
| 401 |
+
return mask
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def compute_sky_mask(img_paths, H, W, S, predictions=None, source="auto",
|
| 406 |
+
model_threshold=0.5, processed_aspect_ratio=None):
|
| 407 |
+
"""Compute sky segmentation mask [S,H,W] (True=non-sky, False=sky)."""
|
| 408 |
+
if source == "model":
|
| 409 |
+
mask = _compute_sky_mask_from_model(predictions, H, W, S, model_threshold) if predictions else None
|
| 410 |
+
return mask if mask is not None else np.ones((S, H, W), dtype=bool)
|
| 411 |
+
|
| 412 |
+
skyseg_path = "skyseg.onnx"
|
| 413 |
+
if not os.path.exists(skyseg_path):
|
| 414 |
+
download_file_from_url(
|
| 415 |
+
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
|
| 416 |
+
skyseg_path,
|
| 417 |
+
)
|
| 418 |
+
import onnxruntime
|
| 419 |
+
session = onnxruntime.InferenceSession(skyseg_path)
|
| 420 |
+
sky_list = []
|
| 421 |
+
for i in range(S):
|
| 422 |
+
if processed_aspect_ratio is not None:
|
| 423 |
+
pil_img = Image.open(img_paths[i]).convert("RGB")
|
| 424 |
+
sw, sh = pil_img.size
|
| 425 |
+
if sw / sh > processed_aspect_ratio:
|
| 426 |
+
cw = int(round(sh * processed_aspect_ratio))
|
| 427 |
+
ch = sh
|
| 428 |
+
else:
|
| 429 |
+
cw = sw
|
| 430 |
+
ch = int(round(sw / processed_aspect_ratio))
|
| 431 |
+
left, top = (sw - cw) // 2, (sh - ch) // 2
|
| 432 |
+
pil_img = pil_img.crop((left, top, left + cw, top + ch))
|
| 433 |
+
frame = segment_sky(cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR), session)
|
| 434 |
+
else:
|
| 435 |
+
frame = segment_sky(img_paths[i], session)
|
| 436 |
+
if frame.shape[:2] != (H, W):
|
| 437 |
+
frame = cv2.resize(frame, (W, H))
|
| 438 |
+
sky_list.append(frame)
|
| 439 |
+
|
| 440 |
+
sky_mask = np.stack(sky_list, axis=0) > 0
|
| 441 |
+
if source == "auto" and predictions is not None:
|
| 442 |
+
model_mask = _compute_sky_mask_from_model(predictions, H, W, S, model_threshold)
|
| 443 |
+
if model_mask is not None:
|
| 444 |
+
sky_mask = sky_mask & model_mask
|
| 445 |
+
return sky_mask
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def compute_filter_mask(predictions, imgs, img_paths, H, W, S,
|
| 449 |
+
apply_confidence_mask=False, apply_edge_mask=False,
|
| 450 |
+
apply_sky_mask=False, confidence_percentile=10.0,
|
| 451 |
+
edge_normal_threshold=5.0, edge_depth_threshold=0.03,
|
| 452 |
+
sky_mask=None, use_gs_depth=False):
|
| 453 |
+
"""Compute unified filter mask. Returns (filter_mask, gs_filter_mask) tuple."""
|
| 454 |
+
if not (apply_confidence_mask or apply_edge_mask or apply_sky_mask):
|
| 455 |
+
return np.ones((S, H, W), dtype=bool), None
|
| 456 |
+
|
| 457 |
+
if apply_sky_mask and sky_mask is None:
|
| 458 |
+
sky_mask = compute_sky_mask(img_paths, H, W, S, processed_aspect_ratio=W / H)
|
| 459 |
+
elif sky_mask is None:
|
| 460 |
+
sky_mask = np.ones((S, H, W), dtype=bool)
|
| 461 |
+
|
| 462 |
+
if "pts3d_conf" in predictions:
|
| 463 |
+
conf_np = predictions["pts3d_conf"][0].detach().cpu().float().numpy()
|
| 464 |
+
elif "depth_conf" in predictions:
|
| 465 |
+
conf_np = predictions["depth_conf"][0].detach().cpu().float().numpy()
|
| 466 |
+
else:
|
| 467 |
+
conf_np = np.ones((S, H, W), dtype=np.float32)
|
| 468 |
+
depth_np = predictions["depth"][0].detach().cpu().float().numpy()
|
| 469 |
+
normal_np = predictions["normals"][0].detach().cpu().float().numpy()
|
| 470 |
+
|
| 471 |
+
gs_depth_np = None
|
| 472 |
+
if use_gs_depth and "gs_depth" in predictions:
|
| 473 |
+
raw = predictions["gs_depth"][0].detach().cpu().float().numpy()
|
| 474 |
+
gs_depth_np = raw if raw.ndim == 4 else raw[..., np.newaxis]
|
| 475 |
+
|
| 476 |
+
result = create_filter_mask(
|
| 477 |
+
conf_np, depth_np, normal_np, sky_mask,
|
| 478 |
+
confidence_percentile=confidence_percentile,
|
| 479 |
+
edge_normal_threshold=edge_normal_threshold,
|
| 480 |
+
edge_depth_threshold=edge_depth_threshold,
|
| 481 |
+
apply_confidence_mask=apply_confidence_mask,
|
| 482 |
+
apply_edge_mask=apply_edge_mask,
|
| 483 |
+
apply_sky_mask=apply_sky_mask,
|
| 484 |
+
gs_depth_preds=gs_depth_np,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if gs_depth_np is not None:
|
| 488 |
+
pts_mask, gs_mask = result
|
| 489 |
+
total = pts_mask.size
|
| 490 |
+
print(f"[Mask] Filter: pts kept {pts_mask.sum()}/{total}, gs kept {gs_mask.sum()}/{total}")
|
| 491 |
+
return pts_mask, gs_mask
|
| 492 |
+
|
| 493 |
+
print(f"[Mask] Filter: kept {result.sum()}/{result.size} points")
|
| 494 |
+
return result, None
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
# ============================================================
|
| 498 |
+
# Save Utilities
|
| 499 |
+
# ============================================================
|
| 500 |
+
|
| 501 |
+
def _timed_call(func, *args, **kwargs):
|
| 502 |
+
t0 = time.perf_counter()
|
| 503 |
+
result = func(*args, **kwargs)
|
| 504 |
+
return result, time.perf_counter() - t0
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def _save_depth_parallel(depth_cpu, depth_dir, S):
|
| 508 |
+
def _save_one(i):
|
| 509 |
+
save_depth_png(depth_dir / f"depth_{i:04d}.png", depth_cpu[i, :, :, 0])
|
| 510 |
+
save_depth_npy(depth_dir / f"depth_{i:04d}.npy", depth_cpu[i, :, :, 0])
|
| 511 |
+
with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
|
| 512 |
+
list(pool.map(_save_one, range(S)))
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _save_conf_parallel(depth_conf_cpu, conf_dir, S):
|
| 516 |
+
def _save_one(i):
|
| 517 |
+
conf = depth_conf_cpu[i]
|
| 518 |
+
c_min, c_max = conf.min(), conf.max()
|
| 519 |
+
norm = (conf - c_min) / (c_max - c_min) if c_max - c_min > 1e-8 else torch.ones_like(conf)
|
| 520 |
+
Image.fromarray((norm.clamp(0, 1) * 255).to(torch.uint8).numpy(), mode="L").save(
|
| 521 |
+
str(conf_dir / f"conf_{i+1:04d}.png"))
|
| 522 |
+
with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
|
| 523 |
+
list(pool.map(_save_one, range(S)))
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def _save_normal_parallel(normals_cpu, normal_dir, S):
|
| 527 |
+
def _save_one(i):
|
| 528 |
+
save_normal_png(normal_dir / f"normal_{i:04d}.png", normals_cpu[i])
|
| 529 |
+
with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
|
| 530 |
+
list(pool.map(_save_one, range(S)))
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _save_sky_mask_parallel(sky_mask, sky_mask_dir, S):
|
| 534 |
+
def _save_one(i):
|
| 535 |
+
Image.fromarray((~sky_mask[i]).astype(np.uint8) * 255, mode="L").save(
|
| 536 |
+
str(sky_mask_dir / f"sky_mask_{i:04d}.png"))
|
| 537 |
+
with ThreadPoolExecutor(max_workers=_IO_WORKERS) as pool:
|
| 538 |
+
list(pool.map(_save_one, range(S)))
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def _voxel_prune_gaussians(means, scales, quats, colors, opacities, weights, voxel_size=0.002):
|
| 542 |
+
"""Voxel-based merging of Gaussian splats via weighted average."""
|
| 543 |
+
N = means.shape[0]
|
| 544 |
+
if N == 0:
|
| 545 |
+
return means, scales, quats, colors, opacities
|
| 546 |
+
|
| 547 |
+
voxel_idx = (means / voxel_size).floor().long()
|
| 548 |
+
voxel_idx = voxel_idx - voxel_idx.min(dim=0)[0]
|
| 549 |
+
vmax = voxel_idx.max(dim=0)[0] + 1
|
| 550 |
+
flat = voxel_idx[:, 0] * vmax[1] * vmax[2] + voxel_idx[:, 1] * vmax[2] + voxel_idx[:, 2]
|
| 551 |
+
|
| 552 |
+
unique, inv = torch.unique(flat, return_inverse=True)
|
| 553 |
+
K = len(unique)
|
| 554 |
+
if K == N:
|
| 555 |
+
return means, scales, quats, colors, opacities
|
| 556 |
+
|
| 557 |
+
w = weights
|
| 558 |
+
wsum = torch.zeros(K, dtype=w.dtype).scatter_add_(0, inv, w).clamp(min=1e-8)
|
| 559 |
+
|
| 560 |
+
def _wavg(vals):
|
| 561 |
+
out = torch.zeros(K, *vals.shape[1:], dtype=vals.dtype)
|
| 562 |
+
for d in range(vals.shape[1]):
|
| 563 |
+
out[:, d].scatter_add_(0, inv, vals[:, d] * w)
|
| 564 |
+
return out / wsum.unsqueeze(-1)
|
| 565 |
+
|
| 566 |
+
m_opa = torch.zeros(K, dtype=opacities.dtype).scatter_add_(0, inv, w * w) / wsum
|
| 567 |
+
m_quats = torch.zeros(K, 4, dtype=quats.dtype)
|
| 568 |
+
for d in range(4):
|
| 569 |
+
m_quats[:, d].scatter_add_(0, inv, quats[:, d] * w)
|
| 570 |
+
m_quats = m_quats / m_quats.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
| 571 |
+
|
| 572 |
+
print(f"[Save] Voxel prune: {N} -> {K} gaussians")
|
| 573 |
+
return _wavg(means), _wavg(scales), m_quats, _wavg(colors), m_opa
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _compress_points_voxel_then_sample(pts_np, cols_np, max_points=2_000_000, voxel_size=0.005):
|
| 577 |
+
"""Compress point cloud: voxel merge then uniform random sampling."""
|
| 578 |
+
n_in = int(pts_np.shape[0])
|
| 579 |
+
if n_in == 0:
|
| 580 |
+
return pts_np, cols_np
|
| 581 |
+
|
| 582 |
+
if voxel_size > 0:
|
| 583 |
+
voxel = np.floor(pts_np / voxel_size).astype(np.int64)
|
| 584 |
+
voxel -= voxel.min(axis=0, keepdims=True)
|
| 585 |
+
_, inv = np.unique(voxel, axis=0, return_inverse=True)
|
| 586 |
+
k = int(inv.max()) + 1
|
| 587 |
+
if k < n_in:
|
| 588 |
+
counts = np.maximum(np.bincount(inv, minlength=k).astype(np.float32), 1.0)
|
| 589 |
+
pts_np = np.stack([np.bincount(inv, weights=pts_np[:, d], minlength=k)
|
| 590 |
+
for d in range(3)], axis=1).astype(np.float32) / counts[:, None]
|
| 591 |
+
cols_np = np.clip(np.round(
|
| 592 |
+
np.stack([np.bincount(inv, weights=cols_np[:, d].astype(np.float32), minlength=k)
|
| 593 |
+
for d in range(3)], axis=1) / counts[:, None]
|
| 594 |
+
), 0, 255).astype(np.uint8)
|
| 595 |
+
|
| 596 |
+
if max_points > 0 and pts_np.shape[0] > max_points:
|
| 597 |
+
idx = np.random.default_rng(42).choice(pts_np.shape[0], size=max_points, replace=False)
|
| 598 |
+
pts_np, cols_np = pts_np[idx], cols_np[idx]
|
| 599 |
+
return pts_np, cols_np
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def _compute_points_from_depth(depth_pred, imgs, extrinsics, intrinsics, S, H, W, filter_mask=None):
|
| 603 |
+
"""Derive 3D point cloud from depth + camera outputs."""
|
| 604 |
+
depth_pred, extrinsics, intrinsics = depth_pred.float(), extrinsics.float(), intrinsics.float()
|
| 605 |
+
points_list, colors_list = [], []
|
| 606 |
+
for i in range(S):
|
| 607 |
+
d = depth_pred[0, i, :, :, 0]
|
| 608 |
+
w2c = torch.cat([extrinsics[i][:3, :4],
|
| 609 |
+
torch.tensor([[0, 0, 0, 1]], device=extrinsics.device)], dim=0)
|
| 610 |
+
c2w = torch.linalg.inv(w2c)[:3, :4]
|
| 611 |
+
pts_i, _, mask = depth_to_world_coords_points(d[None], c2w[None], intrinsics[i][None])
|
| 612 |
+
img_colors = (imgs[0, i].permute(1, 2, 0) * 255).to(torch.uint8)
|
| 613 |
+
valid = mask[0]
|
| 614 |
+
if filter_mask is not None:
|
| 615 |
+
valid = valid & torch.from_numpy(filter_mask[i]).to(valid.device)
|
| 616 |
+
if valid.sum().item() > 0:
|
| 617 |
+
points_list.append(pts_i[0][valid])
|
| 618 |
+
colors_list.append(img_colors[valid])
|
| 619 |
+
|
| 620 |
+
if not points_list:
|
| 621 |
+
return np.empty((0, 3), dtype=np.float32), np.empty((0, 3), dtype=np.uint8)
|
| 622 |
+
return (torch.cat(points_list).detach().cpu().float().numpy(),
|
| 623 |
+
torch.cat(colors_list).detach().cpu().to(torch.uint8).numpy())
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def _save_colmap_lightweight(extrinsics, intrinsics, outdir, final_w, final_h, S, image_names):
|
| 627 |
+
"""Save lightweight COLMAP reconstruction (cameras + images only)."""
|
| 628 |
+
import pycolmap
|
| 629 |
+
sparse_dir = outdir / "sparse" / "0"
|
| 630 |
+
sparse_dir.mkdir(parents=True, exist_ok=True)
|
| 631 |
+
scene = pycolmap.Reconstruction()
|
| 632 |
+
for i in range(S):
|
| 633 |
+
focal_avg = (intrinsics[i][0, 0] + intrinsics[i][1, 1]) / 2
|
| 634 |
+
camera = pycolmap.Camera(
|
| 635 |
+
model="SIMPLE_PINHOLE", width=final_w, height=final_h,
|
| 636 |
+
params=np.array([focal_avg, intrinsics[i][0, 2], intrinsics[i][1, 2]]),
|
| 637 |
+
camera_id=i + 1,
|
| 638 |
+
)
|
| 639 |
+
scene.add_camera(camera)
|
| 640 |
+
cam_from_world = pycolmap.Rigid3d(
|
| 641 |
+
pycolmap.Rotation3d(extrinsics[i][:3, :3]), extrinsics[i][:3, 3])
|
| 642 |
+
img = pycolmap.Image(id=i + 1, name=image_names[i], camera_id=i + 1,
|
| 643 |
+
cam_from_world=cam_from_world)
|
| 644 |
+
img.registered = True
|
| 645 |
+
scene.add_image(img)
|
| 646 |
+
scene.write(str(sparse_dir))
|
| 647 |
+
print(f"[Save] COLMAP sparse -> {sparse_dir}")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def save_results(predictions, imgs, img_paths, outdir,
|
| 651 |
+
save_depth=True, save_normal=True, save_gs=True,
|
| 652 |
+
save_camera=True, save_colmap=False, save_points=True,
|
| 653 |
+
save_sky_mask=False, save_conf=False, log_time=False,
|
| 654 |
+
max_resolution=1920,
|
| 655 |
+
filter_mask=None, gs_filter_mask=None, sky_mask=None,
|
| 656 |
+
compress_pts=True, compress_pts_max_points=2_000_000,
|
| 657 |
+
compress_pts_voxel_size=0.002,
|
| 658 |
+
compress_gs_max_points=5_000_000):
|
| 659 |
+
"""Save all results with parallel I/O. Returns timing dict."""
|
| 660 |
+
timings = {}
|
| 661 |
+
outdir = Path(outdir)
|
| 662 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 663 |
+
B, S, C, H, W = imgs.shape
|
| 664 |
+
|
| 665 |
+
ar = W / H
|
| 666 |
+
max_w = max(Image.open(p).size[0] for p in img_paths)
|
| 667 |
+
new_w, new_h = max_w, int(round(max_w / ar))
|
| 668 |
+
longest = max(new_w, new_h)
|
| 669 |
+
if longest > max_resolution:
|
| 670 |
+
sf = max_resolution / longest
|
| 671 |
+
new_w, new_h = int(new_w * sf), int(new_h * sf)
|
| 672 |
+
new_w -= new_w % 2
|
| 673 |
+
new_h -= new_h % 2
|
| 674 |
+
image_names = [f"image_{i+1:04d}.jpg" for i in range(S)]
|
| 675 |
+
|
| 676 |
+
depth_cpu = predictions["depth"][0].detach().cpu() if "depth" in predictions else None
|
| 677 |
+
conf_cpu = predictions.get("depth_conf", [None])[0]
|
| 678 |
+
if conf_cpu is not None:
|
| 679 |
+
conf_cpu = conf_cpu.detach().cpu()
|
| 680 |
+
normals_cpu = predictions["normals"][0].detach().cpu() if "normals" in predictions else None
|
| 681 |
+
|
| 682 |
+
futures = {}
|
| 683 |
+
executor = ThreadPoolExecutor(max_workers=_IO_WORKERS)
|
| 684 |
+
|
| 685 |
+
if save_depth and depth_cpu is not None:
|
| 686 |
+
d_dir = outdir / "depth"
|
| 687 |
+
d_dir.mkdir(exist_ok=True)
|
| 688 |
+
futures["save_depth"] = executor.submit(_timed_call, _save_depth_parallel, depth_cpu, d_dir, S)
|
| 689 |
+
|
| 690 |
+
if save_conf and conf_cpu is not None:
|
| 691 |
+
c_dir = outdir / "depth_conf"
|
| 692 |
+
c_dir.mkdir(exist_ok=True)
|
| 693 |
+
futures["save_conf"] = executor.submit(_timed_call, _save_conf_parallel, conf_cpu, c_dir, S)
|
| 694 |
+
|
| 695 |
+
if save_normal and normals_cpu is not None:
|
| 696 |
+
n_dir = outdir / "normal"
|
| 697 |
+
n_dir.mkdir(exist_ok=True)
|
| 698 |
+
futures["save_normal"] = executor.submit(_timed_call, _save_normal_parallel, normals_cpu, n_dir, S)
|
| 699 |
+
|
| 700 |
+
if save_sky_mask and sky_mask is not None:
|
| 701 |
+
sm_dir = outdir / "sky_mask"
|
| 702 |
+
sm_dir.mkdir(exist_ok=True)
|
| 703 |
+
futures["save_sky_mask"] = executor.submit(_timed_call, _save_sky_mask_parallel, sky_mask, sm_dir, S)
|
| 704 |
+
|
| 705 |
+
if save_gs and "splats" in predictions:
|
| 706 |
+
sp = predictions["splats"]
|
| 707 |
+
means = sp["means"][0].reshape(-1, 3).detach().cpu()
|
| 708 |
+
scales = sp["scales"][0].reshape(-1, 3).detach().cpu()
|
| 709 |
+
quats = sp["quats"][0].reshape(-1, 4).detach().cpu()
|
| 710 |
+
colors = (sp["sh"][0] if "sh" in sp else sp["colors"][0]).reshape(-1, 3).detach().cpu()
|
| 711 |
+
opacities = sp["opacities"][0].reshape(-1).detach().cpu()
|
| 712 |
+
weights = sp["weights"][0].reshape(-1).detach().cpu() if "weights" in sp else torch.ones_like(opacities)
|
| 713 |
+
|
| 714 |
+
keep = None
|
| 715 |
+
if gs_filter_mask is not None:
|
| 716 |
+
keep = torch.from_numpy(gs_filter_mask.reshape(-1)).bool()
|
| 717 |
+
elif filter_mask is not None:
|
| 718 |
+
keep = torch.from_numpy(filter_mask.reshape(-1)).bool()
|
| 719 |
+
if keep is not None:
|
| 720 |
+
means, scales, quats = means[keep], scales[keep], quats[keep]
|
| 721 |
+
colors, opacities, weights = colors[keep], opacities[keep], weights[keep]
|
| 722 |
+
|
| 723 |
+
means, scales, quats, colors, opacities = _voxel_prune_gaussians(
|
| 724 |
+
means, scales, quats, colors, opacities, weights)
|
| 725 |
+
if compress_gs_max_points > 0 and means.shape[0] > compress_gs_max_points:
|
| 726 |
+
idx = torch.from_numpy(
|
| 727 |
+
np.random.default_rng(42).choice(means.shape[0], size=compress_gs_max_points, replace=False)
|
| 728 |
+
).long()
|
| 729 |
+
means, scales, quats, colors, opacities = means[idx], scales[idx], quats[idx], colors[idx], opacities[idx]
|
| 730 |
+
|
| 731 |
+
futures["save_gs_ply"] = executor.submit(
|
| 732 |
+
_timed_call, save_gs_ply, outdir / "gaussians.ply", means, scales, quats, colors, opacities)
|
| 733 |
+
|
| 734 |
+
if save_camera and "camera_poses" in predictions and "camera_intrs" in predictions:
|
| 735 |
+
cam_p = predictions["camera_poses"][0].detach().cpu().float().numpy()
|
| 736 |
+
cam_i = predictions["camera_intrs"][0].detach().cpu().float().numpy()
|
| 737 |
+
futures["save_camera"] = executor.submit(_timed_call, save_camera_params, cam_p, cam_i, str(outdir))
|
| 738 |
+
|
| 739 |
+
if save_points and "depth" in predictions and "camera_params" in predictions:
|
| 740 |
+
e3x4, intr = vector_to_camera_matrices(predictions["camera_params"], image_hw=(H, W))
|
| 741 |
+
pts_np, cols_np = _compute_points_from_depth(
|
| 742 |
+
predictions["depth"], imgs, e3x4[0], intr[0], S, H, W, filter_mask=filter_mask)
|
| 743 |
+
futures["save_points"] = executor.submit(
|
| 744 |
+
_timed_call, _save_points_artifacts, outdir / "points.ply", pts_np, cols_np,
|
| 745 |
+
compress_pts, compress_pts_max_points, compress_pts_voxel_size)
|
| 746 |
+
|
| 747 |
+
if save_colmap and "camera_params" in predictions:
|
| 748 |
+
e3x4, intr = vector_to_camera_matrices(predictions["camera_params"], image_hw=(new_h, new_w))
|
| 749 |
+
futures["save_colmap"] = executor.submit(
|
| 750 |
+
_timed_call, _save_colmap_lightweight,
|
| 751 |
+
e3x4[0].detach().cpu().float().numpy(), intr[0].detach().cpu().float().numpy(),
|
| 752 |
+
outdir, new_w, new_h, S, image_names)
|
| 753 |
+
|
| 754 |
+
for key, future in futures.items():
|
| 755 |
+
result, elapsed = future.result()
|
| 756 |
+
if log_time:
|
| 757 |
+
timings[key] = elapsed
|
| 758 |
+
if isinstance(result, dict):
|
| 759 |
+
timings.update(result)
|
| 760 |
+
|
| 761 |
+
executor.shutdown(wait=False)
|
| 762 |
+
return timings
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def _save_points_artifacts(path, pts_np, cols_np,
|
| 766 |
+
compress=False, max_points=2_000_000,
|
| 767 |
+
voxel_size=0.005):
|
| 768 |
+
timings = {}
|
| 769 |
+
if compress:
|
| 770 |
+
t0 = time.perf_counter()
|
| 771 |
+
pts_np, cols_np = _compress_points_voxel_then_sample(pts_np, cols_np, max_points, voxel_size)
|
| 772 |
+
timings["compress_points"] = time.perf_counter() - t0
|
| 773 |
+
save_points_ply(path, pts_np, cols_np)
|
| 774 |
+
return timings
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
# ============================================================
|
| 778 |
+
# Timing Report
|
| 779 |
+
# ============================================================
|
| 780 |
+
|
| 781 |
+
def print_and_save_timings(timings, outdir):
|
| 782 |
+
"""Print formatted timing table and save to JSON."""
|
| 783 |
+
def _p(label, value, indent=0):
|
| 784 |
+
print(f"{' ' * (indent + 1)}{label:<38s} {value:>10.3f}s")
|
| 785 |
+
|
| 786 |
+
print(f"\n{'='*72}\n TIMING REPORT\n{'='*72}")
|
| 787 |
+
|
| 788 |
+
print(" [Serial Stages]")
|
| 789 |
+
for key, label in [("data_loading", "Data loading"),
|
| 790 |
+
("inference_preprocess", "Inference preprocess"),
|
| 791 |
+
("inference", "Model inference"),
|
| 792 |
+
("compute_mask", "Compute filter mask")]:
|
| 793 |
+
if key in timings:
|
| 794 |
+
_p(label, timings[key], 1)
|
| 795 |
+
|
| 796 |
+
save_wall = timings.get("save_total_wall")
|
| 797 |
+
save_keys = [("save_depth", "Depth"), ("save_conf", "Depth conf"),
|
| 798 |
+
("save_normal", "Normal"), ("save_sky_mask", "Sky mask"),
|
| 799 |
+
("save_gs_ply", "Gaussians"), ("save_camera", "Camera"),
|
| 800 |
+
("save_points", "Points"), ("save_colmap", "COLMAP")]
|
| 801 |
+
present = [(k, n) for k, n in save_keys if k in timings]
|
| 802 |
+
if save_wall is not None or present:
|
| 803 |
+
print(" [Save Stage | Parallel]")
|
| 804 |
+
if save_wall is not None:
|
| 805 |
+
_p("Save wall-clock", save_wall, 1)
|
| 806 |
+
for k, name in present:
|
| 807 |
+
_p(f"- {name}", timings[k], 2)
|
| 808 |
+
|
| 809 |
+
if "case_total" in timings:
|
| 810 |
+
print(" [Total]")
|
| 811 |
+
_p("Case total", timings["case_total"], 1)
|
| 812 |
+
|
| 813 |
+
if "gpu_mem_peak_per_rank_gb" in timings:
|
| 814 |
+
print(" [GPU Memory]")
|
| 815 |
+
for i, p in enumerate(timings["gpu_mem_peak_per_rank_gb"]):
|
| 816 |
+
print(f" Rank {i}: {p:.2f} GB")
|
| 817 |
+
print(f" Average: {timings['gpu_mem_peak_avg_gb']:.2f} GB")
|
| 818 |
+
|
| 819 |
+
print(f"{'='*72}\n")
|
| 820 |
+
|
| 821 |
+
outdir = Path(outdir)
|
| 822 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 823 |
+
with open(outdir / "pipeline_timing.json", "w") as f:
|
| 824 |
+
json.dump(timings, f, indent=2)
|
hyworldmirror/utils/render_utils.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Render interpolated video from Gaussian Splatting predictions.
|
| 3 |
+
|
| 4 |
+
Interpolates smooth camera trajectories using SLERP quaternions,
|
| 5 |
+
renders each frame via gsplat, and saves MP4 videos.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from ..models.models.rasterization import GaussianSplatRenderer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def rotation_matrix_to_quaternion(R):
|
| 18 |
+
"""Convert rotation matrix to quaternion (scalar-first: [w, x, y, z]).
|
| 19 |
+
|
| 20 |
+
Note: This uses the Hamilton convention [w, x, y, z], which differs from
|
| 21 |
+
models/utils/rotation.py that uses PyTorch3D convention [x, y, z, w].
|
| 22 |
+
"""
|
| 23 |
+
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
| 24 |
+
|
| 25 |
+
q = torch.zeros(R.shape[:-2] + (4,), device=R.device, dtype=R.dtype)
|
| 26 |
+
|
| 27 |
+
mask1 = trace > 0
|
| 28 |
+
s = torch.sqrt(trace[mask1] + 1.0) * 2
|
| 29 |
+
q[mask1, 0] = 0.25 * s
|
| 30 |
+
q[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s
|
| 31 |
+
q[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s
|
| 32 |
+
q[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s
|
| 33 |
+
|
| 34 |
+
mask2 = (~mask1) & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
|
| 35 |
+
s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2
|
| 36 |
+
q[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s
|
| 37 |
+
q[mask2, 1] = 0.25 * s
|
| 38 |
+
q[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s
|
| 39 |
+
q[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s
|
| 40 |
+
|
| 41 |
+
mask3 = (~mask1) & (~mask2) & (R[..., 1, 1] > R[..., 2, 2])
|
| 42 |
+
s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2
|
| 43 |
+
q[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s
|
| 44 |
+
q[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s
|
| 45 |
+
q[mask3, 2] = 0.25 * s
|
| 46 |
+
q[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s
|
| 47 |
+
|
| 48 |
+
mask4 = (~mask1) & (~mask2) & (~mask3)
|
| 49 |
+
s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2
|
| 50 |
+
q[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s
|
| 51 |
+
q[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s
|
| 52 |
+
q[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s
|
| 53 |
+
q[mask4, 3] = 0.25 * s
|
| 54 |
+
|
| 55 |
+
return q
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def quaternion_to_rotation_matrix(q):
|
| 59 |
+
"""Convert quaternion (scalar-first: [w, x, y, z]) to rotation matrix."""
|
| 60 |
+
w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
| 61 |
+
|
| 62 |
+
norm = torch.sqrt(w*w + x*x + y*y + z*z)
|
| 63 |
+
w, x, y, z = w/norm, x/norm, y/norm, z/norm
|
| 64 |
+
|
| 65 |
+
R = torch.zeros(q.shape[:-1] + (3, 3), device=q.device, dtype=q.dtype)
|
| 66 |
+
|
| 67 |
+
R[..., 0, 0] = 1 - 2*(y*y + z*z)
|
| 68 |
+
R[..., 0, 1] = 2*(x*y - w*z)
|
| 69 |
+
R[..., 0, 2] = 2*(x*z + w*y)
|
| 70 |
+
R[..., 1, 0] = 2*(x*y + w*z)
|
| 71 |
+
R[..., 1, 1] = 1 - 2*(x*x + z*z)
|
| 72 |
+
R[..., 1, 2] = 2*(y*z - w*x)
|
| 73 |
+
R[..., 2, 0] = 2*(x*z - w*y)
|
| 74 |
+
R[..., 2, 1] = 2*(y*z + w*x)
|
| 75 |
+
R[..., 2, 2] = 1 - 2*(x*x + y*y)
|
| 76 |
+
|
| 77 |
+
return R
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def slerp_quaternions(q1, q2, t):
|
| 81 |
+
"""Spherical linear interpolation between quaternions."""
|
| 82 |
+
dot = (q1 * q2).sum(dim=-1, keepdim=True)
|
| 83 |
+
|
| 84 |
+
mask = dot < 0
|
| 85 |
+
q2 = torch.where(mask, -q2, q2)
|
| 86 |
+
dot = torch.where(mask, -dot, dot)
|
| 87 |
+
|
| 88 |
+
DOT_THRESHOLD = 0.9995
|
| 89 |
+
mask_linear = dot > DOT_THRESHOLD
|
| 90 |
+
|
| 91 |
+
result = torch.zeros_like(q1)
|
| 92 |
+
|
| 93 |
+
if mask_linear.any():
|
| 94 |
+
result_linear = q1 + t * (q2 - q1)
|
| 95 |
+
norm = torch.norm(result_linear, dim=-1, keepdim=True)
|
| 96 |
+
result_linear = result_linear / norm
|
| 97 |
+
result = torch.where(mask_linear, result_linear, result)
|
| 98 |
+
|
| 99 |
+
mask_slerp = ~mask_linear
|
| 100 |
+
if mask_slerp.any():
|
| 101 |
+
theta_0 = torch.acos(torch.abs(dot))
|
| 102 |
+
sin_theta_0 = torch.sin(theta_0)
|
| 103 |
+
|
| 104 |
+
theta = theta_0 * t
|
| 105 |
+
sin_theta = torch.sin(theta)
|
| 106 |
+
|
| 107 |
+
s0 = torch.cos(theta) - dot * sin_theta / sin_theta_0
|
| 108 |
+
s1 = sin_theta / sin_theta_0
|
| 109 |
+
|
| 110 |
+
result_slerp = (s0 * q1) + (s1 * q2)
|
| 111 |
+
result = torch.where(mask_slerp, result_slerp, result)
|
| 112 |
+
|
| 113 |
+
return result
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def render_interpolated_video(gs_renderer: GaussianSplatRenderer,
|
| 117 |
+
splats: dict,
|
| 118 |
+
camtoworlds: torch.Tensor,
|
| 119 |
+
intrinsics: torch.Tensor,
|
| 120 |
+
hw: tuple,
|
| 121 |
+
out_path: Path,
|
| 122 |
+
interp_per_pair: int = 20,
|
| 123 |
+
loop_reverse: bool = True,
|
| 124 |
+
save_mode: str = "split",
|
| 125 |
+
frame_times: list = None,
|
| 126 |
+
render_depth: bool = False) -> None:
|
| 127 |
+
"""Render an interpolated fly-through video from Gaussian splat predictions.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
gs_renderer: GaussianSplatRenderer instance (from the model).
|
| 131 |
+
splats: Dict with keys 'means', 'scales', 'quats', 'opacities', 'sh'/'colors'.
|
| 132 |
+
camtoworlds: Camera-to-world matrices [B, S, 4, 4].
|
| 133 |
+
intrinsics: Camera intrinsic matrices [B, S, 3, 3].
|
| 134 |
+
hw: Tuple of (height, width) for rendering.
|
| 135 |
+
out_path: Output path (without extension).
|
| 136 |
+
interp_per_pair: Number of interpolated frames per camera pair.
|
| 137 |
+
loop_reverse: Append reversed video for smooth looping.
|
| 138 |
+
save_mode: 'split' (separate rgb/depth) or 'both' (combined).
|
| 139 |
+
frame_times: Optional list of timestamps for adaptive interpolation.
|
| 140 |
+
render_depth: Whether to also render depth video.
|
| 141 |
+
"""
|
| 142 |
+
import moviepy.editor as mpy
|
| 143 |
+
|
| 144 |
+
b, s, _, _ = camtoworlds.shape
|
| 145 |
+
h, w = hw
|
| 146 |
+
|
| 147 |
+
def build_interpolated_traj(index, base_interp_per_pair: int):
|
| 148 |
+
exts, ints = [], []
|
| 149 |
+
tmp_camtoworlds = camtoworlds[:, index]
|
| 150 |
+
tmp_intrinsics = intrinsics[:, index]
|
| 151 |
+
|
| 152 |
+
use_time_based = frame_times is not None and len(frame_times) == len(index)
|
| 153 |
+
if use_time_based and len(index) > 1:
|
| 154 |
+
times = np.array([frame_times[i] for i in index], dtype=np.float32)
|
| 155 |
+
gaps = np.diff(times)
|
| 156 |
+
gaps[gaps < 0] = 0.0
|
| 157 |
+
total_gap = float(gaps.sum())
|
| 158 |
+
target_total_interp = max(1, (len(index) - 1) * base_interp_per_pair)
|
| 159 |
+
gap_scale = target_total_interp / total_gap if total_gap > 1e-6 else 0.0
|
| 160 |
+
else:
|
| 161 |
+
gaps = None
|
| 162 |
+
gap_scale = None
|
| 163 |
+
|
| 164 |
+
for i in range(len(index)-1):
|
| 165 |
+
exts.append(tmp_camtoworlds[:, i:i+1])
|
| 166 |
+
ints.append(tmp_intrinsics[:, i:i+1])
|
| 167 |
+
R0, t0 = tmp_camtoworlds[:, i, :3, :3], tmp_camtoworlds[:, i, :3, 3]
|
| 168 |
+
R1, t1 = tmp_camtoworlds[:, i + 1, :3, :3], tmp_camtoworlds[:, i + 1, :3, 3]
|
| 169 |
+
|
| 170 |
+
q0 = rotation_matrix_to_quaternion(R0)
|
| 171 |
+
q1 = rotation_matrix_to_quaternion(R1)
|
| 172 |
+
|
| 173 |
+
if use_time_based:
|
| 174 |
+
gap = float(gaps[i]) if gaps is not None else 0.0
|
| 175 |
+
num_interp = max(0, int(round(gap * gap_scale)))
|
| 176 |
+
else:
|
| 177 |
+
num_interp = base_interp_per_pair
|
| 178 |
+
|
| 179 |
+
for j in range(1, num_interp + 1):
|
| 180 |
+
alpha = j / (num_interp + 1)
|
| 181 |
+
t_interp = (1 - alpha) * t0 + alpha * t1
|
| 182 |
+
q_interp = slerp_quaternions(q0, q1, alpha)
|
| 183 |
+
R_interp = quaternion_to_rotation_matrix(q_interp)
|
| 184 |
+
|
| 185 |
+
ext = torch.eye(4, device=R_interp.device, dtype=R_interp.dtype)[None].repeat(b, 1, 1)
|
| 186 |
+
ext[:, :3, :3] = R_interp
|
| 187 |
+
ext[:, :3, 3] = t_interp
|
| 188 |
+
|
| 189 |
+
K0 = tmp_intrinsics[:, i]
|
| 190 |
+
K1 = tmp_intrinsics[:, i + 1]
|
| 191 |
+
K = (1 - alpha) * K0 + alpha * K1
|
| 192 |
+
|
| 193 |
+
exts.append(ext[:, None])
|
| 194 |
+
ints.append(K[:, None])
|
| 195 |
+
|
| 196 |
+
exts = torch.cat(exts, dim=1)[:1]
|
| 197 |
+
ints = torch.cat(ints, dim=1)[:1]
|
| 198 |
+
return exts, ints
|
| 199 |
+
|
| 200 |
+
def build_wobble_traj(nums, delta):
|
| 201 |
+
if s != 1:
|
| 202 |
+
raise ValueError("Wobble trajectory requires exactly 1 input view")
|
| 203 |
+
t = torch.linspace(0, 1, nums, dtype=torch.float32, device=camtoworlds.device)
|
| 204 |
+
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
|
| 205 |
+
tf = torch.eye(4, dtype=torch.float32, device=camtoworlds.device)
|
| 206 |
+
radius = delta * 0.15
|
| 207 |
+
tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
|
| 208 |
+
radius = radius[..., None]
|
| 209 |
+
radius = radius * t
|
| 210 |
+
tf[..., 0, 3] = torch.sin(2 * torch.pi * t) * radius
|
| 211 |
+
tf[..., 1, 3] = -torch.cos(2 * torch.pi * t) * radius
|
| 212 |
+
exts = camtoworlds @ tf
|
| 213 |
+
ints = intrinsics.repeat(1, exts.shape[1], 1, 1)
|
| 214 |
+
return exts, ints
|
| 215 |
+
|
| 216 |
+
if s > 1:
|
| 217 |
+
all_ext, all_int = build_interpolated_traj([i for i in range(s)], interp_per_pair)
|
| 218 |
+
else:
|
| 219 |
+
all_ext, all_int = build_wobble_traj(interp_per_pair * 12, splats["means"][0].median(dim=0).values.norm(dim=-1)[None])
|
| 220 |
+
|
| 221 |
+
rendered_rgbs, rendered_depths = [], []
|
| 222 |
+
chunk = 40
|
| 223 |
+
|
| 224 |
+
# Always prune splats to remove scale outliers
|
| 225 |
+
try:
|
| 226 |
+
pruned_splats = gs_renderer.prune_gs(splats, gs_renderer.voxel_size)
|
| 227 |
+
except (AttributeError, RuntimeError):
|
| 228 |
+
pruned_splats = splats
|
| 229 |
+
|
| 230 |
+
for st in tqdm(range(0, all_ext.shape[1], chunk)):
|
| 231 |
+
ed = min(st + chunk, all_ext.shape[1])
|
| 232 |
+
colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
|
| 233 |
+
pruned_splats["means"][:1], pruned_splats["quats"][:1], pruned_splats["scales"][:1],
|
| 234 |
+
pruned_splats["opacities"][:1],
|
| 235 |
+
pruned_splats["sh"][:1] if "sh" in pruned_splats else pruned_splats["colors"][:1],
|
| 236 |
+
all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
|
| 237 |
+
width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None,
|
| 238 |
+
)
|
| 239 |
+
rendered_rgbs.append(colors)
|
| 240 |
+
if render_depth:
|
| 241 |
+
rendered_depths.append(depths)
|
| 242 |
+
|
| 243 |
+
rgbs = torch.cat(rendered_rgbs, dim=1)[0] # [N, H, W, 3]
|
| 244 |
+
if render_depth:
|
| 245 |
+
depths_all = torch.cat(rendered_depths, dim=1)[0, ..., 0] # [N, H, W]
|
| 246 |
+
del rendered_rgbs, rendered_depths
|
| 247 |
+
|
| 248 |
+
def _depth_vis(d: torch.Tensor) -> torch.Tensor:
|
| 249 |
+
"""Simple turbo colormap depth visualization."""
|
| 250 |
+
import matplotlib.pyplot as plt
|
| 251 |
+
valid = d > 0
|
| 252 |
+
if valid.any():
|
| 253 |
+
near = d[valid].float().quantile(0.01).log()
|
| 254 |
+
else:
|
| 255 |
+
near = torch.tensor(0.0, device=d.device)
|
| 256 |
+
far = d.flatten().float().quantile(0.99).log()
|
| 257 |
+
x = d.float().clamp(min=1e-9).log()
|
| 258 |
+
x = 1.0 - (x - near) / (far - near + 1e-9)
|
| 259 |
+
x_np = x.cpu().numpy()
|
| 260 |
+
colored = torch.from_numpy(plt.cm.turbo(x_np)[..., :3]).permute(2, 0, 1).float()
|
| 261 |
+
return colored
|
| 262 |
+
|
| 263 |
+
rgb_frames = []
|
| 264 |
+
depth_frames = []
|
| 265 |
+
|
| 266 |
+
if render_depth:
|
| 267 |
+
for rgb, dep in zip(rgbs, depths_all):
|
| 268 |
+
rgb_frames.append(rgb.permute(2, 0, 1))
|
| 269 |
+
depth_frames.append(_depth_vis(dep))
|
| 270 |
+
else:
|
| 271 |
+
for rgb in rgbs:
|
| 272 |
+
rgb_frames.append(rgb.permute(2, 0, 1))
|
| 273 |
+
|
| 274 |
+
def _make_video(frames, path):
|
| 275 |
+
video = torch.stack([f.cpu() for f in frames]).clamp(0, 1)
|
| 276 |
+
video = video.permute(0, 2, 3, 1)
|
| 277 |
+
video = (video * 255).to(torch.uint8).numpy()
|
| 278 |
+
if loop_reverse and video.shape[0] > 1:
|
| 279 |
+
video = np.concatenate([video, video[::-1][1:-1]], axis=0)
|
| 280 |
+
clip = mpy.ImageSequenceClip(list(video), fps=30)
|
| 281 |
+
clip.write_videofile(str(path), logger=None)
|
| 282 |
+
|
| 283 |
+
out_path = Path(out_path)
|
| 284 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 285 |
+
if save_mode == 'split':
|
| 286 |
+
_make_video(rgb_frames, out_path / "rendered_rgb.mp4")
|
| 287 |
+
if render_depth:
|
| 288 |
+
_make_video(depth_frames, out_path / "rendered_depth.mp4")
|
| 289 |
+
elif save_mode == 'both' and render_depth:
|
| 290 |
+
combined = [torch.cat([r, d], dim=1) for r, d in zip(rgb_frames, depth_frames)]
|
| 291 |
+
_make_video(combined, out_path / "rendered.mp4")
|
| 292 |
+
|
| 293 |
+
print(f"Video saved to {out_path} (mode: {save_mode})")
|
| 294 |
+
torch.cuda.empty_cache()
|
hyworldmirror/utils/save_utils.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for saving images, depths, normals, point clouds, and Gaussian splat data.
|
| 3 |
+
tencent
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from plyfile import PlyData, PlyElement
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
def save_camera_params(extrinsics, intrinsics, target_dir):
|
| 16 |
+
"""
|
| 17 |
+
Save camera parameters (extrinsics and intrinsics) in JSON format
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
extrinsics: numpy array, shape [N, 4, 4] - extrinsic matrices for N cameras
|
| 21 |
+
intrinsics: numpy array, shape [N, 3, 3] - intrinsic matrices for N cameras
|
| 22 |
+
target_dir: str - directory to save the parameters
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
str: path to the saved file
|
| 26 |
+
"""
|
| 27 |
+
camera_data = {
|
| 28 |
+
"num_cameras": int(extrinsics.shape[0]),
|
| 29 |
+
"extrinsics": [],
|
| 30 |
+
"intrinsics": []
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Convert each camera's parameters to list format
|
| 34 |
+
for i in range(extrinsics.shape[0]):
|
| 35 |
+
camera_data["extrinsics"].append({
|
| 36 |
+
"camera_id": i,
|
| 37 |
+
"matrix": extrinsics[i].tolist() # [4, 4] -> list
|
| 38 |
+
})
|
| 39 |
+
camera_data["intrinsics"].append({
|
| 40 |
+
"camera_id": i,
|
| 41 |
+
"matrix": intrinsics[i].tolist() # [3, 3] -> list
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
# Save as JSON file
|
| 45 |
+
camera_params_path = os.path.join(target_dir, "camera_params.json")
|
| 46 |
+
with open(camera_params_path, 'w') as f:
|
| 47 |
+
json.dump(camera_data, f, indent=2)
|
| 48 |
+
|
| 49 |
+
return camera_params_path
|
| 50 |
+
|
| 51 |
+
def save_image_png(path: Path, image_tensor: torch.Tensor) -> None:
|
| 52 |
+
# image_tensor: [H, W, 3]
|
| 53 |
+
img = (image_tensor.detach().cpu() * 255.0).to(torch.uint8).numpy()
|
| 54 |
+
Image.fromarray(img).save(str(path))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def save_depth_png(path: Path, depth_tensor: torch.Tensor) -> None:
|
| 58 |
+
# depth_tensor: [H, W]
|
| 59 |
+
d = depth_tensor.detach()
|
| 60 |
+
d = d - d.min()
|
| 61 |
+
d = d / (d.max() + 1e-9)
|
| 62 |
+
img = (d.clamp(0, 1) * 255.0).to(torch.uint8).cpu().numpy()
|
| 63 |
+
Image.fromarray(img, mode="L").save(str(path))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def save_depth_npy(path: Path, depth_tensor: torch.Tensor) -> None:
|
| 67 |
+
# depth_tensor: [H, W]
|
| 68 |
+
# Save actual depth values in numpy format
|
| 69 |
+
d = depth_tensor.detach().cpu().numpy()
|
| 70 |
+
np.save(str(path), d)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def save_normal_png(path: Path, normal_hwc: torch.Tensor) -> None:
|
| 74 |
+
# normal_hwc: [H, W, 3], in [-1, 1]
|
| 75 |
+
n = (normal_hwc.detach().cpu() + 1.0) * 0.5
|
| 76 |
+
img = (n.clamp(0, 1) * 255.0).to(torch.uint8).numpy()
|
| 77 |
+
Image.fromarray(img).save(str(path))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _build_vertex_ply_element(pts: np.ndarray, colors: np.ndarray) -> PlyElement:
|
| 81 |
+
"""Build a PLY vertex element from points and colors arrays.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
pts: Point coordinates, shape [N, 3], dtype float32
|
| 85 |
+
colors: RGB colors, shape [N, 3], dtype uint8
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
PlyElement describing the vertices
|
| 89 |
+
"""
|
| 90 |
+
vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4"),
|
| 91 |
+
("red", "u1"), ("green", "u1"), ("blue", "u1")]
|
| 92 |
+
vertex_elements = np.empty(len(pts), dtype=vertex_dtype)
|
| 93 |
+
vertex_elements["x"] = pts[:, 0]
|
| 94 |
+
vertex_elements["y"] = pts[:, 1]
|
| 95 |
+
vertex_elements["z"] = pts[:, 2]
|
| 96 |
+
vertex_elements["red"] = colors[:, 0]
|
| 97 |
+
vertex_elements["green"] = colors[:, 1]
|
| 98 |
+
vertex_elements["blue"] = colors[:, 2]
|
| 99 |
+
return PlyElement.describe(vertex_elements, "vertex")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def save_scene_ply(path: Path,
|
| 103 |
+
points_xyz: torch.Tensor,
|
| 104 |
+
point_colors: torch.Tensor,
|
| 105 |
+
valid_mask: torch.Tensor = None) -> None:
|
| 106 |
+
"""Save point cloud to PLY format"""
|
| 107 |
+
pts = points_xyz.detach().cpu().to(torch.float32).numpy().reshape(-1, 3)
|
| 108 |
+
colors = point_colors.detach().cpu().to(torch.uint8).numpy().reshape(-1, 3)
|
| 109 |
+
|
| 110 |
+
# Filter out invalid points (NaN, Inf)
|
| 111 |
+
if valid_mask is None:
|
| 112 |
+
valid_mask = np.isfinite(pts).all(axis=1)
|
| 113 |
+
else:
|
| 114 |
+
valid_mask = valid_mask.detach().cpu().numpy().reshape(-1)
|
| 115 |
+
pts = pts[valid_mask]
|
| 116 |
+
colors = colors[valid_mask]
|
| 117 |
+
|
| 118 |
+
# Handle empty point cloud
|
| 119 |
+
if len(pts) == 0:
|
| 120 |
+
pts = np.array([[0, 0, 0]], dtype=np.float32)
|
| 121 |
+
colors = np.array([[255, 255, 255]], dtype=np.uint8)
|
| 122 |
+
|
| 123 |
+
PlyData([_build_vertex_ply_element(pts, colors)]).write(str(path))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def save_points_ply(path: Path, pts_np: np.ndarray, cols_np: np.ndarray) -> None:
|
| 127 |
+
"""Save point cloud to PLY format from numpy arrays"""
|
| 128 |
+
PlyData([_build_vertex_ply_element(pts_np, cols_np)]).write(str(path))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _build_gs_ply_data(means, scales, rotations, rgbs, opacities, quantile_threshold):
|
| 132 |
+
"""Build Gaussian splat PLY data with scale-based filtering.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
means: Gaussian centers [N, 3]
|
| 136 |
+
scales: Gaussian scales [N, 3]
|
| 137 |
+
rotations: Gaussian rotations as quaternions [N, 4]
|
| 138 |
+
rgbs: RGB colors [N, 3]
|
| 139 |
+
opacities: Opacity values [N]
|
| 140 |
+
quantile_threshold: Percentile threshold for scale filtering (e.g. 0.98 or 0.90)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
PlyData object ready to be written or returned
|
| 144 |
+
"""
|
| 145 |
+
scale_threshold = torch.quantile(scales.max(dim=-1)[0], quantile_threshold, dim=0)
|
| 146 |
+
filter_mask = scales.max(dim=-1)[0] <= scale_threshold
|
| 147 |
+
|
| 148 |
+
means = means[filter_mask].reshape(-1, 3)
|
| 149 |
+
scales = scales[filter_mask].reshape(-1, 3)
|
| 150 |
+
rotations = rotations[filter_mask].reshape(-1, 4)
|
| 151 |
+
rgbs = rgbs[filter_mask].reshape(-1, 3)
|
| 152 |
+
opacities = opacities[filter_mask].reshape(-1)
|
| 153 |
+
|
| 154 |
+
attributes = ["x", "y", "z", "nx", "ny", "nz"]
|
| 155 |
+
for i in range(3):
|
| 156 |
+
attributes.append(f"f_dc_{i}")
|
| 157 |
+
attributes.append("opacity")
|
| 158 |
+
for i in range(3):
|
| 159 |
+
attributes.append(f"scale_{i}")
|
| 160 |
+
for i in range(4):
|
| 161 |
+
attributes.append(f"rot_{i}")
|
| 162 |
+
|
| 163 |
+
dtype_full = [(attribute, "f4") for attribute in attributes]
|
| 164 |
+
elements = np.empty(means.shape[0], dtype=dtype_full)
|
| 165 |
+
|
| 166 |
+
attributes_data = (
|
| 167 |
+
means.float().detach().cpu().numpy(),
|
| 168 |
+
torch.zeros_like(means).float().detach().cpu().numpy(),
|
| 169 |
+
rgbs.detach().cpu().contiguous().numpy(),
|
| 170 |
+
opacities[..., None].detach().cpu().numpy(),
|
| 171 |
+
scales.log().detach().cpu().numpy(),
|
| 172 |
+
rotations.detach().cpu().numpy(),
|
| 173 |
+
)
|
| 174 |
+
attributes_data = np.concatenate(attributes_data, axis=1)
|
| 175 |
+
elements[:] = list(map(tuple, attributes_data))
|
| 176 |
+
|
| 177 |
+
return PlyData([PlyElement.describe(elements, "vertex")])
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def save_gs_ply(path: Path,
|
| 181 |
+
means: torch.Tensor,
|
| 182 |
+
scales: torch.Tensor,
|
| 183 |
+
rotations: torch.Tensor,
|
| 184 |
+
rgbs: torch.Tensor,
|
| 185 |
+
opacities: torch.Tensor) -> None:
|
| 186 |
+
"""
|
| 187 |
+
Export Gaussian splat data to PLY format.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
path: Output PLY file path
|
| 191 |
+
means: Gaussian centers [N, 3]
|
| 192 |
+
scales: Gaussian scales [N, 3]
|
| 193 |
+
rotations: Gaussian rotations as quaternions [N, 4]
|
| 194 |
+
rgbs: RGB colors [N, 3]
|
| 195 |
+
opacities: Opacity values [N]
|
| 196 |
+
"""
|
| 197 |
+
# Ensure float32 for quantile and numpy conversion (bf16 not supported)
|
| 198 |
+
means, scales, rotations, rgbs, opacities = (
|
| 199 |
+
t.float() for t in (means, scales, rotations, rgbs, opacities)
|
| 200 |
+
)
|
| 201 |
+
plydata = _build_gs_ply_data(means, scales, rotations, rgbs, opacities, quantile_threshold=0.98)
|
| 202 |
+
plydata.write(str(path))
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def convert_gs_to_ply(means, scales, rotations, rgbs, opacities):
|
| 206 |
+
"""
|
| 207 |
+
Export Gaussian splat data to PLY format.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
means: Gaussian centers [N, 3]
|
| 211 |
+
scales: Gaussian scales [N, 3]
|
| 212 |
+
rotations: Gaussian rotations as quaternions [N, 4]
|
| 213 |
+
rgbs: RGB colors [N, 3]
|
| 214 |
+
opacities: Opacity values [N]
|
| 215 |
+
"""
|
| 216 |
+
return _build_gs_ply_data(means, scales, rotations, rgbs, opacities, quantile_threshold=0.90)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def process_ply_to_splat(plydata, output_path):
|
| 220 |
+
vert = plydata["vertex"]
|
| 221 |
+
sorted_indices = np.argsort(
|
| 222 |
+
-np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
|
| 223 |
+
/ (1 + np.exp(-vert["opacity"]))
|
| 224 |
+
)
|
| 225 |
+
buffer = BytesIO()
|
| 226 |
+
for idx in sorted_indices:
|
| 227 |
+
v = plydata["vertex"][idx]
|
| 228 |
+
position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
|
| 229 |
+
scales = np.exp(
|
| 230 |
+
np.array(
|
| 231 |
+
[v["scale_0"], v["scale_1"], v["scale_2"]],
|
| 232 |
+
dtype=np.float32,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
rot = np.array(
|
| 236 |
+
[v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
|
| 237 |
+
dtype=np.float32,
|
| 238 |
+
)
|
| 239 |
+
SH_C0 = 0.28209479177387814
|
| 240 |
+
color = np.array(
|
| 241 |
+
[
|
| 242 |
+
0.5 + SH_C0 * v["f_dc_0"],
|
| 243 |
+
0.5 + SH_C0 * v["f_dc_1"],
|
| 244 |
+
0.5 + SH_C0 * v["f_dc_2"],
|
| 245 |
+
1 / (1 + np.exp(-v["opacity"])),
|
| 246 |
+
]
|
| 247 |
+
)
|
| 248 |
+
buffer.write(position.tobytes())
|
| 249 |
+
buffer.write(scales.tobytes())
|
| 250 |
+
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
|
| 251 |
+
buffer.write(
|
| 252 |
+
((rot / np.linalg.norm(rot)) * 128 + 128)
|
| 253 |
+
.clip(0, 255)
|
| 254 |
+
.astype(np.uint8)
|
| 255 |
+
.tobytes()
|
| 256 |
+
)
|
| 257 |
+
value = buffer.getvalue()
|
| 258 |
+
with open(output_path, "wb") as f:
|
| 259 |
+
f.write(value)
|
| 260 |
+
|
| 261 |
+
return output_path
|
hyworldmirror/utils/video_utils.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import csv
|
| 4 |
+
import time
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import subprocess
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def video_to_image_frames(input_video_path, save_directory=None, fps=1):
|
| 13 |
+
"""
|
| 14 |
+
Extracts image frames from a video file at the specified frame rate and saves them as JPEG format.
|
| 15 |
+
Supports regular video files, webcam captures, WebM files, and GIF files, including incomplete files.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
input_video_path: Path to the input video file
|
| 19 |
+
save_directory: Directory to save extracted frames (default: None)
|
| 20 |
+
fps: Number of frames to extract per second (default: 1)
|
| 21 |
+
|
| 22 |
+
Returns: List of file paths to extracted frames
|
| 23 |
+
"""
|
| 24 |
+
extracted_frame_paths = []
|
| 25 |
+
frame_indices = [] # Track frame indices for metadata
|
| 26 |
+
source_fps = None
|
| 27 |
+
|
| 28 |
+
# For GIF files, use PIL library for better handling
|
| 29 |
+
if input_video_path.lower().endswith('.gif'):
|
| 30 |
+
try:
|
| 31 |
+
print(f"Processing GIF file using PIL: {input_video_path}")
|
| 32 |
+
|
| 33 |
+
with Image.open(input_video_path) as gif_img:
|
| 34 |
+
# Get GIF properties
|
| 35 |
+
frame_duration_ms = gif_img.info.get('duration', 100)
|
| 36 |
+
gif_frame_rate = 1000.0 / frame_duration_ms if frame_duration_ms > 0 else 10.0
|
| 37 |
+
source_fps = gif_frame_rate
|
| 38 |
+
|
| 39 |
+
print(f"GIF properties: {gif_img.n_frames} frames, {gif_frame_rate:.2f} FPS, {frame_duration_ms}ms per frame")
|
| 40 |
+
|
| 41 |
+
sampling_interval = max(1, int(gif_frame_rate / fps)) if fps < gif_frame_rate else 1
|
| 42 |
+
|
| 43 |
+
saved_count = 0
|
| 44 |
+
for current_frame_index in range(gif_img.n_frames):
|
| 45 |
+
gif_img.seek(current_frame_index)
|
| 46 |
+
|
| 47 |
+
if current_frame_index % sampling_interval == 0:
|
| 48 |
+
rgb_frame = gif_img.convert('RGB')
|
| 49 |
+
frame_ndarray = np.array(rgb_frame)
|
| 50 |
+
frame_output_path = os.path.join(save_directory, f"frame_{saved_count:06d}.jpg")
|
| 51 |
+
pil_image = Image.fromarray(frame_ndarray)
|
| 52 |
+
pil_image.save(frame_output_path, 'JPEG', quality=95)
|
| 53 |
+
extracted_frame_paths.append(frame_output_path)
|
| 54 |
+
frame_indices.append(current_frame_index)
|
| 55 |
+
saved_count += 1
|
| 56 |
+
|
| 57 |
+
if extracted_frame_paths:
|
| 58 |
+
print(f"Successfully extracted {len(extracted_frame_paths)} frames from GIF using PIL")
|
| 59 |
+
# Save metadata
|
| 60 |
+
_save_old_metadata(save_directory, frame_indices, source_fps)
|
| 61 |
+
return extracted_frame_paths
|
| 62 |
+
|
| 63 |
+
except Exception as error:
|
| 64 |
+
print(f"PIL GIF extraction error: {str(error)}, falling back to OpenCV")
|
| 65 |
+
|
| 66 |
+
# For WebM files, use FFmpeg directly for more stable processing
|
| 67 |
+
if input_video_path.lower().endswith('.webm'):
|
| 68 |
+
try:
|
| 69 |
+
print(f"Processing WebM file using FFmpeg: {input_video_path}")
|
| 70 |
+
|
| 71 |
+
# Get video FPS first
|
| 72 |
+
cap = cv2.VideoCapture(input_video_path)
|
| 73 |
+
source_fps = cap.get(cv2.CAP_PROP_FPS) or fps
|
| 74 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 75 |
+
cap.release()
|
| 76 |
+
|
| 77 |
+
output_frame_pattern = os.path.join(save_directory, "frame_%04d.jpg")
|
| 78 |
+
|
| 79 |
+
ffmpeg_command = [
|
| 80 |
+
"ffmpeg",
|
| 81 |
+
"-i", input_video_path,
|
| 82 |
+
"-vf", f"fps={fps}",
|
| 83 |
+
"-q:v", "2",
|
| 84 |
+
output_frame_pattern
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
ffmpeg_process = subprocess.Popen(
|
| 88 |
+
ffmpeg_command,
|
| 89 |
+
stdout=subprocess.PIPE,
|
| 90 |
+
stderr=subprocess.PIPE
|
| 91 |
+
)
|
| 92 |
+
process_stdout, process_stderr = ffmpeg_process.communicate()
|
| 93 |
+
|
| 94 |
+
# Collect all extracted frames and calculate indices
|
| 95 |
+
extracted_frame_paths = []
|
| 96 |
+
for filename in sorted(os.listdir(save_directory)):
|
| 97 |
+
if filename.startswith("frame_") and filename.endswith(".jpg"):
|
| 98 |
+
full_frame_path = os.path.join(save_directory, filename)
|
| 99 |
+
extracted_frame_paths.append(full_frame_path)
|
| 100 |
+
# Extract frame number from filename (frame_XXXX.jpg)
|
| 101 |
+
try:
|
| 102 |
+
frame_num = int(filename.split("_")[1].split(".")[0])
|
| 103 |
+
# Estimate original frame index based on fps ratio
|
| 104 |
+
frame_idx = int(frame_num * source_fps / fps)
|
| 105 |
+
frame_indices.append(frame_idx)
|
| 106 |
+
except:
|
| 107 |
+
frame_indices.append(len(frame_indices))
|
| 108 |
+
|
| 109 |
+
if extracted_frame_paths:
|
| 110 |
+
print(f"Successfully extracted {len(extracted_frame_paths)} frames from WebM using FFmpeg")
|
| 111 |
+
_save_old_metadata(save_directory, frame_indices, source_fps)
|
| 112 |
+
return extracted_frame_paths
|
| 113 |
+
|
| 114 |
+
print("FFmpeg extraction failed, falling back to OpenCV")
|
| 115 |
+
except Exception as error:
|
| 116 |
+
print(f"FFmpeg extraction error: {str(error)}, falling back to OpenCV")
|
| 117 |
+
|
| 118 |
+
# Standard OpenCV method for non-WebM files or as fallback
|
| 119 |
+
try:
|
| 120 |
+
video_capture = cv2.VideoCapture(input_video_path)
|
| 121 |
+
|
| 122 |
+
if input_video_path.lower().endswith('.webm'):
|
| 123 |
+
video_capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'VP80'))
|
| 124 |
+
|
| 125 |
+
source_fps = video_capture.get(cv2.CAP_PROP_FPS) or fps
|
| 126 |
+
extraction_interval = max(1, int(source_fps / fps))
|
| 127 |
+
processed_frame_count = 0
|
| 128 |
+
|
| 129 |
+
cv2.setLogLevel(0)
|
| 130 |
+
|
| 131 |
+
while True:
|
| 132 |
+
read_success, current_frame = video_capture.read()
|
| 133 |
+
if not read_success:
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
if processed_frame_count % extraction_interval == 0:
|
| 137 |
+
try:
|
| 138 |
+
if current_frame is not None and current_frame.size > 0:
|
| 139 |
+
rgb_converted_frame = cv2.cvtColor(current_frame, cv2.COLOR_BGR2RGB)
|
| 140 |
+
frame_output_path = os.path.join(save_directory, f"frame_{len(extracted_frame_paths):06d}.jpg")
|
| 141 |
+
cv2.imwrite(frame_output_path, cv2.cvtColor(rgb_converted_frame, cv2.COLOR_RGB2BGR))
|
| 142 |
+
extracted_frame_paths.append(frame_output_path)
|
| 143 |
+
frame_indices.append(processed_frame_count)
|
| 144 |
+
except Exception as error:
|
| 145 |
+
print(f"Warning: Failed to process frame {processed_frame_count}: {str(error)}")
|
| 146 |
+
|
| 147 |
+
processed_frame_count += 1
|
| 148 |
+
|
| 149 |
+
if processed_frame_count > 1000:
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
video_capture.release()
|
| 153 |
+
print(f"Extracted {len(extracted_frame_paths)} frames from video using OpenCV")
|
| 154 |
+
|
| 155 |
+
# Save metadata
|
| 156 |
+
if extracted_frame_paths:
|
| 157 |
+
_save_old_metadata(save_directory, frame_indices, source_fps)
|
| 158 |
+
|
| 159 |
+
except Exception as error:
|
| 160 |
+
print(f"Error extracting frames: {str(error)}")
|
| 161 |
+
|
| 162 |
+
return extracted_frame_paths
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _save_old_metadata(save_directory, frame_indices, fps):
|
| 166 |
+
"""Save metadata for old sampling strategy."""
|
| 167 |
+
if not frame_indices or not fps:
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
meta = {
|
| 172 |
+
"frame_indices": frame_indices,
|
| 173 |
+
"frame_times": [idx / fps for idx in frame_indices],
|
| 174 |
+
"fps": fps,
|
| 175 |
+
"algorithm": "uniform_fps_based"
|
| 176 |
+
}
|
| 177 |
+
metadata_path = os.path.join(save_directory, "frame_metadata.json")
|
| 178 |
+
with open(metadata_path, "w") as f:
|
| 179 |
+
json.dump(meta, f, indent=2)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Warning: Failed to save metadata: {e}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _resize_for_flow(frame, long_edge=320):
|
| 185 |
+
height, width = frame.shape[:2]
|
| 186 |
+
long_side = max(height, width)
|
| 187 |
+
if long_side <= long_edge:
|
| 188 |
+
return frame
|
| 189 |
+
scale = long_edge / float(long_side)
|
| 190 |
+
new_w = max(1, int(width * scale))
|
| 191 |
+
new_h = max(1, int(height * scale))
|
| 192 |
+
return cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _resize_for_clarity(frame, long_edge=480):
|
| 196 |
+
"""Resize frame for clarity calculation (480p for better accuracy)."""
|
| 197 |
+
height, width = frame.shape[:2]
|
| 198 |
+
long_side = max(height, width)
|
| 199 |
+
if long_side <= long_edge:
|
| 200 |
+
return frame
|
| 201 |
+
scale = long_edge / float(long_side)
|
| 202 |
+
new_w = max(1, int(width * scale))
|
| 203 |
+
new_h = max(1, int(height * scale))
|
| 204 |
+
return cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _create_dis_flow():
|
| 208 |
+
if hasattr(cv2, "optflow") and hasattr(cv2.optflow, "createOptFlow_DIS"):
|
| 209 |
+
return cv2.optflow.createOptFlow_DIS(cv2.optflow.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
|
| 210 |
+
if hasattr(cv2, "DISOpticalFlow_create"):
|
| 211 |
+
return cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _calculate_histogram(image):
|
| 216 |
+
"""
|
| 217 |
+
Calculate normalized color histogram for global deduplication.
|
| 218 |
+
Using HSV for better robustness to brightness changes.
|
| 219 |
+
"""
|
| 220 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
| 221 |
+
# 8 bins for H, 4 for S, 4 for V -> 128 dim vector
|
| 222 |
+
hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 4, 4], [0, 180, 0, 256, 0, 256])
|
| 223 |
+
cv2.normalize(hist, hist)
|
| 224 |
+
return hist.flatten()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _calculate_hist_similarity(hist1, hist2):
|
| 228 |
+
return cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _advance_cap_to_frame(cap, current_pos, target_idx):
|
| 232 |
+
"""Advance cap so that next read() returns frame target_idx. Returns target_idx."""
|
| 233 |
+
dist = target_idx - current_pos
|
| 234 |
+
if dist <= 0:
|
| 235 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
|
| 236 |
+
return target_idx
|
| 237 |
+
if dist < 50:
|
| 238 |
+
for _ in range(dist):
|
| 239 |
+
cap.grab()
|
| 240 |
+
return target_idx
|
| 241 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
|
| 242 |
+
return target_idx
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _merge_search_windows(candidate_indices, window_size=3):
|
| 246 |
+
"""
|
| 247 |
+
Merge adjacent search windows to reduce disk seeks.
|
| 248 |
+
Returns list of (start_idx, end_idx, target_indices) tuples.
|
| 249 |
+
"""
|
| 250 |
+
if not candidate_indices:
|
| 251 |
+
return []
|
| 252 |
+
|
| 253 |
+
merged = []
|
| 254 |
+
sorted_indices = sorted(candidate_indices)
|
| 255 |
+
i = 0
|
| 256 |
+
|
| 257 |
+
while i < len(sorted_indices):
|
| 258 |
+
start_idx = max(0, sorted_indices[i] - window_size)
|
| 259 |
+
end_idx = sorted_indices[i] + window_size
|
| 260 |
+
targets_in_window = [sorted_indices[i]]
|
| 261 |
+
|
| 262 |
+
# Extend window to include adjacent candidates
|
| 263 |
+
j = i + 1
|
| 264 |
+
while j < len(sorted_indices):
|
| 265 |
+
next_start = max(0, sorted_indices[j] - window_size)
|
| 266 |
+
if next_start <= end_idx:
|
| 267 |
+
end_idx = sorted_indices[j] + window_size
|
| 268 |
+
targets_in_window.append(sorted_indices[j])
|
| 269 |
+
j += 1
|
| 270 |
+
else:
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
merged.append((start_idx, end_idx, targets_in_window))
|
| 274 |
+
i = j
|
| 275 |
+
|
| 276 |
+
return merged
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _sparse_motion_analysis(cap, fps, total_frames):
|
| 280 |
+
"""Phase 1: Sparse sampling with DIS optical flow."""
|
| 281 |
+
sample_interval = max(1, int(fps * 0.5))
|
| 282 |
+
sparse_samples = []
|
| 283 |
+
dis_flow = _create_dis_flow()
|
| 284 |
+
current_idx = 0
|
| 285 |
+
prev_gray = None
|
| 286 |
+
|
| 287 |
+
while True:
|
| 288 |
+
if current_idx > 0:
|
| 289 |
+
steps_to_skip = sample_interval - 1
|
| 290 |
+
if steps_to_skip > 0:
|
| 291 |
+
current_idx = _advance_cap_to_frame(cap, current_idx, current_idx + steps_to_skip)
|
| 292 |
+
ret, frame = cap.read()
|
| 293 |
+
if not ret:
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
small = _resize_for_flow(frame, long_edge=320)
|
| 297 |
+
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
|
| 298 |
+
|
| 299 |
+
motion_mag = 0.0
|
| 300 |
+
if prev_gray is not None:
|
| 301 |
+
if dis_flow is not None:
|
| 302 |
+
flow = dis_flow.calc(prev_gray, gray, None)
|
| 303 |
+
else:
|
| 304 |
+
flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, None, 0.5, 3, 15, 2, 5, 1.2, 0)
|
| 305 |
+
motion_mag = float(np.mean(np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)))
|
| 306 |
+
|
| 307 |
+
sparse_samples.append({
|
| 308 |
+
"idx": current_idx,
|
| 309 |
+
"motion": motion_mag,
|
| 310 |
+
"hist": _calculate_histogram(small)
|
| 311 |
+
})
|
| 312 |
+
prev_gray = gray
|
| 313 |
+
current_idx += 1
|
| 314 |
+
|
| 315 |
+
return sparse_samples
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _adaptive_frame_selection(sparse_samples, fps, max_frames):
|
| 319 |
+
"""Phase 2: Adaptive threshold allocation with deduplication."""
|
| 320 |
+
motions = [s["motion"] for s in sparse_samples[1:]]
|
| 321 |
+
|
| 322 |
+
if not motions:
|
| 323 |
+
return [sparse_samples[0]["idx"]]
|
| 324 |
+
|
| 325 |
+
# Calculate adaptive threshold
|
| 326 |
+
static_floor = 1.0
|
| 327 |
+
total_motion = sum(motions)
|
| 328 |
+
estimated_step = total_motion / max_frames if max_frames > 0 else total_motion
|
| 329 |
+
step_threshold = max(static_floor * 5.0, estimated_step)
|
| 330 |
+
|
| 331 |
+
# Select frames based on accumulated motion
|
| 332 |
+
candidate_indices = [sparse_samples[0]["idx"]]
|
| 333 |
+
selected_hists = [sparse_samples[0]["hist"]]
|
| 334 |
+
current_accum = 0.0
|
| 335 |
+
last_selected_idx = sparse_samples[0]["idx"]
|
| 336 |
+
|
| 337 |
+
for i in range(1, len(sparse_samples)):
|
| 338 |
+
s = sparse_samples[i]
|
| 339 |
+
effective_motion = s["motion"] if s["motion"] >= static_floor else 0.0
|
| 340 |
+
current_accum += effective_motion
|
| 341 |
+
|
| 342 |
+
time_gap = s["idx"] - last_selected_idx
|
| 343 |
+
should_select = (current_accum >= step_threshold) or (time_gap > (4.0 * fps))
|
| 344 |
+
|
| 345 |
+
if should_select:
|
| 346 |
+
is_duplicate = any(_calculate_hist_similarity(s["hist"], h) > 0.999 for h in selected_hists)
|
| 347 |
+
if not is_duplicate:
|
| 348 |
+
candidate_indices.append(s["idx"])
|
| 349 |
+
selected_hists.append(s["hist"])
|
| 350 |
+
current_accum = 0.0
|
| 351 |
+
last_selected_idx = s["idx"]
|
| 352 |
+
|
| 353 |
+
# Always check last frame
|
| 354 |
+
if sparse_samples[-1]["idx"] != candidate_indices[-1]:
|
| 355 |
+
last_hist = sparse_samples[-1]["hist"]
|
| 356 |
+
if not any(_calculate_hist_similarity(last_hist, h) > 0.999 for h in selected_hists):
|
| 357 |
+
candidate_indices.append(sparse_samples[-1]["idx"])
|
| 358 |
+
|
| 359 |
+
return sorted(list(set(candidate_indices)))
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _enforce_frame_constraints(candidate_indices, sparse_samples, min_frames, max_frames):
|
| 363 |
+
"""Enforce min/max frame constraints."""
|
| 364 |
+
if len(candidate_indices) < min_frames:
|
| 365 |
+
needed = min_frames - len(candidate_indices)
|
| 366 |
+
all_indices = [s["idx"] for s in sparse_samples]
|
| 367 |
+
extras = np.linspace(0, len(all_indices)-1, needed+2)[1:-1]
|
| 368 |
+
candidate_indices.extend([all_indices[int(e)] for e in extras])
|
| 369 |
+
candidate_indices = sorted(list(set(candidate_indices)))
|
| 370 |
+
|
| 371 |
+
if len(candidate_indices) > max_frames:
|
| 372 |
+
indices_to_keep = np.linspace(0, len(candidate_indices)-1, max_frames)
|
| 373 |
+
candidate_indices = [candidate_indices[int(round(i))] for i in indices_to_keep]
|
| 374 |
+
|
| 375 |
+
return candidate_indices
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _read_window_frames(cap, merged_windows, total_frames):
|
| 379 |
+
"""Read all frames from merged windows."""
|
| 380 |
+
all_frames = []
|
| 381 |
+
current_pos = -1
|
| 382 |
+
|
| 383 |
+
for window_idx, (window_start, window_end, _) in enumerate(merged_windows):
|
| 384 |
+
current_pos = _advance_cap_to_frame(cap, current_pos, window_start)
|
| 385 |
+
for idx in range(window_start, min(window_end + 1, total_frames)):
|
| 386 |
+
ret, frame = cap.read()
|
| 387 |
+
if not ret:
|
| 388 |
+
break
|
| 389 |
+
all_frames.append((window_idx, idx, frame))
|
| 390 |
+
current_pos = idx + 1
|
| 391 |
+
|
| 392 |
+
return all_frames
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _compute_clarity_parallel(all_frames):
|
| 396 |
+
"""Parallel clarity calculation."""
|
| 397 |
+
def _compute(item):
|
| 398 |
+
window_idx, frame_idx, frame = item
|
| 399 |
+
clarity_frame = _resize_for_clarity(frame, long_edge=480)
|
| 400 |
+
gray = cv2.cvtColor(clarity_frame, cv2.COLOR_BGR2GRAY)
|
| 401 |
+
clarity = cv2.Laplacian(gray, cv2.CV_64F).var()
|
| 402 |
+
return (window_idx, frame_idx, frame, clarity)
|
| 403 |
+
|
| 404 |
+
with ThreadPoolExecutor(max_workers=min(8, len(all_frames) or 1)) as ex:
|
| 405 |
+
return list(ex.map(_compute, all_frames))
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def _select_best_frames(clarity_results, merged_windows, candidate_indices, search_window_size=3):
|
| 409 |
+
"""Select best frame for each candidate based on clarity."""
|
| 410 |
+
# Group by window
|
| 411 |
+
window_frames = {}
|
| 412 |
+
for window_idx, frame_idx, frame, clarity in clarity_results:
|
| 413 |
+
if window_idx not in window_frames:
|
| 414 |
+
window_frames[window_idx] = []
|
| 415 |
+
window_frames[window_idx].append((frame_idx, frame, clarity))
|
| 416 |
+
|
| 417 |
+
# Select best frame for each target
|
| 418 |
+
target_to_best = {}
|
| 419 |
+
for window_idx, (_, _, targets) in enumerate(merged_windows):
|
| 420 |
+
frames = window_frames.get(window_idx, [])
|
| 421 |
+
for target_idx in targets:
|
| 422 |
+
candidates = [(idx, f, c) for idx, f, c in frames
|
| 423 |
+
if abs(idx - target_idx) <= search_window_size]
|
| 424 |
+
if candidates:
|
| 425 |
+
best_idx, best_frame, _ = max(candidates, key=lambda x: x[2])
|
| 426 |
+
target_to_best[target_idx] = (best_idx, best_frame)
|
| 427 |
+
elif frames:
|
| 428 |
+
closest = min(frames, key=lambda x: abs(x[0] - target_idx))
|
| 429 |
+
target_to_best[target_idx] = (closest[0], closest[1])
|
| 430 |
+
|
| 431 |
+
return target_to_best
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _save_frames_parallel(target_to_best, candidate_indices, save_directory):
|
| 435 |
+
"""Parallel frame saving."""
|
| 436 |
+
path_frame_list = []
|
| 437 |
+
final_indices = []
|
| 438 |
+
|
| 439 |
+
for target_idx in sorted(candidate_indices):
|
| 440 |
+
if target_idx in target_to_best:
|
| 441 |
+
best_idx, best_frame = target_to_best[target_idx]
|
| 442 |
+
final_indices.append(best_idx)
|
| 443 |
+
path_frame_list.append((
|
| 444 |
+
os.path.join(save_directory, f"frame_{len(path_frame_list):06d}.jpg"),
|
| 445 |
+
best_frame
|
| 446 |
+
))
|
| 447 |
+
|
| 448 |
+
def _write(p_f):
|
| 449 |
+
cv2.imwrite(p_f[0], p_f[1])
|
| 450 |
+
return p_f[0]
|
| 451 |
+
|
| 452 |
+
with ThreadPoolExecutor(max_workers=min(8, len(path_frame_list) or 1)) as ex:
|
| 453 |
+
paths = list(ex.map(_write, path_frame_list))
|
| 454 |
+
|
| 455 |
+
return final_indices, paths
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def video_to_image_frames_new(
|
| 459 |
+
input_video_path,
|
| 460 |
+
save_directory=None,
|
| 461 |
+
min_frames=1,
|
| 462 |
+
max_frames=64,
|
| 463 |
+
fallback_fps=1,
|
| 464 |
+
):
|
| 465 |
+
"""
|
| 466 |
+
Motion-aware frame extraction with local clarity refinement.
|
| 467 |
+
|
| 468 |
+
Strategy:
|
| 469 |
+
1. Sparse sampling (~0.5s) with DIS optical flow
|
| 470 |
+
2. Adaptive threshold allocation based on motion
|
| 471 |
+
3. Local clarity refinement (±3 frames) to avoid blur
|
| 472 |
+
"""
|
| 473 |
+
if save_directory is None:
|
| 474 |
+
raise ValueError("save_directory must be provided")
|
| 475 |
+
|
| 476 |
+
max_frames = int(np.clip(max_frames, 1, 64))
|
| 477 |
+
min_frames = int(np.clip(min_frames, 1, max_frames))
|
| 478 |
+
|
| 479 |
+
cap = cv2.VideoCapture(input_video_path)
|
| 480 |
+
if not cap.isOpened():
|
| 481 |
+
print(f"Error: Failed to open video {input_video_path}")
|
| 482 |
+
return []
|
| 483 |
+
|
| 484 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or fallback_fps or 30.0
|
| 485 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 486 |
+
t_start = time.perf_counter()
|
| 487 |
+
|
| 488 |
+
# Phase 1: Sparse motion analysis
|
| 489 |
+
sparse_samples = _sparse_motion_analysis(cap, fps, total_frames)
|
| 490 |
+
cap.release()
|
| 491 |
+
|
| 492 |
+
t_phase1 = time.perf_counter()
|
| 493 |
+
print(f"[Timing] Phase 1 (Sparse Flow): {t_phase1 - t_start:.3f}s, Samples: {len(sparse_samples)}")
|
| 494 |
+
|
| 495 |
+
if not sparse_samples:
|
| 496 |
+
return []
|
| 497 |
+
|
| 498 |
+
# Phase 2: Adaptive frame selection
|
| 499 |
+
candidate_indices = _adaptive_frame_selection(sparse_samples, fps, max_frames)
|
| 500 |
+
candidate_indices = _enforce_frame_constraints(candidate_indices, sparse_samples, min_frames, max_frames)
|
| 501 |
+
|
| 502 |
+
# Phase 3: Local clarity refinement
|
| 503 |
+
cap = cv2.VideoCapture(input_video_path)
|
| 504 |
+
if not cap.isOpened():
|
| 505 |
+
return []
|
| 506 |
+
|
| 507 |
+
t_phase3_start = time.perf_counter()
|
| 508 |
+
search_window_size = 3
|
| 509 |
+
merged_windows = _merge_search_windows(candidate_indices, window_size=search_window_size)
|
| 510 |
+
|
| 511 |
+
# Read frames
|
| 512 |
+
t_read_start = time.perf_counter()
|
| 513 |
+
all_frames = _read_window_frames(cap, merged_windows, total_frames)
|
| 514 |
+
cap.release()
|
| 515 |
+
t_read_end = time.perf_counter()
|
| 516 |
+
|
| 517 |
+
# Parallel clarity calculation
|
| 518 |
+
t_clarity_start = time.perf_counter()
|
| 519 |
+
clarity_results = _compute_clarity_parallel(all_frames)
|
| 520 |
+
t_clarity_end = time.perf_counter()
|
| 521 |
+
|
| 522 |
+
# Select best frames
|
| 523 |
+
target_to_best = _select_best_frames(clarity_results, merged_windows, candidate_indices, search_window_size)
|
| 524 |
+
|
| 525 |
+
# Parallel save
|
| 526 |
+
t_save_start = time.perf_counter()
|
| 527 |
+
final_indices, extracted_paths = _save_frames_parallel(target_to_best, candidate_indices, save_directory)
|
| 528 |
+
t_save_end = time.perf_counter()
|
| 529 |
+
|
| 530 |
+
t_phase3_end = time.perf_counter()
|
| 531 |
+
print(f"[Timing] Phase 3 (Clarity Refinement + Save): {t_phase3_end - t_phase3_start:.3f}s")
|
| 532 |
+
print(f" - Read frames: {t_read_end - t_read_start:.3f}s")
|
| 533 |
+
print(f" - Parallel clarity: {t_clarity_end - t_clarity_start:.3f}s")
|
| 534 |
+
print(f" - Parallel save: {t_save_end - t_save_start:.3f}s, Saved: {len(extracted_paths)}")
|
| 535 |
+
|
| 536 |
+
# Save metadata
|
| 537 |
+
try:
|
| 538 |
+
meta = {
|
| 539 |
+
"frame_indices": final_indices,
|
| 540 |
+
"frame_times": [i/fps for i in final_indices],
|
| 541 |
+
"fps": fps,
|
| 542 |
+
"algorithm": "sparse_dis_clarity_refined"
|
| 543 |
+
}
|
| 544 |
+
with open(os.path.join(save_directory, "frame_metadata.json"), "w") as f:
|
| 545 |
+
json.dump(meta, f, indent=2)
|
| 546 |
+
|
| 547 |
+
with open(os.path.join(save_directory, "frame_metrics.csv"), "w", newline="") as f:
|
| 548 |
+
writer = csv.writer(f)
|
| 549 |
+
writer.writerow(["frame_index", "time_sec", "motion", "selected"])
|
| 550 |
+
for s in sparse_samples:
|
| 551 |
+
writer.writerow([s["idx"], s["idx"]/fps, s["motion"],
|
| 552 |
+
1 if s["idx"] in final_indices else 0])
|
| 553 |
+
except:
|
| 554 |
+
pass
|
| 555 |
+
|
| 556 |
+
print(f"Extracted {len(extracted_paths)} frames using DIS flow + local clarity refinement.")
|
| 557 |
+
return extracted_paths
|
hyworldmirror/utils/visual_util.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Visual utilities for HuggingFace integration.
|
| 2 |
+
|
| 3 |
+
References: https://github.com/facebookresearch/vggt
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import os
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import matplotlib
|
| 12 |
+
import numpy as np
|
| 13 |
+
import requests
|
| 14 |
+
import trimesh
|
| 15 |
+
|
| 16 |
+
from scipy.spatial.transform import Rotation
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def segment_sky(image_or_path, onnx_session):
|
| 20 |
+
"""
|
| 21 |
+
Segments sky from an image using an ONNX model.
|
| 22 |
+
Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
image_or_path: Path to input image (str) or BGR numpy array (H, W, 3)
|
| 26 |
+
onnx_session: ONNX runtime session with loaded model
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
np.ndarray: Binary mask where 255 indicates non-sky regions
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
if isinstance(image_or_path, (str, os.PathLike)):
|
| 33 |
+
image = cv2.imread(str(image_or_path))
|
| 34 |
+
else:
|
| 35 |
+
image = image_or_path
|
| 36 |
+
result_map = run_skyseg(onnx_session, [320, 320], image)
|
| 37 |
+
# resize the result_map to the original image size
|
| 38 |
+
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
|
| 39 |
+
|
| 40 |
+
# Fix: Invert the mask so that 255 = non-sky, 0 = sky
|
| 41 |
+
# The model outputs low values for sky, high values for non-sky
|
| 42 |
+
output_mask = np.zeros_like(result_map_original)
|
| 43 |
+
output_mask[result_map_original < 32] = 255 # Use threshold of 32
|
| 44 |
+
return output_mask
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def run_skyseg(onnx_session, input_size, image):
|
| 48 |
+
"""
|
| 49 |
+
Runs sky segmentation inference using ONNX model.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
onnx_session: ONNX runtime session
|
| 53 |
+
input_size: Target size for model input (width, height)
|
| 54 |
+
image: Input image in BGR format
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
np.ndarray: Segmentation mask
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
|
| 61 |
+
temp_image = copy.deepcopy(image)
|
| 62 |
+
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
| 63 |
+
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
| 64 |
+
x = np.array(x, dtype=np.float32)
|
| 65 |
+
mean = [0.485, 0.456, 0.406]
|
| 66 |
+
std = [0.229, 0.224, 0.225]
|
| 67 |
+
x = (x / 255 - mean) / std
|
| 68 |
+
x = x.transpose(2, 0, 1)
|
| 69 |
+
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
| 70 |
+
|
| 71 |
+
# Inference
|
| 72 |
+
input_name = onnx_session.get_inputs()[0].name
|
| 73 |
+
output_name = onnx_session.get_outputs()[0].name
|
| 74 |
+
onnx_result = onnx_session.run([output_name], {input_name: x})
|
| 75 |
+
|
| 76 |
+
# Post process
|
| 77 |
+
onnx_result = np.array(onnx_result).squeeze()
|
| 78 |
+
min_value = np.min(onnx_result)
|
| 79 |
+
max_value = np.max(onnx_result)
|
| 80 |
+
onnx_result = (onnx_result - min_value) / (max_value - min_value)
|
| 81 |
+
onnx_result *= 255
|
| 82 |
+
onnx_result = onnx_result.astype("uint8")
|
| 83 |
+
|
| 84 |
+
return onnx_result
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def download_file_from_url(url, filename):
|
| 88 |
+
"""Downloads a file from a Hugging Face model repo, handling redirects."""
|
| 89 |
+
try:
|
| 90 |
+
# Get the redirect URL
|
| 91 |
+
response = requests.get(url, allow_redirects=False)
|
| 92 |
+
response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
|
| 93 |
+
|
| 94 |
+
if response.status_code == 302: # Expecting a redirect
|
| 95 |
+
redirect_url = response.headers["Location"]
|
| 96 |
+
response = requests.get(redirect_url, stream=True)
|
| 97 |
+
response.raise_for_status()
|
| 98 |
+
else:
|
| 99 |
+
print(f"Unexpected status code: {response.status_code}")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
with open(filename, "wb") as f:
|
| 103 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 104 |
+
f.write(chunk)
|
| 105 |
+
print(f"Downloaded {filename} successfully.")
|
| 106 |
+
|
| 107 |
+
except requests.exceptions.RequestException as e:
|
| 108 |
+
print(f"Error downloading file: {e}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def create_image_mesh(
|
| 112 |
+
*image_data: np.ndarray,
|
| 113 |
+
mask: np.ndarray = None,
|
| 114 |
+
triangulate: bool = False,
|
| 115 |
+
return_vertex_indices: bool = False,
|
| 116 |
+
) -> Tuple[np.ndarray, ...]:
|
| 117 |
+
"""
|
| 118 |
+
Create a mesh from image data using pixel coordinates as vertices and grid connections as faces.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
*image_data (np.ndarray): Image arrays with shape (height, width, [channels])
|
| 122 |
+
mask (np.ndarray, optional): Boolean mask with shape (height, width). Defaults to None.
|
| 123 |
+
triangulate (bool): Convert quad faces to triangular faces. Defaults to False.
|
| 124 |
+
return_vertex_indices (bool): Include vertex indices in output. Defaults to False.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
faces (np.ndarray): Face connectivity array. Shape (N, 4) for quads or (N, 3) for triangles
|
| 128 |
+
*vertex_data (np.ndarray): Vertex attributes corresponding to input image_data
|
| 129 |
+
vertex_indices (np.ndarray, optional): Original vertex indices if return_vertex_indices=True
|
| 130 |
+
"""
|
| 131 |
+
# Validate inputs
|
| 132 |
+
assert (len(image_data) > 0) or (mask is not None), "Need at least one image or mask"
|
| 133 |
+
|
| 134 |
+
if mask is None:
|
| 135 |
+
height, width = image_data[0].shape[:2]
|
| 136 |
+
else:
|
| 137 |
+
height, width = mask.shape
|
| 138 |
+
|
| 139 |
+
# Check all images have same dimensions
|
| 140 |
+
for img in image_data:
|
| 141 |
+
assert img.shape[:2] == (height, width), "All images must have same height and width"
|
| 142 |
+
|
| 143 |
+
# Create quad faces connecting neighboring pixels
|
| 144 |
+
base_quad = np.stack([
|
| 145 |
+
np.arange(0, width - 1, dtype=np.int32), # bottom-left
|
| 146 |
+
np.arange(width, 2 * width - 1, dtype=np.int32), # top-left
|
| 147 |
+
np.arange(1 + width, 2 * width, dtype=np.int32), # top-right
|
| 148 |
+
np.arange(1, width, dtype=np.int32), # bottom-right
|
| 149 |
+
], axis=1)
|
| 150 |
+
|
| 151 |
+
# Replicate quad pattern for all rows
|
| 152 |
+
row_offsets = np.arange(0, (height - 1) * width, width, dtype=np.int32)
|
| 153 |
+
faces = (row_offsets[:, None, None] + base_quad[None, :, :]).reshape((-1, 4))
|
| 154 |
+
|
| 155 |
+
if mask is None:
|
| 156 |
+
# No masking - use all faces and vertices
|
| 157 |
+
if triangulate:
|
| 158 |
+
faces = _convert_quads_to_triangles(faces)
|
| 159 |
+
|
| 160 |
+
output = [faces]
|
| 161 |
+
for img in image_data:
|
| 162 |
+
output.append(img.reshape(-1, *img.shape[2:]))
|
| 163 |
+
|
| 164 |
+
if return_vertex_indices:
|
| 165 |
+
output.append(np.arange(height * width, dtype=np.int32))
|
| 166 |
+
|
| 167 |
+
return tuple(output)
|
| 168 |
+
else:
|
| 169 |
+
# Apply mask - only keep faces where all 4 corners are valid
|
| 170 |
+
valid_quads = (
|
| 171 |
+
mask[:-1, :-1] & mask[1:, :-1] &
|
| 172 |
+
mask[1:, 1:] & mask[:-1, 1:]
|
| 173 |
+
).ravel()
|
| 174 |
+
faces = faces[valid_quads]
|
| 175 |
+
|
| 176 |
+
if triangulate:
|
| 177 |
+
faces = _convert_quads_to_triangles(faces)
|
| 178 |
+
|
| 179 |
+
# Remove unused vertices and remap face indices
|
| 180 |
+
num_face_vertices = faces.shape[-1]
|
| 181 |
+
unique_vertices, remapped_indices = np.unique(faces, return_inverse=True)
|
| 182 |
+
faces = remapped_indices.astype(np.int32).reshape(-1, num_face_vertices)
|
| 183 |
+
|
| 184 |
+
output = [faces]
|
| 185 |
+
for img in image_data:
|
| 186 |
+
flattened_img = img.reshape(-1, *img.shape[2:])
|
| 187 |
+
output.append(flattened_img[unique_vertices])
|
| 188 |
+
|
| 189 |
+
if return_vertex_indices:
|
| 190 |
+
output.append(unique_vertices)
|
| 191 |
+
|
| 192 |
+
return tuple(output)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _convert_quads_to_triangles(quad_faces: np.ndarray) -> np.ndarray:
|
| 196 |
+
"""Convert quadrilateral faces to triangular faces."""
|
| 197 |
+
if quad_faces.shape[-1] == 3:
|
| 198 |
+
return quad_faces # Already triangular
|
| 199 |
+
|
| 200 |
+
num_vertices_per_face = quad_faces.shape[-1]
|
| 201 |
+
triangle_indices = np.stack([
|
| 202 |
+
np.zeros(num_vertices_per_face - 2, dtype=int), # First vertex
|
| 203 |
+
np.arange(1, num_vertices_per_face - 1, dtype=int), # Sequential vertices
|
| 204 |
+
np.arange(2, num_vertices_per_face, dtype=int), # Next sequential vertices
|
| 205 |
+
], axis=1)
|
| 206 |
+
|
| 207 |
+
return quad_faces[:, triangle_indices].reshape((-1, 3))
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def convert_predictions_to_glb_scene(
|
| 211 |
+
predictions,
|
| 212 |
+
filter_by_frames="all",
|
| 213 |
+
show_camera=True,
|
| 214 |
+
mask_sky_bg=False,
|
| 215 |
+
mask_ambiguous=False,
|
| 216 |
+
as_mesh=True,
|
| 217 |
+
) -> trimesh.Scene:
|
| 218 |
+
"""
|
| 219 |
+
Converts model predictions to a 3D scene represented as a GLB file.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
predictions (dict): Dictionary containing model predictions with keys:
|
| 223 |
+
- world_points: 3D point coordinates (S, H, W, 3)
|
| 224 |
+
- images: Input images (S, H, W, 3)
|
| 225 |
+
- camera_poses: Camera extrinsic matrices (S, 3, 4)
|
| 226 |
+
filter_by_frames (str): Frame filter specification (default: "all")
|
| 227 |
+
show_camera (bool): Include camera visualization (default: True)
|
| 228 |
+
mask_sky_bg (bool): Mask out sky background pixels (default: False)
|
| 229 |
+
mask_ambiguous (bool): Apply final mask to filter ambiguous predictions (default: False)
|
| 230 |
+
as_mesh (bool): Represent the data as a mesh instead of point cloud (default: False)
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
trimesh.Scene: Processed 3D scene containing point cloud/mesh and cameras
|
| 234 |
+
|
| 235 |
+
Raises:
|
| 236 |
+
ValueError: If input predictions structure is invalid
|
| 237 |
+
"""
|
| 238 |
+
if not isinstance(predictions, dict):
|
| 239 |
+
raise ValueError("predictions must be a dictionary")
|
| 240 |
+
|
| 241 |
+
print("Building GLB scene")
|
| 242 |
+
|
| 243 |
+
# Parse frame selection from filter string
|
| 244 |
+
target_frame_index = None
|
| 245 |
+
if filter_by_frames not in ["all", "All"]:
|
| 246 |
+
try:
|
| 247 |
+
# Extract numeric index before colon separator
|
| 248 |
+
target_frame_index = int(filter_by_frames.split(":")[0])
|
| 249 |
+
except (ValueError, IndexError):
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
# Validate required data in predictions
|
| 253 |
+
print("Using Pointmap Branch")
|
| 254 |
+
if "world_points" not in predictions:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
"world_points not found in predictions. Pointmap Branch requires 'world_points' key. "
|
| 257 |
+
"Depthmap and Camera branches have been removed."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Extract prediction data
|
| 261 |
+
point_cloud_3d = predictions["world_points"]
|
| 262 |
+
input_images = predictions["images"]
|
| 263 |
+
extrinsic_matrices = predictions["camera_poses"]
|
| 264 |
+
ambiguity_mask = predictions["final_mask"]
|
| 265 |
+
sky_region_mask = predictions["sky_mask"]
|
| 266 |
+
|
| 267 |
+
# Filter to single frame if specified
|
| 268 |
+
if target_frame_index is not None:
|
| 269 |
+
point_cloud_3d = point_cloud_3d[target_frame_index][None]
|
| 270 |
+
input_images = input_images[target_frame_index][None]
|
| 271 |
+
extrinsic_matrices = extrinsic_matrices[target_frame_index][None]
|
| 272 |
+
ambiguity_mask = ambiguity_mask[target_frame_index][None]
|
| 273 |
+
sky_region_mask = sky_region_mask[target_frame_index][None]
|
| 274 |
+
|
| 275 |
+
# Flatten 3D points to vertex array
|
| 276 |
+
flattened_vertices = point_cloud_3d.reshape(-1, 3)
|
| 277 |
+
|
| 278 |
+
# Convert images to RGB color array
|
| 279 |
+
if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
|
| 280 |
+
rgb_colors = np.transpose(input_images, (0, 2, 3, 1))
|
| 281 |
+
else: # Already in NHWC format
|
| 282 |
+
rgb_colors = input_images
|
| 283 |
+
rgb_colors = (rgb_colors.reshape(-1, 3) * 255).astype(np.uint8)
|
| 284 |
+
|
| 285 |
+
# Build composite filtering mask
|
| 286 |
+
valid_points_mask = np.ones(len(flattened_vertices), dtype=bool)
|
| 287 |
+
|
| 288 |
+
# Apply ambiguity filtering if requested
|
| 289 |
+
if mask_ambiguous:
|
| 290 |
+
flat_ambiguity_mask = ambiguity_mask.reshape(-1)
|
| 291 |
+
valid_points_mask = valid_points_mask & flat_ambiguity_mask
|
| 292 |
+
|
| 293 |
+
# Apply sky region filtering if requested
|
| 294 |
+
if mask_sky_bg:
|
| 295 |
+
flat_sky_mask = sky_region_mask.reshape(-1)
|
| 296 |
+
valid_points_mask = valid_points_mask & flat_sky_mask
|
| 297 |
+
|
| 298 |
+
# Apply mask to filter vertices and colors
|
| 299 |
+
filtered_vertices = flattened_vertices[valid_points_mask].copy()
|
| 300 |
+
filtered_colors = rgb_colors[valid_points_mask].copy()
|
| 301 |
+
|
| 302 |
+
# Handle empty geometry case
|
| 303 |
+
if filtered_vertices is None or np.asarray(filtered_vertices).size == 0:
|
| 304 |
+
filtered_vertices = np.array([[1, 0, 0]])
|
| 305 |
+
filtered_colors = np.array([[255, 255, 255]])
|
| 306 |
+
scene_scale_factor = 1
|
| 307 |
+
else:
|
| 308 |
+
# Compute scene scale from percentile-based bounding box
|
| 309 |
+
percentile_lower = np.percentile(filtered_vertices, 5, axis=0)
|
| 310 |
+
percentile_upper = np.percentile(filtered_vertices, 95, axis=0)
|
| 311 |
+
scene_scale_factor = np.linalg.norm(percentile_upper - percentile_lower)
|
| 312 |
+
|
| 313 |
+
# Initialize color mapping for cameras
|
| 314 |
+
color_palette = matplotlib.colormaps.get_cmap("gist_rainbow")
|
| 315 |
+
|
| 316 |
+
# Create empty 3D scene container
|
| 317 |
+
output_scene = trimesh.Scene()
|
| 318 |
+
|
| 319 |
+
# Add geometry to scene based on representation type
|
| 320 |
+
if as_mesh:
|
| 321 |
+
# Mesh representation
|
| 322 |
+
if target_frame_index is not None:
|
| 323 |
+
# Single frame mesh generation
|
| 324 |
+
frame_height, frame_width = point_cloud_3d.shape[1:3]
|
| 325 |
+
|
| 326 |
+
# Prepare unfiltered data for mesh construction
|
| 327 |
+
structured_points = point_cloud_3d.reshape(frame_height, frame_width, 3)
|
| 328 |
+
|
| 329 |
+
# Convert image data to proper format
|
| 330 |
+
if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
|
| 331 |
+
structured_colors = np.transpose(input_images[0], (1, 2, 0))
|
| 332 |
+
else: # Already in HWC format
|
| 333 |
+
structured_colors = input_images[0]
|
| 334 |
+
structured_colors *= 255
|
| 335 |
+
|
| 336 |
+
# Get structured mask for mesh creation
|
| 337 |
+
structured_mask = predictions["final_mask"][target_frame_index].reshape(
|
| 338 |
+
frame_height, frame_width
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Build filtering mask
|
| 342 |
+
mesh_filter_mask = structured_mask
|
| 343 |
+
|
| 344 |
+
# Check for normal data availability
|
| 345 |
+
mesh_normals = None
|
| 346 |
+
if "normal" in predictions and predictions["normal"] is not None:
|
| 347 |
+
# Extract normals for selected frame
|
| 348 |
+
frame_normal_data = (
|
| 349 |
+
predictions["normal"][target_frame_index]
|
| 350 |
+
if target_frame_index is not None
|
| 351 |
+
else predictions["normal"][0]
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Generate mesh with normal information
|
| 355 |
+
mesh_faces, mesh_vertices, mesh_colors, mesh_normals = create_image_mesh(
|
| 356 |
+
structured_points * np.array([1, -1, 1], dtype=np.float32),
|
| 357 |
+
structured_colors / 255.0,
|
| 358 |
+
frame_normal_data * np.array([1, -1, 1], dtype=np.float32),
|
| 359 |
+
mask=mesh_filter_mask,
|
| 360 |
+
triangulate=True,
|
| 361 |
+
return_vertex_indices=False,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Apply coordinate system transformation to normals
|
| 365 |
+
mesh_normals = mesh_normals * np.array([1, -1, 1], dtype=np.float32)
|
| 366 |
+
else:
|
| 367 |
+
# Generate mesh without normal information
|
| 368 |
+
mesh_faces, mesh_vertices, mesh_colors = create_image_mesh(
|
| 369 |
+
structured_points * np.array([1, -1, 1], dtype=np.float32),
|
| 370 |
+
structured_colors / 255.0,
|
| 371 |
+
mask=mesh_filter_mask,
|
| 372 |
+
triangulate=True,
|
| 373 |
+
return_vertex_indices=False,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Construct trimesh object with optional normals
|
| 377 |
+
geometry_mesh = trimesh.Trimesh(
|
| 378 |
+
vertices=mesh_vertices * np.array([1, -1, 1], dtype=np.float32),
|
| 379 |
+
faces=mesh_faces,
|
| 380 |
+
vertex_colors=(mesh_colors * 255).astype(np.uint8),
|
| 381 |
+
vertex_normals=(mesh_normals if mesh_normals is not None else None),
|
| 382 |
+
process=False,
|
| 383 |
+
)
|
| 384 |
+
output_scene.add_geometry(geometry_mesh)
|
| 385 |
+
else:
|
| 386 |
+
# Multi-frame mesh generation
|
| 387 |
+
print("Creating mesh for multi-frame data...")
|
| 388 |
+
|
| 389 |
+
for frame_idx in range(point_cloud_3d.shape[0]):
|
| 390 |
+
frame_height, frame_width = point_cloud_3d.shape[1:3]
|
| 391 |
+
|
| 392 |
+
# Extract per-frame data
|
| 393 |
+
frame_point_data = point_cloud_3d[frame_idx]
|
| 394 |
+
frame_ambiguity_mask = predictions["final_mask"][frame_idx]
|
| 395 |
+
frame_sky_mask = predictions["sky_mask"][frame_idx]
|
| 396 |
+
|
| 397 |
+
# Extract frame image data
|
| 398 |
+
if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
|
| 399 |
+
frame_image_data = np.transpose(input_images[frame_idx], (1, 2, 0))
|
| 400 |
+
else: # Already in HWC format
|
| 401 |
+
frame_image_data = input_images[frame_idx]
|
| 402 |
+
frame_image_data *= 255
|
| 403 |
+
|
| 404 |
+
# Build per-frame filtering mask
|
| 405 |
+
frame_filter_mask = np.ones((frame_height, frame_width), dtype=bool)
|
| 406 |
+
|
| 407 |
+
# Apply ambiguity filtering if enabled
|
| 408 |
+
if mask_ambiguous:
|
| 409 |
+
frame_filter_mask = frame_filter_mask & frame_ambiguity_mask
|
| 410 |
+
|
| 411 |
+
# Apply sky filtering if enabled
|
| 412 |
+
if mask_sky_bg:
|
| 413 |
+
frame_filter_mask = frame_filter_mask & frame_sky_mask
|
| 414 |
+
|
| 415 |
+
# Generate mesh for current frame
|
| 416 |
+
frame_faces, frame_vertices, frame_colors = create_image_mesh(
|
| 417 |
+
frame_point_data * np.array([1, -1, 1], dtype=np.float32),
|
| 418 |
+
frame_image_data / 255.0,
|
| 419 |
+
mask=frame_filter_mask,
|
| 420 |
+
triangulate=True,
|
| 421 |
+
return_vertex_indices=False,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
frame_vertices = frame_vertices * np.array([1, -1, 1], dtype=np.float32)
|
| 425 |
+
|
| 426 |
+
# Create trimesh object for current frame
|
| 427 |
+
frame_geometry = trimesh.Trimesh(
|
| 428 |
+
vertices=frame_vertices,
|
| 429 |
+
faces=frame_faces,
|
| 430 |
+
vertex_colors=(frame_colors * 255).astype(np.uint8),
|
| 431 |
+
process=False,
|
| 432 |
+
)
|
| 433 |
+
output_scene.add_geometry(frame_geometry)
|
| 434 |
+
else:
|
| 435 |
+
# Point cloud representation
|
| 436 |
+
point_cloud_geometry = trimesh.PointCloud(vertices=filtered_vertices, colors=filtered_colors)
|
| 437 |
+
output_scene.add_geometry(point_cloud_geometry)
|
| 438 |
+
|
| 439 |
+
# Add camera visualizations if requested
|
| 440 |
+
num_camera_views = len(extrinsic_matrices)
|
| 441 |
+
|
| 442 |
+
if show_camera:
|
| 443 |
+
# Iterate through all camera views
|
| 444 |
+
for camera_idx in range(num_camera_views):
|
| 445 |
+
camera_extrinsic = extrinsic_matrices[camera_idx]
|
| 446 |
+
camera_color_rgba = color_palette(camera_idx / num_camera_views)
|
| 447 |
+
camera_color_rgb = tuple(int(255 * x) for x in camera_color_rgba[:3])
|
| 448 |
+
|
| 449 |
+
integrate_camera_into_scene(
|
| 450 |
+
output_scene, camera_extrinsic, camera_color_rgb, scene_scale_factor
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Define coordinate system transformation matrices
|
| 454 |
+
opengl_transform = np.eye(4)
|
| 455 |
+
opengl_transform[1, 1] = -1 # Flip Y axis
|
| 456 |
+
opengl_transform[2, 2] = -1 # Flip Z axis
|
| 457 |
+
|
| 458 |
+
# Define alignment rotation (180 degrees around Y-axis)
|
| 459 |
+
alignment_rotation = np.eye(4)
|
| 460 |
+
alignment_rotation[:3, :3] = Rotation.from_euler("y", 0, degrees=True).as_matrix()
|
| 461 |
+
|
| 462 |
+
# Compute and apply final transformation
|
| 463 |
+
scene_transformation = (
|
| 464 |
+
np.linalg.inv(extrinsic_matrices[0])
|
| 465 |
+
@ opengl_transform
|
| 466 |
+
@ alignment_rotation
|
| 467 |
+
)
|
| 468 |
+
output_scene.apply_transform(scene_transformation)
|
| 469 |
+
|
| 470 |
+
print("GLB Scene built")
|
| 471 |
+
return output_scene
|
| 472 |
+
|
| 473 |
+
def integrate_camera_into_scene(
|
| 474 |
+
scene: trimesh.Scene,
|
| 475 |
+
camera_transform: np.ndarray,
|
| 476 |
+
camera_color: tuple,
|
| 477 |
+
scale_factor: float,
|
| 478 |
+
):
|
| 479 |
+
"""
|
| 480 |
+
Adds a camera visualization mesh to the 3D scene.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
scene (trimesh.Scene): The 3D scene to add the camera visualization.
|
| 484 |
+
camera_transform (np.ndarray): 4x4 transformation matrix for camera positioning.
|
| 485 |
+
camera_color (tuple): RGB color tuple for the camera mesh.
|
| 486 |
+
scale_factor (float): Scaling factor for the camera size relative to scene.
|
| 487 |
+
"""
|
| 488 |
+
# Define camera dimensions based on scene scale
|
| 489 |
+
camera_base_width = scale_factor * 0.05
|
| 490 |
+
camera_cone_height = scale_factor * 0.1
|
| 491 |
+
|
| 492 |
+
# Create base cone geometry for camera representation
|
| 493 |
+
base_cone = trimesh.creation.cone(camera_base_width, camera_cone_height, sections=4)
|
| 494 |
+
|
| 495 |
+
# Setup rotation transformation (45 degrees around z-axis)
|
| 496 |
+
z_rotation_matrix = np.eye(4)
|
| 497 |
+
z_rotation_matrix[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
| 498 |
+
z_rotation_matrix[2, 3] = -camera_cone_height
|
| 499 |
+
|
| 500 |
+
# Setup OpenGL coordinate system conversion
|
| 501 |
+
opengl_coord_transform = np.eye(4)
|
| 502 |
+
opengl_coord_transform[1, 1] = -1 # Flip Y axis
|
| 503 |
+
opengl_coord_transform[2, 2] = -1 # Flip Z axis
|
| 504 |
+
|
| 505 |
+
# Combine all transformations
|
| 506 |
+
final_transform = camera_transform @ opengl_coord_transform @ z_rotation_matrix
|
| 507 |
+
|
| 508 |
+
# Create slight rotation for mesh variation
|
| 509 |
+
minor_rotation = np.eye(4)
|
| 510 |
+
minor_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
| 511 |
+
|
| 512 |
+
# Generate multiple vertex sets for complex camera geometry
|
| 513 |
+
original_vertices = base_cone.vertices
|
| 514 |
+
scaled_vertices = 0.95 * original_vertices
|
| 515 |
+
rotated_vertices = apply_transformation_to_points(minor_rotation, original_vertices)
|
| 516 |
+
|
| 517 |
+
# Combine all vertex sets
|
| 518 |
+
all_vertices = np.concatenate([
|
| 519 |
+
original_vertices,
|
| 520 |
+
scaled_vertices,
|
| 521 |
+
rotated_vertices
|
| 522 |
+
])
|
| 523 |
+
|
| 524 |
+
# Transform vertices to final position
|
| 525 |
+
transformed_vertices = apply_transformation_to_points(final_transform, all_vertices)
|
| 526 |
+
|
| 527 |
+
# Generate faces for the complete camera mesh
|
| 528 |
+
camera_faces = generate_camera_mesh_faces(base_cone)
|
| 529 |
+
|
| 530 |
+
# Create and configure the camera mesh
|
| 531 |
+
camera_mesh = trimesh.Trimesh(
|
| 532 |
+
vertices=transformed_vertices,
|
| 533 |
+
faces=camera_faces
|
| 534 |
+
)
|
| 535 |
+
camera_mesh.visual.face_colors[:, :3] = camera_color
|
| 536 |
+
|
| 537 |
+
# Add the camera mesh to the scene
|
| 538 |
+
scene.add_geometry(camera_mesh)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def apply_transformation_to_points(
|
| 542 |
+
transform_matrix: np.ndarray, point_array: np.ndarray, output_dim: int = None
|
| 543 |
+
) -> np.ndarray:
|
| 544 |
+
"""
|
| 545 |
+
Applies a 4x4 transformation matrix to a collection of 3D points.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
transform_matrix (np.ndarray): 4x4 transformation matrix to apply.
|
| 549 |
+
point_array (np.ndarray): Array of points to transform.
|
| 550 |
+
output_dim (int, optional): Target dimension for output points.
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
np.ndarray: Array of transformed points.
|
| 554 |
+
"""
|
| 555 |
+
point_array = np.asarray(point_array)
|
| 556 |
+
original_shape = point_array.shape[:-1]
|
| 557 |
+
target_dim = output_dim or point_array.shape[-1]
|
| 558 |
+
|
| 559 |
+
# Transpose transformation matrix for matrix multiplication
|
| 560 |
+
transposed_transform = transform_matrix.swapaxes(-1, -2)
|
| 561 |
+
|
| 562 |
+
# Apply rotation/scaling and translation components
|
| 563 |
+
transformed_points = (
|
| 564 |
+
point_array @ transposed_transform[..., :-1, :] +
|
| 565 |
+
transposed_transform[..., -1:, :]
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Extract desired dimensions and restore original shape
|
| 569 |
+
final_result = transformed_points[..., :target_dim].reshape(*original_shape, target_dim)
|
| 570 |
+
return final_result
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def generate_camera_mesh_faces(base_cone_mesh: trimesh.Trimesh) -> np.ndarray:
|
| 574 |
+
"""
|
| 575 |
+
Generates face indices for a complex camera mesh composed of multiple cone layers.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
base_cone_mesh (trimesh.Trimesh): Base cone geometry used as template.
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
np.ndarray: Array of face indices defining the camera mesh topology.
|
| 582 |
+
"""
|
| 583 |
+
face_indices = []
|
| 584 |
+
vertex_count_per_cone = len(base_cone_mesh.vertices)
|
| 585 |
+
|
| 586 |
+
# Process each face of the base cone
|
| 587 |
+
for triangle_face in base_cone_mesh.faces:
|
| 588 |
+
# Skip faces that include the cone tip (vertex 0)
|
| 589 |
+
if 0 in triangle_face:
|
| 590 |
+
continue
|
| 591 |
+
|
| 592 |
+
# Get vertex indices for current triangle
|
| 593 |
+
vertex_a, vertex_b, vertex_c = triangle_face
|
| 594 |
+
|
| 595 |
+
# Calculate corresponding vertices in second and third cone layers
|
| 596 |
+
vertex_a_layer2, vertex_b_layer2, vertex_c_layer2 = triangle_face + vertex_count_per_cone
|
| 597 |
+
vertex_a_layer3, vertex_b_layer3, vertex_c_layer3 = triangle_face + 2 * vertex_count_per_cone
|
| 598 |
+
|
| 599 |
+
# Create connecting faces between cone layers
|
| 600 |
+
connecting_faces = [
|
| 601 |
+
(vertex_a, vertex_b, vertex_b_layer2),
|
| 602 |
+
(vertex_a, vertex_a_layer2, vertex_c),
|
| 603 |
+
(vertex_c_layer2, vertex_b, vertex_c),
|
| 604 |
+
(vertex_a, vertex_b, vertex_b_layer3),
|
| 605 |
+
(vertex_a, vertex_a_layer3, vertex_c),
|
| 606 |
+
(vertex_c_layer3, vertex_b, vertex_c),
|
| 607 |
+
]
|
| 608 |
+
|
| 609 |
+
face_indices.extend(connecting_faces)
|
| 610 |
+
|
| 611 |
+
# Add reverse-winding faces for proper mesh closure
|
| 612 |
+
reversed_faces = [(vertex_c, vertex_b, vertex_a) for vertex_a, vertex_b, vertex_c in face_indices]
|
| 613 |
+
face_indices.extend(reversed_faces)
|
| 614 |
+
|
| 615 |
+
return np.array(face_indices)
|
| 616 |
+
|
| 617 |
+
|
hyworldmirror/utils/warnings.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper utilities for warnings.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
from functools import wraps
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class no_warnings:
|
| 10 |
+
def __init__(self, action: str = "ignore", **kwargs):
|
| 11 |
+
self.action = action
|
| 12 |
+
self.filter_kwargs = kwargs
|
| 13 |
+
|
| 14 |
+
def __call__(self, fn):
|
| 15 |
+
@wraps(fn)
|
| 16 |
+
def wrapper(*args, **kwargs):
|
| 17 |
+
with warnings.catch_warnings():
|
| 18 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 19 |
+
return fn(*args, **kwargs)
|
| 20 |
+
|
| 21 |
+
return wrapper
|
| 22 |
+
|
| 23 |
+
def __enter__(self):
|
| 24 |
+
self.warnings_manager = warnings.catch_warnings()
|
| 25 |
+
self.warnings_manager.__enter__()
|
| 26 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 27 |
+
|
| 28 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 29 |
+
self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
|
pipeline.py
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HunyuanWorld-Mirror Inference Pipeline
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
# Python API — Single GPU
|
| 6 |
+
from hyworld2.worldrecon.pipeline import WorldMirrorPipeline
|
| 7 |
+
pipeline = WorldMirrorPipeline.from_pretrained('tencent/HY-World-2.0')
|
| 8 |
+
result = pipeline('path/to/images')
|
| 9 |
+
|
| 10 |
+
# Python API — Multi-GPU (in a torchrun script)
|
| 11 |
+
pipeline = WorldMirrorPipeline.from_pretrained(
|
| 12 |
+
'tencent/HY-World-2.0', use_fsdp=True, enable_bf16=True)
|
| 13 |
+
result = pipeline('path/to/images')
|
| 14 |
+
|
| 15 |
+
# CLI — Single GPU
|
| 16 |
+
python -m hyworld2.worldrecon.pipeline --input_path path/to/images
|
| 17 |
+
|
| 18 |
+
# CLI — Multi-GPU
|
| 19 |
+
torchrun --nproc_per_node=2 -m hyworld2.worldrecon.pipeline --input_path path/to/images --use_fsdp --enable_bf16
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import functools
|
| 24 |
+
import gc
|
| 25 |
+
import os
|
| 26 |
+
import time
|
| 27 |
+
from datetime import datetime, timedelta
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
import torch.distributed as dist
|
| 33 |
+
from omegaconf import OmegaConf
|
| 34 |
+
from safetensors.torch import load_file as load_safetensors
|
| 35 |
+
from torch.distributed.fsdp import (
|
| 36 |
+
FullyShardedDataParallel as FSDP,
|
| 37 |
+
ShardingStrategy,
|
| 38 |
+
CPUOffload,
|
| 39 |
+
)
|
| 40 |
+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
| 41 |
+
|
| 42 |
+
from .hyworldmirror.models.models.worldmirror import WorldMirror
|
| 43 |
+
from .hyworldmirror.models.layers.block import Block, DistBlock
|
| 44 |
+
from .hyworldmirror.models.heads.dense_head import DPTHead
|
| 45 |
+
from .hyworldmirror.models.heads.camera_head import CameraHead
|
| 46 |
+
from .hyworldmirror.utils.inference_utils import (
|
| 47 |
+
prepare_images_to_tensor,
|
| 48 |
+
prepare_input,
|
| 49 |
+
compute_adaptive_target_size,
|
| 50 |
+
compute_preprocessing_transform,
|
| 51 |
+
load_prior_camera,
|
| 52 |
+
load_prior_depth,
|
| 53 |
+
compute_sky_mask,
|
| 54 |
+
compute_filter_mask,
|
| 55 |
+
save_results,
|
| 56 |
+
print_and_save_timings,
|
| 57 |
+
)
|
| 58 |
+
from .hyworldmirror.utils.render_utils import render_interpolated_video
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ============================================================
|
| 62 |
+
# Model loading helpers (checkpoint, config, selective load)
|
| 63 |
+
# ============================================================
|
| 64 |
+
|
| 65 |
+
def _get_model_config_from_yaml(cfg) -> dict:
|
| 66 |
+
if hasattr(cfg, "wrapper") and hasattr(cfg.wrapper, "model"):
|
| 67 |
+
model_cfg = cfg.wrapper.model
|
| 68 |
+
elif hasattr(cfg, "model"):
|
| 69 |
+
model_cfg = cfg.model
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError("No model config found (expect wrapper.model or model).")
|
| 72 |
+
out = OmegaConf.to_container(model_cfg, resolve=True)
|
| 73 |
+
out.pop("_target_", None)
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _load_checkpoint_state_dict(ckpt_path: str) -> dict:
|
| 78 |
+
if ckpt_path.endswith(".safetensors"):
|
| 79 |
+
return load_safetensors(ckpt_path)
|
| 80 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 81 |
+
state = ckpt.get("state_dict", ckpt)
|
| 82 |
+
if "state_dict" in ckpt:
|
| 83 |
+
state = {k.replace("model.", ""): v for k, v in state.items()}
|
| 84 |
+
return state
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _load_state_dict_selective(model, ckpt_state, source_name="checkpoint"):
|
| 88 |
+
current = model.state_dict()
|
| 89 |
+
for key in current:
|
| 90 |
+
if key in ckpt_state and current[key].shape == ckpt_state[key].shape:
|
| 91 |
+
current[key] = ckpt_state[key]
|
| 92 |
+
model.load_state_dict(current, strict=True)
|
| 93 |
+
matched = sum(1 for k in current if k in ckpt_state and current[k].shape == ckpt_state[k].shape)
|
| 94 |
+
print(f" Loaded {matched}/{len(current)} keys from {source_name}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _has_model_files(path: str) -> bool:
|
| 98 |
+
"""Check whether a directory contains the expected model artifacts."""
|
| 99 |
+
has_weights = os.path.isfile(os.path.join(path, "model.safetensors"))
|
| 100 |
+
has_config = (os.path.isfile(os.path.join(path, "config.yaml"))
|
| 101 |
+
or os.path.isfile(os.path.join(path, "config.json")))
|
| 102 |
+
return has_weights and has_config
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _resolve_model_dir(model_path: str, subfolder: str) -> str:
|
| 106 |
+
"""Resolve a local directory containing config + model.safetensors.
|
| 107 |
+
|
| 108 |
+
Resolution order:
|
| 109 |
+
1. {model_path}/{subfolder} — local repo root with subfolder
|
| 110 |
+
2. {model_path} — direct local path (backward compat)
|
| 111 |
+
3. HuggingFace download: snapshot_download(repo_id, allow_patterns=[subfolder/*])
|
| 112 |
+
"""
|
| 113 |
+
candidate = os.path.join(model_path, subfolder)
|
| 114 |
+
if os.path.isdir(candidate) and _has_model_files(candidate):
|
| 115 |
+
print(f"[Init] Found local model at {candidate}")
|
| 116 |
+
return candidate
|
| 117 |
+
|
| 118 |
+
if os.path.isdir(model_path) and _has_model_files(model_path):
|
| 119 |
+
print(f"[Init] Found local model at {model_path}")
|
| 120 |
+
return model_path
|
| 121 |
+
|
| 122 |
+
print(f"[Init] Downloading from HuggingFace: {model_path} (subfolder={subfolder})")
|
| 123 |
+
from huggingface_hub import snapshot_download
|
| 124 |
+
repo_root = snapshot_download(
|
| 125 |
+
repo_id=model_path,
|
| 126 |
+
allow_patterns=[f"{subfolder}/*"],
|
| 127 |
+
)
|
| 128 |
+
resolved = os.path.join(repo_root, subfolder)
|
| 129 |
+
if not _has_model_files(resolved):
|
| 130 |
+
raise FileNotFoundError(
|
| 131 |
+
f"Downloaded repo '{model_path}' but subfolder '{subfolder}' "
|
| 132 |
+
f"does not contain model.safetensors + config. "
|
| 133 |
+
f"Check that the repo and subfolder name are correct."
|
| 134 |
+
)
|
| 135 |
+
return resolved
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _load_model_config(model_dir: str) -> dict:
|
| 139 |
+
"""Load model constructor kwargs from config.yaml or config.json in model_dir."""
|
| 140 |
+
import json as _json
|
| 141 |
+
yaml_path = os.path.join(model_dir, "config.yaml")
|
| 142 |
+
json_path = os.path.join(model_dir, "config.json")
|
| 143 |
+
|
| 144 |
+
if os.path.isfile(yaml_path):
|
| 145 |
+
cfg = OmegaConf.load(yaml_path)
|
| 146 |
+
return _get_model_config_from_yaml(cfg)
|
| 147 |
+
elif os.path.isfile(json_path):
|
| 148 |
+
with open(json_path) as f:
|
| 149 |
+
return _json.load(f)
|
| 150 |
+
else:
|
| 151 |
+
raise FileNotFoundError(f"No config.yaml or config.json in {model_dir}")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ============================================================
|
| 155 |
+
# FSDP / bf16 helpers
|
| 156 |
+
# ============================================================
|
| 157 |
+
|
| 158 |
+
def _collect_fp32_critical_modules(model):
|
| 159 |
+
from .hyworldmirror.models.layers.mlp import MlpFP32
|
| 160 |
+
critical = set()
|
| 161 |
+
for _, module in model.named_modules():
|
| 162 |
+
if isinstance(module, MlpFP32) and hasattr(module, 'fc2'):
|
| 163 |
+
if any(p.dtype == torch.float32 for p in module.fc2.parameters()):
|
| 164 |
+
critical.add(module.fc2)
|
| 165 |
+
if hasattr(module, 'scratch') and hasattr(module.scratch, 'output_conv2'):
|
| 166 |
+
oc2 = module.scratch.output_conv2
|
| 167 |
+
if any(p.dtype == torch.float32 for p in oc2.parameters()):
|
| 168 |
+
critical.add(oc2)
|
| 169 |
+
return critical
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _cast_noncritical_fp32_to_bf16(model, critical_modules):
|
| 173 |
+
critical_ids = {id(p) for mod in critical_modules for p in mod.parameters()}
|
| 174 |
+
cast = []
|
| 175 |
+
for name, param in model.named_parameters():
|
| 176 |
+
if param.dtype == torch.float32 and id(param) not in critical_ids:
|
| 177 |
+
param.data = param.data.to(torch.bfloat16)
|
| 178 |
+
cast.append(name)
|
| 179 |
+
for _, buf in model.named_buffers():
|
| 180 |
+
if buf.dtype == torch.float32:
|
| 181 |
+
buf.data = buf.data.to(torch.bfloat16)
|
| 182 |
+
|
| 183 |
+
def _hook(module, args):
|
| 184 |
+
if not args:
|
| 185 |
+
return args
|
| 186 |
+
dtype = next((p.dtype for p in module.parameters(recurse=False)), None)
|
| 187 |
+
if dtype is None:
|
| 188 |
+
return args
|
| 189 |
+
return tuple(a.to(dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() and a.dtype != dtype else a
|
| 190 |
+
for a in args)
|
| 191 |
+
|
| 192 |
+
for name, module in model.named_modules():
|
| 193 |
+
if not any(True for _ in module.children()):
|
| 194 |
+
own = list(module.named_parameters(recurse=False))
|
| 195 |
+
if own and all(p.dtype == torch.bfloat16 for _, p in own):
|
| 196 |
+
pfx = name + "." if name else ""
|
| 197 |
+
if any(c.startswith(pfx) for c in cast):
|
| 198 |
+
module.register_forward_pre_hook(_hook)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _wrap_model_fsdp(model, sp_group, device, use_cpu_offload=False, enable_bf16=False):
|
| 202 |
+
wrap_cls = {DistBlock, Block, DPTHead, CameraHead}
|
| 203 |
+
if enable_bf16:
|
| 204 |
+
fp32_critical = _collect_fp32_critical_modules(model)
|
| 205 |
+
def policy(module, recurse, nonwrapped_numel, **kw):
|
| 206 |
+
if recurse:
|
| 207 |
+
return True
|
| 208 |
+
return isinstance(module, tuple(wrap_cls)) or module in fp32_critical
|
| 209 |
+
auto_wrap_policy = policy
|
| 210 |
+
else:
|
| 211 |
+
auto_wrap_policy = functools.partial(
|
| 212 |
+
transformer_auto_wrap_policy, transformer_layer_cls=wrap_cls)
|
| 213 |
+
|
| 214 |
+
fsdp_model = FSDP(
|
| 215 |
+
model, process_group=sp_group,
|
| 216 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 217 |
+
auto_wrap_policy=auto_wrap_policy, mixed_precision=None,
|
| 218 |
+
cpu_offload=CPUOffload(offload_params=True) if use_cpu_offload else None,
|
| 219 |
+
device_id=device, use_orig_params=True, sync_module_states=True,
|
| 220 |
+
forward_prefetch=False,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
rank = dist.get_rank()
|
| 224 |
+
if rank == 0:
|
| 225 |
+
total = sum(p.numel() for p in fsdp_model.parameters())
|
| 226 |
+
local = sum(getattr(p, '_local_tensor', p).numel() for p in fsdp_model.parameters())
|
| 227 |
+
print(f"[FSDP] total={total/1e6:.1f}M, local≈{local/1e6:.1f}M")
|
| 228 |
+
return fsdp_model
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ============================================================
|
| 232 |
+
# WorldMirrorPipeline
|
| 233 |
+
# ============================================================
|
| 234 |
+
|
| 235 |
+
class WorldMirrorPipeline:
|
| 236 |
+
"""HunyuanWorld-Mirror inference pipeline.
|
| 237 |
+
|
| 238 |
+
Supports single-GPU and multi-GPU (Sequence Parallel) inference with
|
| 239 |
+
a unified API. Multi-GPU mode is auto-detected from torch.distributed.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(self, model, device, sp_size=1, sp_group=None, rank=0):
|
| 243 |
+
self.model = model
|
| 244 |
+
self.device = device
|
| 245 |
+
self.sp_size = sp_size
|
| 246 |
+
self.sp_group = sp_group
|
| 247 |
+
self.rank = rank
|
| 248 |
+
|
| 249 |
+
@classmethod
|
| 250 |
+
def from_pretrained(
|
| 251 |
+
cls,
|
| 252 |
+
pretrained_model_name_or_path: str = "tencent/HY-World-2.0",
|
| 253 |
+
*,
|
| 254 |
+
subfolder: str = "HY-WorldMirror-2.0",
|
| 255 |
+
config_path: str = None,
|
| 256 |
+
ckpt_path: str = None,
|
| 257 |
+
use_fsdp: bool = False,
|
| 258 |
+
enable_bf16: bool = False,
|
| 259 |
+
fsdp_cpu_offload: bool = False,
|
| 260 |
+
disable_heads: list = None,
|
| 261 |
+
) -> "WorldMirrorPipeline":
|
| 262 |
+
"""Load model and create pipeline instance.
|
| 263 |
+
|
| 264 |
+
Automatically detects distributed mode (torchrun sets WORLD_SIZE).
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
pretrained_model_name_or_path: HuggingFace repo ID or local path.
|
| 268 |
+
The model files are expected under ``{path}/{subfolder}/``.
|
| 269 |
+
subfolder: Subfolder inside the repo that contains the WorldMirror
|
| 270 |
+
checkpoint (model.safetensors + config).
|
| 271 |
+
config_path: Training config YAML (used with ckpt_path).
|
| 272 |
+
ckpt_path: Checkpoint file (.ckpt / .safetensors).
|
| 273 |
+
use_fsdp: Shard parameters across GPUs via FSDP.
|
| 274 |
+
enable_bf16: Use bf16 precision (except critical layers).
|
| 275 |
+
fsdp_cpu_offload: Offload FSDP params to CPU.
|
| 276 |
+
disable_heads: List of heads to disable, e.g. ["camera", "depth"].
|
| 277 |
+
"""
|
| 278 |
+
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
| 279 |
+
|
| 280 |
+
if is_distributed:
|
| 281 |
+
if not dist.is_initialized():
|
| 282 |
+
dist.init_process_group(backend="nccl")
|
| 283 |
+
rank = dist.get_rank()
|
| 284 |
+
world_size = dist.get_world_size()
|
| 285 |
+
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
| 286 |
+
torch.cuda.set_device(local_rank)
|
| 287 |
+
device = torch.device("cuda", local_rank)
|
| 288 |
+
sp_size = world_size
|
| 289 |
+
sp_group = dist.new_group(ranks=list(range(sp_size)))
|
| 290 |
+
if rank == 0:
|
| 291 |
+
print(f"[Pipeline] Multi-GPU: world_size={world_size}, sp_size={sp_size}")
|
| 292 |
+
else:
|
| 293 |
+
rank, sp_size = 0, 1
|
| 294 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 295 |
+
sp_group = None
|
| 296 |
+
if use_fsdp:
|
| 297 |
+
print("[Pipeline] Warning: use_fsdp is ignored in single-GPU mode (FSDP requires torchrun with multiple GPUs)")
|
| 298 |
+
use_fsdp = False
|
| 299 |
+
print("[Pipeline] Single-GPU mode")
|
| 300 |
+
|
| 301 |
+
# Load model
|
| 302 |
+
t0 = time.perf_counter()
|
| 303 |
+
if config_path and ckpt_path:
|
| 304 |
+
print(f"[Init] config={config_path}, ckpt={ckpt_path}, sp_size={sp_size}")
|
| 305 |
+
cfg = OmegaConf.load(config_path)
|
| 306 |
+
model_cfg = _get_model_config_from_yaml(cfg)
|
| 307 |
+
if sp_size > 1:
|
| 308 |
+
model_cfg["sp_size"] = sp_size
|
| 309 |
+
if enable_bf16:
|
| 310 |
+
model_cfg["enable_bf16"] = True
|
| 311 |
+
model = WorldMirror(**model_cfg).to(device)
|
| 312 |
+
state = _load_checkpoint_state_dict(ckpt_path)
|
| 313 |
+
_load_state_dict_selective(model, state, source_name=ckpt_path)
|
| 314 |
+
del state; gc.collect(); torch.cuda.empty_cache()
|
| 315 |
+
else:
|
| 316 |
+
model_dir = _resolve_model_dir(pretrained_model_name_or_path, subfolder)
|
| 317 |
+
model_cfg = _load_model_config(model_dir)
|
| 318 |
+
if sp_size > 1:
|
| 319 |
+
model_cfg["sp_size"] = sp_size
|
| 320 |
+
if enable_bf16:
|
| 321 |
+
model_cfg["enable_bf16"] = True
|
| 322 |
+
model = WorldMirror(**model_cfg).to(device)
|
| 323 |
+
state = load_safetensors(os.path.join(model_dir, "model.safetensors"))
|
| 324 |
+
_load_state_dict_selective(model, state, source_name=model_dir)
|
| 325 |
+
del state; gc.collect(); torch.cuda.empty_cache()
|
| 326 |
+
|
| 327 |
+
# bf16 casting — two strategies depending on FSDP:
|
| 328 |
+
#
|
| 329 |
+
# Multi-GPU + FSDP: cast everything to bf16 uniformly (including fc2).
|
| 330 |
+
# FSDP requires uniform dtype per flat-param unit.
|
| 331 |
+
#
|
| 332 |
+
# Single GPU (no FSDP): cast to bf16, then restore critical fp32
|
| 333 |
+
# modules (MlpFP32.fc2, output_conv2) so their .float() calls work.
|
| 334 |
+
# Register input-cast hooks on bf16 leaf modules for dtype boundaries.
|
| 335 |
+
if enable_bf16:
|
| 336 |
+
if use_fsdp and is_distributed:
|
| 337 |
+
model.to(torch.bfloat16)
|
| 338 |
+
crit = _collect_fp32_critical_modules(model)
|
| 339 |
+
_cast_noncritical_fp32_to_bf16(model, crit)
|
| 340 |
+
else:
|
| 341 |
+
crit = _collect_fp32_critical_modules(model)
|
| 342 |
+
model.to(torch.bfloat16)
|
| 343 |
+
for mod in crit:
|
| 344 |
+
mod.to(torch.float32)
|
| 345 |
+
|
| 346 |
+
def _input_cast_hook(module, args):
|
| 347 |
+
if not args:
|
| 348 |
+
return args
|
| 349 |
+
dtype = next((p.dtype for p in module.parameters(recurse=False)), None)
|
| 350 |
+
if dtype is None:
|
| 351 |
+
return args
|
| 352 |
+
return tuple(
|
| 353 |
+
a.to(dtype) if isinstance(a, torch.Tensor) and a.is_floating_point() and a.dtype != dtype else a
|
| 354 |
+
for a in args
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
for _, module in model.named_modules():
|
| 358 |
+
if not any(True for _ in module.children()):
|
| 359 |
+
own = list(module.parameters(recurse=False))
|
| 360 |
+
if own and all(p.dtype == torch.bfloat16 for p in own):
|
| 361 |
+
module.register_forward_pre_hook(_input_cast_hook)
|
| 362 |
+
|
| 363 |
+
model.eval()
|
| 364 |
+
|
| 365 |
+
# Disable unused heads
|
| 366 |
+
if disable_heads:
|
| 367 |
+
_disable_heads(model, disable_heads)
|
| 368 |
+
|
| 369 |
+
# FSDP wrapping
|
| 370 |
+
if use_fsdp and is_distributed:
|
| 371 |
+
model = _wrap_model_fsdp(model, sp_group, device,
|
| 372 |
+
use_cpu_offload=fsdp_cpu_offload,
|
| 373 |
+
enable_bf16=enable_bf16)
|
| 374 |
+
if enable_bf16:
|
| 375 |
+
inner = model.module if hasattr(model, 'module') else model
|
| 376 |
+
inner.to = lambda *a, **kw: inner
|
| 377 |
+
|
| 378 |
+
if rank == 0:
|
| 379 |
+
print(f"[Init] Model ready in {time.perf_counter() - t0:.1f}s")
|
| 380 |
+
if torch.cuda.is_available():
|
| 381 |
+
alloc = torch.cuda.memory_allocated(device) / (1024**3)
|
| 382 |
+
print(f"[Memory] allocated={alloc:.2f}GB")
|
| 383 |
+
|
| 384 |
+
return cls(model, device, sp_size, sp_group, rank)
|
| 385 |
+
|
| 386 |
+
@torch.no_grad()
|
| 387 |
+
def __call__(
|
| 388 |
+
self,
|
| 389 |
+
input_path: str,
|
| 390 |
+
output_path: str = "inference_output",
|
| 391 |
+
*,
|
| 392 |
+
# Inference
|
| 393 |
+
target_size: int = 952,
|
| 394 |
+
fps: int = 1,
|
| 395 |
+
video_strategy: str = "new",
|
| 396 |
+
video_min_frames: int = 1,
|
| 397 |
+
video_max_frames: int = 32,
|
| 398 |
+
# Save
|
| 399 |
+
save_depth: bool = True,
|
| 400 |
+
save_normal: bool = True,
|
| 401 |
+
save_gs: bool = True,
|
| 402 |
+
save_camera: bool = True,
|
| 403 |
+
save_points: bool = True,
|
| 404 |
+
save_colmap: bool = False,
|
| 405 |
+
save_conf: bool = False,
|
| 406 |
+
# Mask
|
| 407 |
+
apply_sky_mask: bool = True,
|
| 408 |
+
apply_edge_mask: bool = True,
|
| 409 |
+
apply_confidence_mask: bool = False,
|
| 410 |
+
save_sky_mask: bool = False,
|
| 411 |
+
sky_mask_source: str = "auto",
|
| 412 |
+
model_sky_threshold: float = 0.45,
|
| 413 |
+
confidence_percentile: float = 10.0,
|
| 414 |
+
edge_normal_threshold: float = 1.0,
|
| 415 |
+
edge_depth_threshold: float = 0.03,
|
| 416 |
+
# Compression
|
| 417 |
+
compress_pts: bool = True,
|
| 418 |
+
compress_pts_max_points: int = 2_000_000,
|
| 419 |
+
compress_pts_voxel_size: float = 0.002,
|
| 420 |
+
max_resolution: int = 1920,
|
| 421 |
+
compress_gs_max_points: int = 5_000_000,
|
| 422 |
+
# Prior
|
| 423 |
+
prior_cam_path: str = None,
|
| 424 |
+
prior_depth_path: str = None,
|
| 425 |
+
# Rendered video
|
| 426 |
+
save_rendered: bool = False,
|
| 427 |
+
render_interp_per_pair: int = 15,
|
| 428 |
+
render_depth: bool = False,
|
| 429 |
+
# Misc
|
| 430 |
+
log_time: bool = True,
|
| 431 |
+
strict_output_path: str = None,
|
| 432 |
+
) -> str:
|
| 433 |
+
"""Run inference on images/video and save results.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
input_path: Directory of images or a video file.
|
| 437 |
+
output_path: Root output directory.
|
| 438 |
+
**kwargs: Override default inference parameters.
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
Path to the output directory (str), or None on skip.
|
| 442 |
+
"""
|
| 443 |
+
model = self.model
|
| 444 |
+
device = self.device
|
| 445 |
+
sp_size, sp_group, rank = self.sp_size, self.sp_group, self.rank
|
| 446 |
+
is_distributed = sp_size > 1
|
| 447 |
+
|
| 448 |
+
case_t0 = time.perf_counter()
|
| 449 |
+
timings = {}
|
| 450 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 451 |
+
|
| 452 |
+
# 1. Prepare input
|
| 453 |
+
t0 = time.perf_counter()
|
| 454 |
+
img_paths, subdir_name = prepare_input(
|
| 455 |
+
input_path, target_size=target_size, fps=fps,
|
| 456 |
+
video_strategy=video_strategy,
|
| 457 |
+
min_frames=video_min_frames, max_frames=video_max_frames,
|
| 458 |
+
)
|
| 459 |
+
if log_time:
|
| 460 |
+
timings["data_loading"] = time.perf_counter() - t0
|
| 461 |
+
|
| 462 |
+
if strict_output_path is not None:
|
| 463 |
+
outdir = Path(strict_output_path)
|
| 464 |
+
else:
|
| 465 |
+
outdir = Path(output_path) / subdir_name / timestamp
|
| 466 |
+
|
| 467 |
+
# 2. Adaptive resolution
|
| 468 |
+
effective = compute_adaptive_target_size(img_paths, target_size)
|
| 469 |
+
if rank == 0 and effective != target_size:
|
| 470 |
+
print(f"[Inference] Adaptive resolution: {effective} (max={target_size})")
|
| 471 |
+
|
| 472 |
+
# 3. Inference
|
| 473 |
+
if torch.cuda.is_available():
|
| 474 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 475 |
+
torch.cuda.synchronize(device)
|
| 476 |
+
|
| 477 |
+
t0_all = time.perf_counter()
|
| 478 |
+
try:
|
| 479 |
+
predictions, imgs, infer_time = self._run_inference(
|
| 480 |
+
img_paths, effective, prior_cam_path, prior_depth_path)
|
| 481 |
+
except ValueError as e:
|
| 482 |
+
if rank == 0:
|
| 483 |
+
print(f"[Pipeline] Skipping '{input_path}': {e}")
|
| 484 |
+
return None
|
| 485 |
+
|
| 486 |
+
if log_time:
|
| 487 |
+
timings["inference"] = infer_time
|
| 488 |
+
timings["inference_preprocess"] = time.perf_counter() - t0_all - infer_time
|
| 489 |
+
|
| 490 |
+
# GPU memory stats (multi-GPU)
|
| 491 |
+
if log_time and torch.cuda.is_available() and is_distributed:
|
| 492 |
+
peak = torch.cuda.max_memory_allocated(device) / (1024**3)
|
| 493 |
+
peak_t = torch.tensor([peak], dtype=torch.float64, device=device)
|
| 494 |
+
gathered = [torch.zeros(1, dtype=torch.float64, device=device) for _ in range(sp_size)]
|
| 495 |
+
dist.all_gather(gathered, peak_t, group=sp_group)
|
| 496 |
+
timings["gpu_mem_peak_per_rank_gb"] = [t.item() for t in gathered]
|
| 497 |
+
timings["gpu_mem_peak_avg_gb"] = sum(timings["gpu_mem_peak_per_rank_gb"]) / sp_size
|
| 498 |
+
|
| 499 |
+
# 4. Post-processing and saving (rank 0 only)
|
| 500 |
+
if rank == 0:
|
| 501 |
+
B, S, C, H, W = imgs.shape
|
| 502 |
+
t0 = time.perf_counter()
|
| 503 |
+
|
| 504 |
+
sky_mask = (compute_sky_mask(
|
| 505 |
+
img_paths, H, W, S, predictions=predictions,
|
| 506 |
+
source=sky_mask_source, model_threshold=model_sky_threshold,
|
| 507 |
+
processed_aspect_ratio=W / H,
|
| 508 |
+
) if apply_sky_mask else None)
|
| 509 |
+
|
| 510 |
+
filter_mask, gs_filter_mask = None, None
|
| 511 |
+
if apply_confidence_mask or apply_edge_mask or apply_sky_mask:
|
| 512 |
+
filter_mask, gs_filter_mask = compute_filter_mask(
|
| 513 |
+
predictions, imgs, img_paths, H, W, S,
|
| 514 |
+
apply_confidence_mask=apply_confidence_mask,
|
| 515 |
+
apply_edge_mask=apply_edge_mask,
|
| 516 |
+
apply_sky_mask=apply_sky_mask,
|
| 517 |
+
confidence_percentile=confidence_percentile,
|
| 518 |
+
edge_normal_threshold=edge_normal_threshold,
|
| 519 |
+
edge_depth_threshold=edge_depth_threshold,
|
| 520 |
+
sky_mask=sky_mask, use_gs_depth=save_gs,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if log_time:
|
| 524 |
+
timings["compute_mask"] = time.perf_counter() - t0
|
| 525 |
+
|
| 526 |
+
t0 = time.perf_counter()
|
| 527 |
+
save_timings = save_results(
|
| 528 |
+
predictions, imgs, img_paths, outdir,
|
| 529 |
+
save_depth=save_depth, save_normal=save_normal,
|
| 530 |
+
save_gs=save_gs, save_camera=save_camera,
|
| 531 |
+
save_points=save_points, save_colmap=save_colmap,
|
| 532 |
+
save_sky_mask=save_sky_mask, save_conf=save_conf,
|
| 533 |
+
log_time=log_time, max_resolution=max_resolution,
|
| 534 |
+
filter_mask=filter_mask, gs_filter_mask=gs_filter_mask,
|
| 535 |
+
sky_mask=sky_mask,
|
| 536 |
+
compress_pts=compress_pts,
|
| 537 |
+
compress_pts_max_points=compress_pts_max_points,
|
| 538 |
+
compress_pts_voxel_size=compress_pts_voxel_size,
|
| 539 |
+
compress_gs_max_points=compress_gs_max_points,
|
| 540 |
+
)
|
| 541 |
+
if log_time:
|
| 542 |
+
timings.update(save_timings or {})
|
| 543 |
+
timings["save_total_wall"] = time.perf_counter() - t0
|
| 544 |
+
|
| 545 |
+
# Render interpolated video from Gaussian splats
|
| 546 |
+
if save_rendered and "splats" in predictions:
|
| 547 |
+
inner_model = model.module if hasattr(model, 'module') else model
|
| 548 |
+
if hasattr(inner_model, 'gs_renderer'):
|
| 549 |
+
t0_render = time.perf_counter()
|
| 550 |
+
try:
|
| 551 |
+
splats_f32 = {k: v.float() if isinstance(v, torch.Tensor) else v
|
| 552 |
+
for k, v in predictions["splats"].items()}
|
| 553 |
+
render_interpolated_video(
|
| 554 |
+
inner_model.gs_renderer,
|
| 555 |
+
splats_f32,
|
| 556 |
+
predictions["camera_poses"].float(),
|
| 557 |
+
predictions["camera_intrs"].float(),
|
| 558 |
+
(H, W),
|
| 559 |
+
outdir / "rendered",
|
| 560 |
+
interp_per_pair=render_interp_per_pair,
|
| 561 |
+
loop_reverse=(S <= 2),
|
| 562 |
+
render_depth=render_depth,
|
| 563 |
+
)
|
| 564 |
+
if log_time:
|
| 565 |
+
timings["render_video"] = time.perf_counter() - t0_render
|
| 566 |
+
except Exception as e:
|
| 567 |
+
print(f"[Pipeline] Warning: video rendering failed: {e}")
|
| 568 |
+
|
| 569 |
+
if not is_distributed:
|
| 570 |
+
del predictions
|
| 571 |
+
torch.cuda.empty_cache()
|
| 572 |
+
|
| 573 |
+
timings["case_total"] = time.perf_counter() - case_t0
|
| 574 |
+
if log_time:
|
| 575 |
+
print_and_save_timings(timings, outdir)
|
| 576 |
+
|
| 577 |
+
print(f"\n{'='*60}\n[Pipeline] Results saved to: {outdir}\n{'='*60}\n")
|
| 578 |
+
|
| 579 |
+
if is_distributed:
|
| 580 |
+
del predictions, imgs
|
| 581 |
+
gc.collect()
|
| 582 |
+
torch.cuda.empty_cache()
|
| 583 |
+
dist.barrier()
|
| 584 |
+
|
| 585 |
+
return str(outdir)
|
| 586 |
+
|
| 587 |
+
def _run_inference(self, img_paths, target_size, prior_cam_path, prior_depth_path):
|
| 588 |
+
"""Run model forward pass."""
|
| 589 |
+
device = self.device
|
| 590 |
+
imgs = prepare_images_to_tensor(
|
| 591 |
+
img_paths, target_size=target_size, resize_strategy="crop"
|
| 592 |
+
).to(device)
|
| 593 |
+
views = {"img": imgs}
|
| 594 |
+
B, S, C, H, W = imgs.shape
|
| 595 |
+
|
| 596 |
+
if self.sp_size > 1 and S < self.sp_size:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
f"Number of input images ({S}) must be >= number of GPUs ({self.sp_size}) "
|
| 599 |
+
f"in multi-GPU mode. Please provide at least {self.sp_size} images, "
|
| 600 |
+
f"or use fewer GPUs."
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
if self.rank == 0:
|
| 604 |
+
print(f"[Inference] {S} images, shape={imgs.shape}, sp_size={self.sp_size}")
|
| 605 |
+
|
| 606 |
+
pp_xform = compute_preprocessing_transform(img_paths, target_size)
|
| 607 |
+
cond_flags = [0, 0, 0]
|
| 608 |
+
|
| 609 |
+
if prior_cam_path and os.path.isfile(prior_cam_path):
|
| 610 |
+
extr, intr = load_prior_camera(prior_cam_path, img_paths, preprocess_transform=pp_xform)
|
| 611 |
+
if extr is not None:
|
| 612 |
+
first = extr[0, 0]
|
| 613 |
+
extr = torch.linalg.inv(first.float()).to(first.dtype).unsqueeze(0).unsqueeze(0) @ extr
|
| 614 |
+
views["camera_poses"] = extr.to(device)
|
| 615 |
+
cond_flags[0] = 1
|
| 616 |
+
if intr is not None:
|
| 617 |
+
views["camera_intrs"] = intr.to(device)
|
| 618 |
+
cond_flags[2] = 1
|
| 619 |
+
|
| 620 |
+
if prior_depth_path and os.path.isdir(prior_depth_path):
|
| 621 |
+
depth = load_prior_depth(prior_depth_path, img_paths, H, W, preprocess_transform=pp_xform)
|
| 622 |
+
if depth is not None:
|
| 623 |
+
views["depthmap"] = depth.to(device)
|
| 624 |
+
cond_flags[1] = 1
|
| 625 |
+
|
| 626 |
+
use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 627 |
+
inner = self.model.module if hasattr(self.model, 'module') else self.model
|
| 628 |
+
model_bf16 = getattr(inner, 'enable_bf16', False)
|
| 629 |
+
|
| 630 |
+
t0 = time.perf_counter()
|
| 631 |
+
with torch.amp.autocast("cuda", enabled=(not model_bf16 and use_amp), dtype=torch.bfloat16):
|
| 632 |
+
fwd_kw = dict(views=views, cond_flags=cond_flags, is_inference=True)
|
| 633 |
+
if self.sp_size > 1:
|
| 634 |
+
fwd_kw["sp_size"] = self.sp_size
|
| 635 |
+
fwd_kw["sp_group"] = self.sp_group
|
| 636 |
+
predictions = self.model(**fwd_kw)
|
| 637 |
+
if device.type == "cuda":
|
| 638 |
+
torch.cuda.synchronize()
|
| 639 |
+
infer_time = time.perf_counter() - t0
|
| 640 |
+
|
| 641 |
+
if self.rank == 0:
|
| 642 |
+
print(f"[Inference] Done in {infer_time:.2f}s")
|
| 643 |
+
return predictions, imgs, infer_time
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
# ============================================================
|
| 647 |
+
# Head disabling helper
|
| 648 |
+
# ============================================================
|
| 649 |
+
|
| 650 |
+
def _disable_heads(model, head_names):
|
| 651 |
+
"""Disable and free specified heads. head_names: list of 'camera','depth','normal','points','gs'."""
|
| 652 |
+
mapping = {
|
| 653 |
+
"camera": ("enable_cam", ["cam_head"]),
|
| 654 |
+
"depth": ("enable_depth", ["depth_head"]),
|
| 655 |
+
"normal": ("enable_norm", ["norm_head"]),
|
| 656 |
+
"points": ("enable_pts", ["pts_head"]),
|
| 657 |
+
"gs": ("enable_gs", ["gs_head", "gs_renderer"]),
|
| 658 |
+
}
|
| 659 |
+
freed = 0
|
| 660 |
+
for name in head_names:
|
| 661 |
+
if name not in mapping:
|
| 662 |
+
continue
|
| 663 |
+
attr, modules = mapping[name]
|
| 664 |
+
setattr(model, attr, False)
|
| 665 |
+
for mod_name in modules:
|
| 666 |
+
if hasattr(model, mod_name):
|
| 667 |
+
mod = getattr(model, mod_name)
|
| 668 |
+
freed += sum(p.numel() for p in mod.parameters())
|
| 669 |
+
mod.cpu()
|
| 670 |
+
delattr(model, mod_name)
|
| 671 |
+
del mod
|
| 672 |
+
if freed:
|
| 673 |
+
gc.collect()
|
| 674 |
+
torch.cuda.empty_cache()
|
| 675 |
+
print(f"[Init] Disabled heads: {head_names}, freed ~{freed/1e6:.1f}M params")
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# ============================================================
|
| 679 |
+
# CLI entry point
|
| 680 |
+
# ============================================================
|
| 681 |
+
|
| 682 |
+
def _broadcast_string(s, rank, src=0):
|
| 683 |
+
if rank == src:
|
| 684 |
+
data = s.encode("utf-8")
|
| 685 |
+
length = torch.tensor([len(data)], dtype=torch.long, device="cuda")
|
| 686 |
+
else:
|
| 687 |
+
length = torch.tensor([0], dtype=torch.long, device="cuda")
|
| 688 |
+
dist.broadcast(length, src=src)
|
| 689 |
+
n = length.item()
|
| 690 |
+
tensor = torch.tensor(list(data), dtype=torch.uint8, device="cuda") if rank == src else torch.empty(n, dtype=torch.uint8, device="cuda")
|
| 691 |
+
dist.broadcast(tensor, src=src)
|
| 692 |
+
return tensor.cpu().numpy().tobytes().decode("utf-8")
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def main():
|
| 696 |
+
parser = argparse.ArgumentParser(description="HunyuanWorld-Mirror Pipeline")
|
| 697 |
+
parser.add_argument("--input_path", type=str, required=True)
|
| 698 |
+
parser.add_argument("--output_path", type=str, default="inference_output")
|
| 699 |
+
parser.add_argument("--strict_output_path", type=str, default=None,
|
| 700 |
+
help="If set, save results directly to this path (no subdir/timestamp)")
|
| 701 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="tencent/HY-World-2.0",
|
| 702 |
+
help="HuggingFace repo ID or local path")
|
| 703 |
+
parser.add_argument("--subfolder", type=str, default="HY-WorldMirror-2.0",
|
| 704 |
+
help="Subfolder inside the repo containing WorldMirror weights")
|
| 705 |
+
parser.add_argument("--config_path", type=str, default=None)
|
| 706 |
+
parser.add_argument("--ckpt_path", type=str, default=None)
|
| 707 |
+
parser.add_argument("--use_fsdp", action="store_true", default=False)
|
| 708 |
+
parser.add_argument("--enable_bf16", action="store_true", default=False)
|
| 709 |
+
parser.add_argument("--fsdp_cpu_offload", action="store_true", default=False)
|
| 710 |
+
parser.add_argument("--target_size", type=int, default=952)
|
| 711 |
+
parser.add_argument("--fps", type=int, default=1)
|
| 712 |
+
parser.add_argument("--video_strategy", type=str, default="new", choices=["old", "new"])
|
| 713 |
+
parser.add_argument("--video_min_frames", type=int, default=1)
|
| 714 |
+
parser.add_argument("--video_max_frames", type=int, default=32)
|
| 715 |
+
parser.add_argument("--no_save_depth", action="store_true")
|
| 716 |
+
parser.add_argument("--no_save_normal", action="store_true")
|
| 717 |
+
parser.add_argument("--no_save_gs", action="store_true")
|
| 718 |
+
parser.add_argument("--no_save_camera", action="store_true")
|
| 719 |
+
parser.add_argument("--no_save_points", action="store_true")
|
| 720 |
+
parser.add_argument("--save_colmap", action="store_true", default=False)
|
| 721 |
+
parser.add_argument("--save_conf", action="store_true", default=False)
|
| 722 |
+
parser.add_argument("--save_sky_mask", action="store_true", default=False)
|
| 723 |
+
parser.add_argument("--apply_sky_mask", action="store_true", default=True)
|
| 724 |
+
parser.add_argument("--no_sky_mask", dest="apply_sky_mask", action="store_false")
|
| 725 |
+
parser.add_argument("--apply_edge_mask", action="store_true", default=True)
|
| 726 |
+
parser.add_argument("--no_edge_mask", dest="apply_edge_mask", action="store_false")
|
| 727 |
+
parser.add_argument("--apply_confidence_mask", action="store_true", default=False)
|
| 728 |
+
parser.add_argument("--sky_mask_source", type=str, default="auto", choices=["auto", "model", "onnx"])
|
| 729 |
+
parser.add_argument("--model_sky_threshold", type=float, default=0.45)
|
| 730 |
+
parser.add_argument("--confidence_percentile", type=float, default=10.0)
|
| 731 |
+
parser.add_argument("--edge_normal_threshold", type=float, default=1.0)
|
| 732 |
+
parser.add_argument("--edge_depth_threshold", type=float, default=0.03)
|
| 733 |
+
parser.add_argument("--compress_pts", action="store_true", default=True)
|
| 734 |
+
parser.add_argument("--no_compress_pts", dest="compress_pts", action="store_false")
|
| 735 |
+
parser.add_argument("--compress_pts_max_points", type=int, default=2_000_000)
|
| 736 |
+
parser.add_argument("--compress_pts_voxel_size", type=float, default=0.002)
|
| 737 |
+
parser.add_argument("--max_resolution", type=int, default=1920)
|
| 738 |
+
parser.add_argument("--compress_gs_max_points", type=int, default=5_000_000)
|
| 739 |
+
parser.add_argument("--prior_cam_path", type=str, default=None)
|
| 740 |
+
parser.add_argument("--prior_depth_path", type=str, default=None)
|
| 741 |
+
parser.add_argument("--disable_heads", type=str, nargs="*", default=None,
|
| 742 |
+
help="Heads to disable: camera depth normal points gs")
|
| 743 |
+
parser.add_argument("--save_rendered", action="store_true", default=False,
|
| 744 |
+
help="Render interpolated video from Gaussian splats")
|
| 745 |
+
parser.add_argument("--render_interp_per_pair", type=int, default=15,
|
| 746 |
+
help="Interpolated frames per camera pair for video rendering")
|
| 747 |
+
parser.add_argument("--render_depth", action="store_true", default=False,
|
| 748 |
+
help="Also render depth video")
|
| 749 |
+
parser.add_argument("--log_time", action="store_true", default=True)
|
| 750 |
+
parser.add_argument("--no_log_time", dest="log_time", action="store_false")
|
| 751 |
+
parser.add_argument("--no_interactive", action="store_true")
|
| 752 |
+
args = parser.parse_args()
|
| 753 |
+
|
| 754 |
+
pipeline = WorldMirrorPipeline.from_pretrained(
|
| 755 |
+
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
|
| 756 |
+
subfolder=args.subfolder,
|
| 757 |
+
config_path=args.config_path, ckpt_path=args.ckpt_path,
|
| 758 |
+
use_fsdp=args.use_fsdp, enable_bf16=args.enable_bf16,
|
| 759 |
+
fsdp_cpu_offload=args.fsdp_cpu_offload,
|
| 760 |
+
disable_heads=args.disable_heads,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
call_kwargs = dict(
|
| 764 |
+
output_path=args.output_path,
|
| 765 |
+
target_size=args.target_size, fps=args.fps,
|
| 766 |
+
video_strategy=args.video_strategy,
|
| 767 |
+
video_min_frames=args.video_min_frames,
|
| 768 |
+
video_max_frames=args.video_max_frames,
|
| 769 |
+
save_depth=not args.no_save_depth,
|
| 770 |
+
save_normal=not args.no_save_normal,
|
| 771 |
+
save_gs=not args.no_save_gs,
|
| 772 |
+
save_camera=not args.no_save_camera,
|
| 773 |
+
save_points=not args.no_save_points,
|
| 774 |
+
save_colmap=args.save_colmap,
|
| 775 |
+
save_conf=args.save_conf,
|
| 776 |
+
save_sky_mask=args.save_sky_mask,
|
| 777 |
+
apply_sky_mask=args.apply_sky_mask,
|
| 778 |
+
apply_edge_mask=args.apply_edge_mask,
|
| 779 |
+
apply_confidence_mask=args.apply_confidence_mask,
|
| 780 |
+
sky_mask_source=args.sky_mask_source,
|
| 781 |
+
model_sky_threshold=args.model_sky_threshold,
|
| 782 |
+
confidence_percentile=args.confidence_percentile,
|
| 783 |
+
edge_normal_threshold=args.edge_normal_threshold,
|
| 784 |
+
edge_depth_threshold=args.edge_depth_threshold,
|
| 785 |
+
compress_pts=args.compress_pts,
|
| 786 |
+
compress_pts_max_points=args.compress_pts_max_points,
|
| 787 |
+
compress_pts_voxel_size=args.compress_pts_voxel_size,
|
| 788 |
+
max_resolution=args.max_resolution,
|
| 789 |
+
compress_gs_max_points=args.compress_gs_max_points,
|
| 790 |
+
prior_cam_path=args.prior_cam_path,
|
| 791 |
+
prior_depth_path=args.prior_depth_path,
|
| 792 |
+
save_rendered=args.save_rendered,
|
| 793 |
+
render_interp_per_pair=args.render_interp_per_pair,
|
| 794 |
+
render_depth=args.render_depth,
|
| 795 |
+
log_time=args.log_time,
|
| 796 |
+
strict_output_path=args.strict_output_path,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
try:
|
| 800 |
+
pipeline(args.input_path, **call_kwargs)
|
| 801 |
+
|
| 802 |
+
if args.no_interactive:
|
| 803 |
+
return
|
| 804 |
+
|
| 805 |
+
rank = pipeline.rank
|
| 806 |
+
is_distributed = pipeline.sp_size > 1
|
| 807 |
+
|
| 808 |
+
if rank == 0:
|
| 809 |
+
print("\n[Interactive] Enter new input paths. Type 'quit' to stop.\n")
|
| 810 |
+
|
| 811 |
+
_INF_TIMEOUT = timedelta(days=365)
|
| 812 |
+
_DEF_TIMEOUT = timedelta(minutes=10)
|
| 813 |
+
|
| 814 |
+
while True:
|
| 815 |
+
if is_distributed:
|
| 816 |
+
dist.distributed_c10d._get_default_group()._get_backend(
|
| 817 |
+
torch.device("cuda")).options._timeout = _INF_TIMEOUT
|
| 818 |
+
|
| 819 |
+
new_input = ""
|
| 820 |
+
if rank == 0:
|
| 821 |
+
try:
|
| 822 |
+
new_input = input(">>> ").strip()
|
| 823 |
+
except (EOFError, KeyboardInterrupt):
|
| 824 |
+
new_input = "quit"
|
| 825 |
+
|
| 826 |
+
if is_distributed:
|
| 827 |
+
new_input = _broadcast_string(new_input, rank, src=0)
|
| 828 |
+
dist.distributed_c10d._get_default_group()._get_backend(
|
| 829 |
+
torch.device("cuda")).options._timeout = _DEF_TIMEOUT
|
| 830 |
+
|
| 831 |
+
if not new_input or new_input.lower() in ("quit", "exit", "q"):
|
| 832 |
+
break
|
| 833 |
+
|
| 834 |
+
if rank == 0 and not (Path(new_input).is_dir() or Path(new_input).is_file()):
|
| 835 |
+
print(f" Invalid path: {new_input}")
|
| 836 |
+
continue
|
| 837 |
+
|
| 838 |
+
pipeline.model.to(pipeline.device)
|
| 839 |
+
pipeline.model.eval()
|
| 840 |
+
pipeline(new_input, **call_kwargs)
|
| 841 |
+
finally:
|
| 842 |
+
if dist.is_initialized():
|
| 843 |
+
dist.destroy_process_group()
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
if __name__ == "__main__":
|
| 847 |
+
main()
|