| import math |
| from typing import Literal |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from torch import Tensor, nn |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| SampleMods = Literal[ |
| "conv", |
| "pixelshuffledirect", |
| "pixelshuffle", |
| "nearest+conv", |
| "dysample", |
| "transpose+conv", |
| "lda", |
| "pa_up", |
| ] |
|
|
|
|
| def ICNR(tensor, initializer, upscale_factor=2, *args, **kwargs): |
| upscale_factor_squared = upscale_factor * upscale_factor |
| assert tensor.shape[0] % upscale_factor_squared == 0, ( |
| "The size of the first dimension: " |
| f"tensor.shape[0] = {tensor.shape[0]}" |
| " is not divisible by square of upscale_factor: " |
| f"upscale_factor = {upscale_factor}" |
| ) |
| sub_kernel = torch.empty( |
| tensor.shape[0] // upscale_factor_squared, *tensor.shape[1:] |
| ) |
| sub_kernel = initializer(sub_kernel, *args, **kwargs) |
| return sub_kernel.repeat_interleave(upscale_factor_squared, dim=0) |
|
|
|
|
| class DySample(nn.Module): |
| """Adapted from 'Learning to Upsample by Learning to Sample': |
| https://arxiv.org/abs/2308.15085 |
| https://github.com/tiny-smart/dysample |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int = 64, |
| out_ch: int = 3, |
| scale: int = 2, |
| groups: int = 4, |
| end_convolution: bool = True, |
| end_kernel=1, |
| ) -> None: |
| super().__init__() |
|
|
| if in_channels <= groups or in_channels % groups != 0: |
| msg = "Incorrect in_channels and groups values." |
| raise ValueError(msg) |
|
|
| out_channels = 2 * groups * scale**2 |
| self.scale = scale |
| self.groups = groups |
| self.end_convolution = end_convolution |
| if end_convolution: |
| self.end_conv = nn.Conv2d( |
| in_channels, out_ch, end_kernel, 1, end_kernel // 2 |
| ) |
| self.offset = nn.Conv2d(in_channels, out_channels, 1) |
| self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False) |
| if self.training: |
| nn.init.trunc_normal_(self.offset.weight, std=0.02) |
| nn.init.constant_(self.scope.weight, val=0) |
|
|
| self.register_buffer("init_pos", self._init_pos()) |
|
|
| def _init_pos(self) -> Tensor: |
| h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale |
| return ( |
| torch.stack(torch.meshgrid([h, h], indexing="ij")) |
| .transpose(1, 2) |
| .repeat(1, self.groups, 1) |
| .reshape(1, -1, 1, 1) |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos |
| B, _, H, W = offset.shape |
| offset = offset.view(B, 2, -1, H, W) |
| coords_h = torch.arange(H) + 0.5 |
| coords_w = torch.arange(W) + 0.5 |
|
|
| coords = ( |
| torch.stack(torch.meshgrid([coords_w, coords_h], indexing="ij")) |
| .transpose(1, 2) |
| .unsqueeze(1) |
| .unsqueeze(0) |
| .type(x.dtype) |
| .to(x.device, non_blocking=True) |
| ) |
| normalizer = torch.tensor( |
| [W, H], dtype=x.dtype, device=x.device, pin_memory=True |
| ).view(1, 2, 1, 1, 1) |
| coords = 2 * (coords + offset) / normalizer - 1 |
|
|
| coords = ( |
| F.pixel_shuffle(coords.reshape(B, -1, H, W), self.scale) |
| .view(B, 2, -1, self.scale * H, self.scale * W) |
| .permute(0, 2, 3, 4, 1) |
| .contiguous() |
| .flatten(0, 1) |
| ) |
| output = F.grid_sample( |
| x.reshape(B * self.groups, -1, H, W), |
| coords, |
| mode="bilinear", |
| align_corners=False, |
| padding_mode="border", |
| ).view(B, -1, self.scale * H, self.scale * W) |
|
|
| if self.end_convolution: |
| output = self.end_conv(output) |
|
|
| return output |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim: int = 64, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.bias = nn.Parameter(torch.zeros(dim)) |
| self.eps = eps |
| self.dim = (dim,) |
|
|
| def forward(self, x): |
| if x.is_contiguous(memory_format=torch.channels_last): |
| return F.layer_norm( |
| x.permute(0, 2, 3, 1), self.dim, self.weight, self.bias, self.eps |
| ).permute(0, 3, 1, 2) |
| u = x.mean(1, keepdim=True) |
| s = (x - u).pow(2).mean(1, keepdim=True) |
| x = (x - u) / torch.sqrt(s + self.eps) |
| return self.weight[:, None, None] * x + self.bias[:, None, None] |
|
|
|
|
| class LDA_AQU(nn.Module): |
| def __init__( |
| self, |
| in_channels=48, |
| reduction_factor=4, |
| nh=1, |
| scale_factor=2.0, |
| k_e=3, |
| k_u=3, |
| n_groups=2, |
| range_factor=11, |
| rpb=True, |
| ) -> None: |
| super().__init__() |
| self.k_u = k_u |
| self.num_head = nh |
| self.scale_factor = scale_factor |
| self.n_groups = n_groups |
| self.offset_range_factor = range_factor |
|
|
| self.attn_dim = in_channels // (reduction_factor * self.num_head) |
| self.scale = self.attn_dim**-0.5 |
| self.rpb = rpb |
| self.hidden_dim = in_channels // reduction_factor |
| self.proj_q = nn.Conv2d( |
| in_channels, self.hidden_dim, kernel_size=1, stride=1, padding=0, bias=False |
| ) |
|
|
| self.proj_k = nn.Conv2d( |
| in_channels, self.hidden_dim, kernel_size=1, stride=1, padding=0, bias=False |
| ) |
|
|
| self.group_channel = in_channels // (reduction_factor * self.n_groups) |
| |
| self.conv_offset = nn.Sequential( |
| nn.Conv2d( |
| self.group_channel, |
| self.group_channel, |
| 3, |
| 1, |
| 1, |
| groups=self.group_channel, |
| bias=False, |
| ), |
| LayerNorm(self.group_channel), |
| nn.SiLU(), |
| nn.Conv2d(self.group_channel, 2 * k_u**2, k_e, 1, k_e // 2), |
| ) |
| print(2 * k_u**2) |
| self.layer_norm = LayerNorm(in_channels) |
|
|
| self.pad = int((self.k_u - 1) / 2) |
| base = np.arange(-self.pad, self.pad + 1).astype(np.float32) |
| base_y = np.repeat(base, self.k_u) |
| base_x = np.tile(base, self.k_u) |
| base_offset = np.stack([base_y, base_x], axis=1).flatten() |
| base_offset = torch.tensor(base_offset).view(1, -1, 1, 1) |
| self.register_buffer("base_offset", base_offset, persistent=False) |
|
|
| if self.rpb: |
| self.relative_position_bias_table = nn.Parameter( |
| torch.zeros( |
| 1, self.num_head, 1, self.k_u**2, self.hidden_dim // self.num_head |
| ) |
| ) |
| nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) |
|
|
| def init_weights(self) -> None: |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.xavier_uniform(m) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| nn.init.constant_(self.conv_offset[-1].weight, 0) |
| nn.init.constant_(self.conv_offset[-1].bias, 0) |
|
|
| def get_offset(self, offset, Hout, Wout): |
| B, _, _, _ = offset.shape |
| device = offset.device |
| row_indices = torch.arange(Hout, device=device) |
| col_indices = torch.arange(Wout, device=device) |
| row_indices, col_indices = torch.meshgrid(row_indices, col_indices) |
| index_tensor = torch.stack((row_indices, col_indices), dim=-1).view( |
| 1, Hout, Wout, 2 |
| ) |
| offset = rearrange( |
| offset, "b (kh kw d) h w -> b kh h kw w d", kh=self.k_u, kw=self.k_u |
| ) |
| offset = offset + index_tensor.view(1, 1, Hout, 1, Wout, 2) |
| offset = offset.contiguous().view(B, self.k_u * Hout, self.k_u * Wout, 2) |
|
|
| offset[..., 0] = 2 * offset[..., 0] / (Hout - 1) - 1 |
| offset[..., 1] = 2 * offset[..., 1] / (Wout - 1) - 1 |
| offset = offset.flip(-1) |
| return offset |
|
|
| def extract_feats(self, x, offset, ks=3): |
| out = nn.functional.grid_sample( |
| x, offset, mode="bilinear", padding_mode="zeros", align_corners=True |
| ) |
| out = rearrange(out, "b c (ksh h) (ksw w) -> b (ksh ksw) c h w", ksh=ks, ksw=ks) |
| return out |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| out_H, out_W = int(H * self.scale_factor), int(W * self.scale_factor) |
| v = x |
| x = self.layer_norm(x) |
| q = self.proj_q(x) |
| k = self.proj_k(x) |
|
|
| q = torch.nn.functional.interpolate( |
| q, (out_H, out_W), mode="bilinear", align_corners=True |
| ) |
| q_off = q.view(B * self.n_groups, -1, out_H, out_W) |
| pred_offset = self.conv_offset(q_off) |
| offset = pred_offset.tanh().mul(self.offset_range_factor) + self.base_offset.to( |
| x.dtype |
| ) |
|
|
| k = k.view(B * self.n_groups, self.hidden_dim // self.n_groups, H, W) |
| v = v.view(B * self.n_groups, C // self.n_groups, H, W) |
| offset = self.get_offset(offset, out_H, out_W) |
| k = self.extract_feats(k, offset=offset) |
| v = self.extract_feats(v, offset=offset) |
|
|
| q = rearrange(q, "b (nh c) h w -> b nh (h w) () c", nh=self.num_head) |
| k = rearrange(k, "(b g) n c h w -> b (h w) n (g c)", g=self.n_groups) |
| v = rearrange(v, "(b g) n c h w -> b (h w) n (g c)", g=self.n_groups) |
| k = rearrange(k, "b n1 n (nh c) -> b nh n1 n c", nh=self.num_head) |
| v = rearrange(v, "b n1 n (nh c) -> b nh n1 n c", nh=self.num_head) |
|
|
| if self.rpb: |
| k = k + self.relative_position_bias_table |
|
|
| q = q * self.scale |
| attn = q @ k.transpose(-1, -2) |
| attn = attn.softmax(dim=-1) |
| out = attn @ v |
|
|
| out = rearrange(out, "b nh (h w) t c -> b (nh c) (t h) w", h=out_H) |
| return out |
|
|
|
|
| class PA(nn.Module): |
| def __init__(self, dim) -> None: |
| super().__init__() |
| self.conv = nn.Sequential(nn.Conv2d(dim, dim, 1), nn.Sigmoid()) |
|
|
| def forward(self, x): |
| return x.mul(self.conv(x)) |
|
|
|
|
| class UniUpsampleV3(nn.Sequential): |
| def __init__( |
| self, |
| upsample: SampleMods = "pa_up", |
| scale: int = 2, |
| in_dim: int = 48, |
| out_dim: int = 3, |
| mid_dim: int = 48, |
| group: int = 4, |
| dysample_end_kernel=1, |
| ) -> None: |
| m = [] |
|
|
| if scale == 1 or upsample == "conv": |
| m.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1)) |
| elif upsample == "pixelshuffledirect": |
| m.extend( |
| [nn.Conv2d(in_dim, out_dim * scale**2, 3, 1, 1), nn.PixelShuffle(scale)] |
| ) |
| elif upsample == "pixelshuffle": |
| m.extend([nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)]) |
| if (scale & (scale - 1)) == 0: |
| for _ in range(int(math.log2(scale))): |
| m.extend( |
| [nn.Conv2d(mid_dim, 4 * mid_dim, 3, 1, 1), nn.PixelShuffle(2)] |
| ) |
| elif scale == 3: |
| m.extend([nn.Conv2d(mid_dim, 9 * mid_dim, 3, 1, 1), nn.PixelShuffle(3)]) |
| else: |
| raise ValueError( |
| f"scale {scale} is not supported. Supported scales: 2^n and 3." |
| ) |
| m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1)) |
| elif upsample == "nearest+conv": |
| if (scale & (scale - 1)) == 0: |
| for _ in range(int(math.log2(scale))): |
| m.extend( |
| ( |
| nn.Conv2d(in_dim, in_dim, 3, 1, 1), |
| nn.Upsample(scale_factor=2), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| ) |
| ) |
| m.extend( |
| ( |
| nn.Conv2d(in_dim, in_dim, 3, 1, 1), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| ) |
| ) |
| elif scale == 3: |
| m.extend( |
| ( |
| nn.Conv2d(in_dim, in_dim, 3, 1, 1), |
| nn.Upsample(scale_factor=scale), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| nn.Conv2d(in_dim, in_dim, 3, 1, 1), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| ) |
| ) |
| else: |
| raise ValueError( |
| f"scale {scale} is not supported. Supported scales: 2^n and 3." |
| ) |
| m.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1)) |
| elif upsample == "dysample": |
| if mid_dim != in_dim: |
| m.extend( |
| [nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)] |
| ) |
| m.append( |
| DySample(mid_dim, out_dim, scale, group, end_kernel=dysample_end_kernel) |
| ) |
| |
| elif upsample == "transpose+conv": |
| if scale == 2: |
| m.append(nn.ConvTranspose2d(in_dim, out_dim, 4, 2, 1)) |
| elif scale == 3: |
| m.append(nn.ConvTranspose2d(in_dim, out_dim, 3, 3, 0)) |
| elif scale == 4: |
| m.extend( |
| [ |
| nn.ConvTranspose2d(in_dim, in_dim, 4, 2, 1), |
| nn.GELU(), |
| nn.ConvTranspose2d(in_dim, out_dim, 4, 2, 1), |
| ] |
| ) |
| else: |
| raise ValueError( |
| f"scale {scale} is not supported. Supported scales: 2, 3, 4" |
| ) |
| m.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1)) |
| elif upsample == "lda": |
| if mid_dim != in_dim: |
| m.extend( |
| [nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)] |
| ) |
| m.append(LDA_AQU(mid_dim, scale_factor=scale)) |
| m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1)) |
| elif upsample == "pa_up": |
| if (scale & (scale - 1)) == 0: |
| for _ in range(int(math.log2(scale))): |
| m.extend( |
| [ |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(in_dim, mid_dim, 3, 1, 1), |
| PA(mid_dim), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| nn.Conv2d(mid_dim, mid_dim, 3, 1, 1), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| ] |
| ) |
| in_dim = mid_dim |
| elif scale == 3: |
| m.extend( |
| [ |
| nn.Upsample(scale_factor=3), |
| nn.Conv2d(in_dim, mid_dim, 3, 1, 1), |
| PA(mid_dim), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| nn.Conv2d(mid_dim, mid_dim, 3, 1, 1), |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| ] |
| ) |
| else: |
| raise ValueError( |
| f"scale {scale} is not supported. Supported scales: 2^n and 3." |
| ) |
| m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1)) |
| else: |
| raise ValueError( |
| f"An invalid Upsample was selected. Please choose one of {SampleMods}" |
| ) |
| super().__init__(*m) |
|
|
| self.register_buffer( |
| "MetaUpsample", |
| torch.tensor( |
| [ |
| 3, |
| list(SampleMods.__args__).index(upsample), |
| scale, |
| in_dim, |
| out_dim, |
| mid_dim, |
| group, |
| ], |
| dtype=torch.uint8, |
| ), |
| ) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(dim)) |
| self.offset = nn.Parameter(torch.zeros(dim)) |
| self.eps = nn.Parameter(torch.Tensor(torch.ones(1) * eps), requires_grad=False) |
| self.rms = nn.Parameter( |
| torch.Tensor(torch.ones(1) * (dim**-0.5)), requires_grad=False |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| norm_x = torch.addcmul(self.eps, x.norm(2, dim=1, keepdim=True), self.rms) |
| return torch.addcmul( |
| self.offset[:, None, None], x.div(norm_x), self.scale[:, None, None] |
| ) |
|
|
|
|
| class CustomRFFT2(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x: torch.Tensor): |
| y = torch.fft.rfft2(x, dim=(2, 3), norm="ortho") |
| return torch.view_as_real(y) |
|
|
| @staticmethod |
| def symbolic(g, x: torch.Value): |
| shp = g.op("Shape", x) |
| iH = g.op("Constant", value_t=torch.tensor([2], dtype=torch.int64)) |
| iW = g.op("Constant", value_t=torch.tensor([3], dtype=torch.int64)) |
| nH = g.op("Gather", shp, iH, axis_i=0) |
| nW = g.op("Gather", shp, iW, axis_i=0) |
|
|
| axes_last = g.op("Constant", value_t=torch.tensor([4], dtype=torch.int64)) |
| x_u = g.op("Unsqueeze", x, axes_last) |
| zero = g.op("Sub", x_u, x_u) |
| x_c = g.op("Concat", x_u, zero, axis_i=4) |
|
|
| Hf = g.op("Cast", nH, to_i=torch.onnx.TensorProtoDataType.FLOAT) |
| Wf = g.op("Cast", nW, to_i=torch.onnx.TensorProtoDataType.FLOAT) |
|
|
| y = g.op("DFT", x_c, nW, axis_i=3, onesided_i=1) |
| y = g.op("Div", y, g.op("Sqrt", Wf)) |
|
|
| y = g.op("DFT", y, nH, axis_i=2, onesided_i=0) |
| y = g.op("Div", y, g.op("Sqrt", Hf)) |
|
|
| return y |
|
|
|
|
| class CustomIRFFT2(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x_ri: torch.Tensor): |
| x_c = torch.view_as_complex(x_ri) |
| return torch.fft.irfft2(x_c, dim=(2, 3), norm="ortho") |
|
|
| @staticmethod |
| def symbolic(g, x: torch.Value): |
| shp = g.op("Shape", x) |
| iH = g.op("Constant", value_t=torch.tensor([2], dtype=torch.int64)) |
| iWr = g.op("Constant", value_t=torch.tensor([3], dtype=torch.int64)) |
| nH = g.op("Gather", shp, iH, axis_i=0) |
| nWr = g.op("Gather", shp, iWr, axis_i=0) |
|
|
| one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
| two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64)) |
| nW = g.op("Mul", g.op("Sub", nWr, one), two) |
| Hf = g.op("Cast", nH, to_i=torch.onnx.TensorProtoDataType.FLOAT) |
| Wf = g.op("Cast", nW, to_i=torch.onnx.TensorProtoDataType.FLOAT) |
|
|
| yH = g.op("DFT", x, nH, axis_i=2, inverse_i=1, onesided_i=0) |
| yH = g.op("Mul", yH, g.op("Sqrt", Hf)) |
|
|
| start = g.op("Sub", nWr, two) |
| start = g.op( |
| "Squeeze", |
| start, |
| g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), |
| ) |
| limit = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) |
| step = g.op("Constant", value_t=torch.tensor(-1, dtype=torch.int64)) |
| idx_r = g.op("Range", start, limit, step) |
|
|
| mirW = g.op("Gather", yH, idx_r, axis_i=3) |
| maskW = g.op("Constant", value_t=torch.tensor([1.0, -1.0], dtype=torch.float32)) |
| maskW = g.op( |
| "Unsqueeze", |
| maskW, |
| g.op("Constant", value_t=torch.tensor([0, 1, 2, 3], dtype=torch.int64)), |
| ) |
| mirWc = g.op("Mul", mirW, maskW) |
| x_full = g.op("Concat", yH, mirWc, axis_i=3) |
|
|
| y = g.op("DFT", x_full, nW, axis_i=3, inverse_i=1, onesided_i=0) |
| y = g.op("Mul", y, g.op("Sqrt", Wf)) |
|
|
| s0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) |
| s1 = g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) |
| axC = g.op("Constant", value_t=torch.tensor([4], dtype=torch.int64)) |
| y = g.op("Slice", y, s0, s1, axC) |
| y = g.op("Squeeze", y, axC) |
|
|
| return y |
|
|
|
|
| class CustomRfft2Wrap(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
|
|
| def forward(self, x): |
| if self.training: |
| y = torch.fft.rfft2(x, dim=(2, 3), norm="ortho") |
| return torch.view_as_real(y) |
| else: |
| return CustomRFFT2().apply(x) |
|
|
|
|
| class CustomIrfft2Wrap(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
|
|
| def forward(self, x): |
| if self.training: |
| x_c = torch.view_as_complex(x) |
| return torch.fft.irfft2(x_c, dim=(2, 3), norm="ortho") |
| else: |
| return CustomIRFFT2().apply(x) |
|
|
|
|
| class FourierUnit(nn.Module): |
| def __init__(self, in_channels: int = 48, out_channels: int = 48) -> None: |
| super().__init__() |
| self.rn = RMSNorm(out_channels * 2) |
| self.post_norm = RMSNorm(out_channels) |
|
|
| self.fdc = nn.Conv2d( |
| in_channels=in_channels * 2, |
| out_channels=out_channels * 2, |
| kernel_size=1, |
| bias=True, |
| ) |
|
|
| self.fpe = nn.Conv2d( |
| in_channels=in_channels * 2, |
| out_channels=in_channels * 2, |
| kernel_size=3, |
| padding=1, |
| groups=in_channels * 2, |
| bias=True, |
| ) |
| self.gelu = nn.GELU() |
| self.irfft2 = CustomIrfft2Wrap() |
| self.rfft2 = CustomRfft2Wrap() |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| orig_dtype = x.dtype |
| x = x.to(torch.float32) |
| b, c, h, w = x.shape |
| ffted = self.rfft2(x) |
| ffted = ffted.permute(0, 4, 1, 2, 3).contiguous() |
| ffted = ffted.view(b, c * 2, h, -1).to(orig_dtype) |
| ffted = self.rn(ffted) |
| ffted = self.fpe(ffted) + ffted |
| ffted = self.fdc(ffted) |
| ffted = self.gelu(ffted) |
| ffted = ffted.view(b, c, 2, h, -1).permute(0, 1, 3, 4, 2).contiguous().float() |
| out = self.irfft2(ffted) |
| out = self.post_norm(out.to(orig_dtype)) |
| return out |
|
|
|
|
| class InceptionConv2d(nn.Module): |
| """Inception convolution""" |
|
|
| def __init__( |
| self, |
| fu_dim: int = 24, |
| gc: int = 8, |
| square_kernel_size: int = 13, |
| band_kernel_size: int = 17, |
| ) -> None: |
| super().__init__() |
|
|
| self.fu = FourierUnit(fu_dim, fu_dim) |
| self.convhw = nn.Conv2d( |
| gc, gc, square_kernel_size, padding=square_kernel_size // 2 |
| ) |
| self.convw = nn.Conv2d( |
| gc, |
| gc, |
| kernel_size=(1, band_kernel_size), |
| padding=(0, band_kernel_size // 2), |
| ) |
| self.convh = nn.Conv2d( |
| gc, |
| gc, |
| kernel_size=(band_kernel_size, 1), |
| padding=(band_kernel_size // 2, 0), |
| ) |
|
|
| def forward( |
| self, x: Tensor, x_hw: Tensor, x_w: Tensor, xh: Tensor |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor]: |
| return self.fu(x), self.convhw(x_hw), self.convw(x_w), self.convh(xh) |
|
|
|
|
| class GatedCNNBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int = 64, |
| expansion_ratio: float = 8 / 3, |
| gc: int = 8, |
| square_kernel_size: int = 13, |
| band_kernel_size: int = 17, |
| ) -> None: |
| super().__init__() |
| hidden = int(expansion_ratio * dim) // 8 * 8 |
| self.norm = RMSNorm(dim) |
| self.fc1 = nn.Conv2d(dim, hidden * 2, 3, 1, 1) |
| self.act = nn.SiLU() |
| self.split_indices = [hidden, hidden - dim, dim - gc * 3, gc, gc, gc] |
| self.conv = InceptionConv2d( |
| dim - gc * 3, gc, square_kernel_size, band_kernel_size |
| ) |
| self.fc2 = nn.Conv2d(hidden, dim, 3, 1, 1) |
|
|
| def gated_forward(self, x: Tensor) -> Tensor: |
| x = self.norm(x) |
| x = self.fc1(x) |
| g, i, c, c_hw, c_w, c_h = torch.split(x, self.split_indices, dim=1) |
| c, c_hw, c_w, c_h = self.conv(c, c_hw, c_w, c_h) |
| x = self.fc2(self.act(g) * torch.cat((i, c, c_hw, c_w, c_h), dim=1)) |
| return x |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.gated_forward(x) + x |
|
|
|
|
| |
| class FIGSR(nn.Module): |
| """Fourier Inception Gated Super Resolution""" |
|
|
| def __init__( |
| self, |
| in_nc: int = 3, |
| dim: int = 48, |
| expansion_ratio: float = 8 / 3, |
| scale: int = 4, |
| |
| |
| out_nc: int = 3, |
| upsampler: SampleMods = "pixelshuffledirect", |
| mid_dim: int = 32, |
| n_blocks: int = 24, |
| gc: int = 8, |
| square_kernel_size: int = 13, |
| band_kernel_size: int = 17, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.in_to_dim = nn.Conv2d(in_nc, dim, 3, 1, 1) |
| self.pad = 2 |
| self.gfisr_body_half = nn.Sequential( |
| *[ |
| GatedCNNBlock( |
| dim, expansion_ratio, gc, square_kernel_size, band_kernel_size |
| ) |
| for _ in range(n_blocks // 2) |
| ] |
| ) |
| self.gfisr_body_half_2 = nn.Sequential( |
| *[ |
| GatedCNNBlock( |
| dim, expansion_ratio, gc, square_kernel_size, band_kernel_size |
| ) |
| for _ in range(n_blocks - n_blocks // 2) |
| ] |
| + [nn.Conv2d(dim, dim, 3, 1, 1)] |
| ) |
| self.cat_to_dim = nn.Conv2d(dim * 3, dim, 1) |
| self.upscale = UniUpsampleV3( |
| upsampler, scale, dim, out_nc, mid_dim, dysample_end_kernel=3 |
| ) |
| if upsampler == "pixelshuffledirect": |
| weight = ICNR( |
| self.upscale[0].weight, |
| initializer=nn.init.kaiming_normal_, |
| upscale_factor=scale, |
| ) |
| self.upscale[0].weight.data.copy_(weight) |
|
|
| self.scale = scale |
| self.shift = nn.Parameter(torch.ones(1, 3, 1, 1) * 0.5, requires_grad=True) |
| self.scale_norm = nn.Parameter(torch.ones(1, 3, 1, 1) / 6, requires_grad=True) |
|
|
| def load_state_dict(self, state_dict, strict=True, assign=True): |
| state_dict["upscale.MetaUpsample"] = self.upscale.MetaUpsample |
| return super().load_state_dict(state_dict, strict, assign) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = (x - self.shift) / self.scale_norm |
|
|
| _, _, H, W = x.shape |
| mod_pad_h = (self.pad - H % self.pad) % self.pad |
| mod_pad_w = (self.pad - W % self.pad) % self.pad |
| x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") |
|
|
| x = self.in_to_dim(x) |
| x0 = self.gfisr_body_half(x) |
| x1 = self.gfisr_body_half_2(x0) |
|
|
| x = self.cat_to_dim(torch.cat([x1, x, x0], dim=1)) |
| x = self.upscale(x)[:, :, : H * self.scale, : W * self.scale] |
| return x * self.scale_norm + self.shift |
|
|