From 4bd03c72f31e014c8e5b1d6fb55bb1bc224c7aab Mon Sep 17 00:00:00 2001 From: xiaoyuxi Date: Tue, 8 Jul 2025 15:56:48 +0800 Subject: [PATCH] first-commit --- .../models/depth_refiner/backbone.py | 472 ------- .../models/depth_refiner/decode_head.py | 619 --------- .../models/depth_refiner/depth_refiner.py | 115 -- .../models/depth_refiner/network.py | 429 ------ .../depth_refiner/stablilization_attention.py | 1187 ----------------- .../models/depth_refiner/stablizer.py | 342 ----- 6 files changed, 3164 deletions(-) delete mode 100644 models/SpaTrackV2/models/depth_refiner/backbone.py delete mode 100644 models/SpaTrackV2/models/depth_refiner/decode_head.py delete mode 100644 models/SpaTrackV2/models/depth_refiner/depth_refiner.py delete mode 100644 models/SpaTrackV2/models/depth_refiner/network.py delete mode 100644 models/SpaTrackV2/models/depth_refiner/stablilization_attention.py delete mode 100644 models/SpaTrackV2/models/depth_refiner/stablizer.py diff --git a/models/SpaTrackV2/models/depth_refiner/backbone.py b/models/SpaTrackV2/models/depth_refiner/backbone.py deleted file mode 100644 index 8ccec44..0000000 --- a/models/SpaTrackV2/models/depth_refiner/backbone.py +++ /dev/null @@ -1,472 +0,0 @@ -# --------------------------------------------------------------- -# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. -# -# This work is licensed under the NVIDIA Source Code License -# --------------------------------------------------------------- -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial - -from timm.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models import register_model -from timm.models.vision_transformer import _cfg -import math - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.dwconv = DWConv(hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - x = self.fc1(x) - x = self.dwconv(x, H, W) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): - super().__init__() - assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." - - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.sr_ratio = sr_ratio - if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = nn.LayerNorm(dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - B, N, C = x.shape - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - - return x - - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - - return x - - -class OverlapPatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - self.img_size = img_size - self.patch_size = patch_size - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] - self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = nn.LayerNorm(embed_dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x): - x = self.proj(x) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - - return x, H, W - - - - -class OverlapPatchEmbed43(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - self.img_size = img_size - self.patch_size = patch_size - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] - self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = nn.LayerNorm(embed_dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x): - if x.shape[1]==4: - x = self.proj_4c(x) - else: - x = self.proj(x) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - - return x, H, W - -class MixVisionTransformer(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], - num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): - super().__init__() - self.num_classes = num_classes - self.depths = depths - - # patch_embed 43 - self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, - embed_dim=embed_dims[0]) - self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], - embed_dim=embed_dims[1]) - self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], - embed_dim=embed_dims[2]) - self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], - embed_dim=embed_dims[3]) - - # transformer encoder - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - cur = 0 - self.block1 = nn.ModuleList([Block( - dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[0]) - for i in range(depths[0])]) - self.norm1 = norm_layer(embed_dims[0]) - - cur += depths[0] - self.block2 = nn.ModuleList([Block( - dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[1]) - for i in range(depths[1])]) - self.norm2 = norm_layer(embed_dims[1]) - - cur += depths[1] - self.block3 = nn.ModuleList([Block( - dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[2]) - for i in range(depths[2])]) - self.norm3 = norm_layer(embed_dims[2]) - - cur += depths[2] - self.block4 = nn.ModuleList([Block( - dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[3]) - for i in range(depths[3])]) - self.norm4 = norm_layer(embed_dims[3]) - - # classification head - # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) - - def reset_drop_path(self, drop_path_rate): - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] - cur = 0 - for i in range(self.depths[0]): - self.block1[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[0] - for i in range(self.depths[1]): - self.block2[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[1] - for i in range(self.depths[2]): - self.block3[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[2] - for i in range(self.depths[3]): - self.block4[i].drop_path.drop_prob = dpr[cur + i] - - def freeze_patch_emb(self): - self.patch_embed1.requires_grad = False - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better - - def get_classifier(self): - return self.head - - def reset_classifier(self, num_classes, global_pool=''): - self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - - def forward_features(self, x): - B = x.shape[0] - outs = [] - - # stage 1 - x, H, W = self.patch_embed1(x) - for i, blk in enumerate(self.block1): - x = blk(x, H, W) - x = self.norm1(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 2 - x, H, W = self.patch_embed2(x) - for i, blk in enumerate(self.block2): - x = blk(x, H, W) - x = self.norm2(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 3 - x, H, W = self.patch_embed3(x) - for i, blk in enumerate(self.block3): - x = blk(x, H, W) - x = self.norm3(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 4 - x, H, W = self.patch_embed4(x) - for i, blk in enumerate(self.block4): - x = blk(x, H, W) - x = self.norm4(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - return outs - - def forward(self, x): - if x.dim() == 5: - x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4]) - x = self.forward_features(x) - # x = self.head(x) - - return x - - -class DWConv(nn.Module): - def __init__(self, dim=768): - super(DWConv, self).__init__() - self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - - def forward(self, x, H, W): - B, N, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W) - x = self.dwconv(x) - x = x.flatten(2).transpose(1, 2) - - return x - - - -#@BACKBONES.register_module() -class mit_b0(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b0, self).__init__( - patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -#@BACKBONES.register_module() -class mit_b1(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b1, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -#@BACKBONES.register_module() -class mit_b2(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b2, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -#@BACKBONES.register_module() -class mit_b3(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b3, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -#@BACKBONES.register_module() -class mit_b4(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b4, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -#@BACKBONES.register_module() -class mit_b5(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b5, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - diff --git a/models/SpaTrackV2/models/depth_refiner/decode_head.py b/models/SpaTrackV2/models/depth_refiner/decode_head.py deleted file mode 100644 index b733c36..0000000 --- a/models/SpaTrackV2/models/depth_refiner/decode_head.py +++ /dev/null @@ -1,619 +0,0 @@ -from abc import ABCMeta, abstractmethod - -import torch -import torch.nn as nn - -# from mmcv.cnn import normal_init -# from mmcv.runner import auto_fp16, force_fp32 - -# from mmseg.core import build_pixel_sampler -# from mmseg.ops import resize - - -class BaseDecodeHead(nn.Module, metaclass=ABCMeta): - """Base class for BaseDecodeHead. - - Args: - in_channels (int|Sequence[int]): Input channels. - channels (int): Channels after modules, before conv_seg. - num_classes (int): Number of classes. - dropout_ratio (float): Ratio of dropout layer. Default: 0.1. - conv_cfg (dict|None): Config of conv layers. Default: None. - norm_cfg (dict|None): Config of norm layers. Default: None. - act_cfg (dict): Config of activation layers. - Default: dict(type='ReLU') - in_index (int|Sequence[int]): Input feature index. Default: -1 - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - Default: None. - loss_decode (dict): Config of decode loss. - Default: dict(type='CrossEntropyLoss'). - ignore_index (int | None): The label index to be ignored. When using - masked BCE loss, ignore_index should be set to None. Default: 255 - sampler (dict|None): The config of segmentation map sampler. - Default: None. - align_corners (bool): align_corners argument of F.interpolate. - Default: False. - """ - - def __init__(self, - in_channels, - channels, - *, - num_classes, - dropout_ratio=0.1, - conv_cfg=None, - norm_cfg=None, - act_cfg=dict(type='ReLU'), - in_index=-1, - input_transform=None, - loss_decode=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - decoder_params=None, - ignore_index=255, - sampler=None, - align_corners=False): - super(BaseDecodeHead, self).__init__() - self._init_inputs(in_channels, in_index, input_transform) - self.channels = channels - self.num_classes = num_classes - self.dropout_ratio = dropout_ratio - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg - self.in_index = in_index - self.ignore_index = ignore_index - self.align_corners = align_corners - - if sampler is not None: - self.sampler = build_pixel_sampler(sampler, context=self) - else: - self.sampler = None - - self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) - if dropout_ratio > 0: - self.dropout = nn.Dropout2d(dropout_ratio) - else: - self.dropout = None - self.fp16_enabled = False - - def extra_repr(self): - """Extra repr.""" - s = f'input_transform={self.input_transform}, ' \ - f'ignore_index={self.ignore_index}, ' \ - f'align_corners={self.align_corners}' - return s - - def _init_inputs(self, in_channels, in_index, input_transform): - """Check and initialize input transforms. - - The in_channels, in_index and input_transform must match. - Specifically, when input_transform is None, only single feature map - will be selected. So in_channels and in_index must be of type int. - When input_transform - - Args: - in_channels (int|Sequence[int]): Input channels. - in_index (int|Sequence[int]): Input feature index. - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - """ - - if input_transform is not None: - assert input_transform in ['resize_concat', 'multiple_select'] - self.input_transform = input_transform - self.in_index = in_index - if input_transform is not None: - assert isinstance(in_channels, (list, tuple)) - assert isinstance(in_index, (list, tuple)) - assert len(in_channels) == len(in_index) - if input_transform == 'resize_concat': - self.in_channels = sum(in_channels) - else: - self.in_channels = in_channels - else: - assert isinstance(in_channels, int) - assert isinstance(in_index, int) - self.in_channels = in_channels - - def init_weights(self): - """Initialize weights of classification layer.""" - normal_init(self.conv_seg, mean=0, std=0.01) - - def _transform_inputs(self, inputs): - """Transform inputs for decoder. - - Args: - inputs (list[Tensor]): List of multi-level img features. - - Returns: - Tensor: The transformed inputs - """ - - if self.input_transform == 'resize_concat': - inputs = [inputs[i] for i in self.in_index] - upsampled_inputs = [ - resize( - input=x, - size=inputs[0].shape[2:], - mode='bilinear', - align_corners=self.align_corners) for x in inputs - ] - inputs = torch.cat(upsampled_inputs, dim=1) - elif self.input_transform == 'multiple_select': - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - # @auto_fp16() - @abstractmethod - def forward(self, inputs): - """Placeholder of forward function.""" - pass - - def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): - """Forward function for training. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - train_cfg (dict): The training config. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - seg_logits = self.forward(inputs) - losses = self.losses(seg_logits, gt_semantic_seg) - return losses - - def forward_test(self, inputs, img_metas, test_cfg): - """Forward function for testing. - - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - test_cfg (dict): The testing config. - - Returns: - Tensor: Output segmentation map. - """ - return self.forward(inputs) - - def cls_seg(self, feat): - """Classify each pixel.""" - if self.dropout is not None: - feat = self.dropout(feat) - output = self.conv_seg(feat) - return output - - -class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta): - """Base class for BaseDecodeHead_clips. - - Args: - in_channels (int|Sequence[int]): Input channels. - channels (int): Channels after modules, before conv_seg. - num_classes (int): Number of classes. - dropout_ratio (float): Ratio of dropout layer. Default: 0.1. - conv_cfg (dict|None): Config of conv layers. Default: None. - norm_cfg (dict|None): Config of norm layers. Default: None. - act_cfg (dict): Config of activation layers. - Default: dict(type='ReLU') - in_index (int|Sequence[int]): Input feature index. Default: -1 - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - Default: None. - loss_decode (dict): Config of decode loss. - Default: dict(type='CrossEntropyLoss'). - ignore_index (int | None): The label index to be ignored. When using - masked BCE loss, ignore_index should be set to None. Default: 255 - sampler (dict|None): The config of segmentation map sampler. - Default: None. - align_corners (bool): align_corners argument of F.interpolate. - Default: False. - """ - - def __init__(self, - in_channels, - channels, - *, - num_classes, - dropout_ratio=0.1, - conv_cfg=None, - norm_cfg=None, - act_cfg=dict(type='ReLU'), - in_index=-1, - input_transform=None, - loss_decode=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - decoder_params=None, - ignore_index=255, - sampler=None, - align_corners=False, - num_clips=5): - super(BaseDecodeHead_clips, self).__init__() - self._init_inputs(in_channels, in_index, input_transform) - self.channels = channels - self.num_classes = num_classes - self.dropout_ratio = dropout_ratio - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg - self.in_index = in_index - self.ignore_index = ignore_index - self.align_corners = align_corners - self.num_clips=num_clips - - if sampler is not None: - self.sampler = build_pixel_sampler(sampler, context=self) - else: - self.sampler = None - - self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) - if dropout_ratio > 0: - self.dropout = nn.Dropout2d(dropout_ratio) - else: - self.dropout = None - self.fp16_enabled = False - - def extra_repr(self): - """Extra repr.""" - s = f'input_transform={self.input_transform}, ' \ - f'ignore_index={self.ignore_index}, ' \ - f'align_corners={self.align_corners}' - return s - - def _init_inputs(self, in_channels, in_index, input_transform): - """Check and initialize input transforms. - - The in_channels, in_index and input_transform must match. - Specifically, when input_transform is None, only single feature map - will be selected. So in_channels and in_index must be of type int. - When input_transform - - Args: - in_channels (int|Sequence[int]): Input channels. - in_index (int|Sequence[int]): Input feature index. - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - """ - - if input_transform is not None: - assert input_transform in ['resize_concat', 'multiple_select'] - self.input_transform = input_transform - self.in_index = in_index - if input_transform is not None: - assert isinstance(in_channels, (list, tuple)) - assert isinstance(in_index, (list, tuple)) - assert len(in_channels) == len(in_index) - if input_transform == 'resize_concat': - self.in_channels = sum(in_channels) - else: - self.in_channels = in_channels - else: - assert isinstance(in_channels, int) - assert isinstance(in_index, int) - self.in_channels = in_channels - - def init_weights(self): - """Initialize weights of classification layer.""" - normal_init(self.conv_seg, mean=0, std=0.01) - - def _transform_inputs(self, inputs): - """Transform inputs for decoder. - - Args: - inputs (list[Tensor]): List of multi-level img features. - - Returns: - Tensor: The transformed inputs - """ - - if self.input_transform == 'resize_concat': - inputs = [inputs[i] for i in self.in_index] - upsampled_inputs = [ - resize( - input=x, - size=inputs[0].shape[2:], - mode='bilinear', - align_corners=self.align_corners) for x in inputs - ] - inputs = torch.cat(upsampled_inputs, dim=1) - elif self.input_transform == 'multiple_select': - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - # @auto_fp16() - @abstractmethod - def forward(self, inputs): - """Placeholder of forward function.""" - pass - - def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips): - """Forward function for training. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - train_cfg (dict): The training config. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - seg_logits = self.forward(inputs,batch_size, num_clips) - losses = self.losses(seg_logits, gt_semantic_seg) - return losses - - def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips): - """Forward function for testing. - - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - test_cfg (dict): The testing config. - - Returns: - Tensor: Output segmentation map. - """ - return self.forward(inputs, batch_size, num_clips) - - def cls_seg(self, feat): - """Classify each pixel.""" - if self.dropout is not None: - feat = self.dropout(feat) - output = self.conv_seg(feat) - return output - -class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta): - """Base class for BaseDecodeHead_clips_flow. - - Args: - in_channels (int|Sequence[int]): Input channels. - channels (int): Channels after modules, before conv_seg. - num_classes (int): Number of classes. - dropout_ratio (float): Ratio of dropout layer. Default: 0.1. - conv_cfg (dict|None): Config of conv layers. Default: None. - norm_cfg (dict|None): Config of norm layers. Default: None. - act_cfg (dict): Config of activation layers. - Default: dict(type='ReLU') - in_index (int|Sequence[int]): Input feature index. Default: -1 - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - Default: None. - loss_decode (dict): Config of decode loss. - Default: dict(type='CrossEntropyLoss'). - ignore_index (int | None): The label index to be ignored. When using - masked BCE loss, ignore_index should be set to None. Default: 255 - sampler (dict|None): The config of segmentation map sampler. - Default: None. - align_corners (bool): align_corners argument of F.interpolate. - Default: False. - """ - - def __init__(self, - in_channels, - channels, - *, - num_classes, - dropout_ratio=0.1, - conv_cfg=None, - norm_cfg=None, - act_cfg=dict(type='ReLU'), - in_index=-1, - input_transform=None, - loss_decode=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - decoder_params=None, - ignore_index=255, - sampler=None, - align_corners=False, - num_clips=5): - super(BaseDecodeHead_clips_flow, self).__init__() - self._init_inputs(in_channels, in_index, input_transform) - self.channels = channels - self.num_classes = num_classes - self.dropout_ratio = dropout_ratio - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg - self.in_index = in_index - self.ignore_index = ignore_index - self.align_corners = align_corners - self.num_clips=num_clips - - if sampler is not None: - self.sampler = build_pixel_sampler(sampler, context=self) - else: - self.sampler = None - - self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) - if dropout_ratio > 0: - self.dropout = nn.Dropout2d(dropout_ratio) - else: - self.dropout = None - self.fp16_enabled = False - - def extra_repr(self): - """Extra repr.""" - s = f'input_transform={self.input_transform}, ' \ - f'ignore_index={self.ignore_index}, ' \ - f'align_corners={self.align_corners}' - return s - - def _init_inputs(self, in_channels, in_index, input_transform): - """Check and initialize input transforms. - - The in_channels, in_index and input_transform must match. - Specifically, when input_transform is None, only single feature map - will be selected. So in_channels and in_index must be of type int. - When input_transform - - Args: - in_channels (int|Sequence[int]): Input channels. - in_index (int|Sequence[int]): Input feature index. - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - """ - - if input_transform is not None: - assert input_transform in ['resize_concat', 'multiple_select'] - self.input_transform = input_transform - self.in_index = in_index - if input_transform is not None: - assert isinstance(in_channels, (list, tuple)) - assert isinstance(in_index, (list, tuple)) - assert len(in_channels) == len(in_index) - if input_transform == 'resize_concat': - self.in_channels = sum(in_channels) - else: - self.in_channels = in_channels - else: - assert isinstance(in_channels, int) - assert isinstance(in_index, int) - self.in_channels = in_channels - - def init_weights(self): - """Initialize weights of classification layer.""" - normal_init(self.conv_seg, mean=0, std=0.01) - - def _transform_inputs(self, inputs): - """Transform inputs for decoder. - - Args: - inputs (list[Tensor]): List of multi-level img features. - - Returns: - Tensor: The transformed inputs - """ - - if self.input_transform == 'resize_concat': - inputs = [inputs[i] for i in self.in_index] - upsampled_inputs = [ - resize( - input=x, - size=inputs[0].shape[2:], - mode='bilinear', - align_corners=self.align_corners) for x in inputs - ] - inputs = torch.cat(upsampled_inputs, dim=1) - elif self.input_transform == 'multiple_select': - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - # @auto_fp16() - @abstractmethod - def forward(self, inputs): - """Placeholder of forward function.""" - pass - - def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None): - """Forward function for training. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - train_cfg (dict): The training config. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - seg_logits = self.forward(inputs,batch_size, num_clips,img) - losses = self.losses(seg_logits, gt_semantic_seg) - return losses - - def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None): - """Forward function for testing. - - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - test_cfg (dict): The testing config. - - Returns: - Tensor: Output segmentation map. - """ - return self.forward(inputs, batch_size, num_clips,img) - - def cls_seg(self, feat): - """Classify each pixel.""" - if self.dropout is not None: - feat = self.dropout(feat) - output = self.conv_seg(feat) - return output \ No newline at end of file diff --git a/models/SpaTrackV2/models/depth_refiner/depth_refiner.py b/models/SpaTrackV2/models/depth_refiner/depth_refiner.py deleted file mode 100644 index 4a98a82..0000000 --- a/models/SpaTrackV2/models/depth_refiner/depth_refiner.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed -from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3 -from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention -from einops import rearrange -class TrackStablizer(nn.Module): - def __init__(self): - super().__init__() - - self.backbone = mit_b3() - - old_conv = self.backbone.patch_embed1.proj - new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding) - - new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone()) - self.backbone.patch_embed1.proj = new_conv - - self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512], - in_index=[0, 1, 2, 3], - feature_strides=[4, 8, 16, 32], - channels=128, - dropout_ratio=0.1, - num_classes=1, - align_corners=False, - decoder_params=dict(embed_dim=256, depths=4), - num_clips=16, - norm_cfg = dict(type='SyncBN', requires_grad=True)) - - self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\ - nn.ReLU(inplace=True)) - self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\ - nn.ReLU(inplace=True)) - self.success = False - self.x = None - - def buffer_forward(self, inputs, num_clips=16): - """ - buffer forward for getting the pointmap and image features - """ - B, T, C, H, W = inputs.shape - self.x = self.backbone(inputs) - scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips) - self.success = True - return scale, shift - - def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None): - - """ - Args: - inputs: [B, T, C, H, W], RGB + PointMap + Mask - tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility - num_clips: int, number of clips to use - """ - B, T, C, H, W = inputs.shape - edge_feat = self.edge_conv(inputs.view(B*T,4,H,W)) - edge_feat1 = self.edge_conv1(edge_feat) - - if not self.success: - scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips) - self.success = True - update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track) - else: - update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track) - - return update - - def reset_success(self): - self.success = False - self.x = None - self.Track_Stabilizer.reset_success() - - -if __name__ == "__main__": - # Create test input tensors - batch_size = 1 - seq_len = 16 - channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask - height = 384 - width = 512 - - # Create random input tensor with shape [B, T, C, H, W] - inputs = torch.randn(batch_size, seq_len, channels, height, width) - - # Create random tracks - tracks = torch.randn(batch_size, seq_len, 1024, 4) - - # Create random test images - test_imgs = torch.randn(batch_size, seq_len, 3, height, width) - - # Initialize model and move to GPU - model = TrackStablizer().cuda() - - # Move inputs to GPU and run forward pass - inputs = inputs.cuda() - tracks = tracks.cuda() - outputs = model.buffer_forward(inputs, num_clips=seq_len) - import time - start_time = time.time() - outputs = model(inputs, tracks, num_clips=seq_len) - end_time = time.time() - print(f"Time taken: {end_time - start_time} seconds") - import pdb; pdb.set_trace() - # # Print shapes for verification - # print(f"Input shape: {inputs.shape}") - # print(f"Output shape: {outputs.shape}") - - # # Basic tests - # assert outputs.shape[0] == batch_size, "Batch size mismatch" - # assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]" - # assert torch.all(outputs >= 0), "Output should be non-negative after ReLU" - - # print("All tests passed!") - diff --git a/models/SpaTrackV2/models/depth_refiner/network.py b/models/SpaTrackV2/models/depth_refiner/network.py deleted file mode 100644 index a9e70f0..0000000 --- a/models/SpaTrackV2/models/depth_refiner/network.py +++ /dev/null @@ -1,429 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -''' -Author: Ke Xian -Email: kexian@hust.edu.cn -Date: 2020/07/20 -''' - -import torch -import torch.nn as nn -import torch.nn.init as init - -# ============================================================================================================== - -class FTB(nn.Module): - def __init__(self, inchannels, midchannels=512): - super(FTB, self).__init__() - self.in1 = inchannels - self.mid = midchannels - - self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True) - self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\ - nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\ - #nn.BatchNorm2d(num_features=self.mid),\ - nn.ReLU(inplace=True),\ - nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True)) - self.relu = nn.ReLU(inplace=True) - - self.init_params() - - def forward(self, x): - x = self.conv1(x) - x = x + self.conv_branch(x) - x = self.relu(x) - - return x - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - # init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - # init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - -class ATA(nn.Module): - def __init__(self, inchannels, reduction = 8): - super(ATA, self).__init__() - self.inchannels = inchannels - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction), - nn.ReLU(inplace=True), - nn.Linear(self.inchannels // reduction, self.inchannels), - nn.Sigmoid()) - self.init_params() - - def forward(self, low_x, high_x): - n, c, _, _ = low_x.size() - x = torch.cat([low_x, high_x], 1) - x = self.avg_pool(x) - x = x.view(n, -1) - x = self.fc(x).view(n,c,1,1) - x = low_x * x + high_x - - return x - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - #init.normal(m.weight, std=0.01) - init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - #init.normal_(m.weight, std=0.01) - init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - - -class FFM(nn.Module): - def __init__(self, inchannels, midchannels, outchannels, upfactor=2): - super(FFM, self).__init__() - self.inchannels = inchannels - self.midchannels = midchannels - self.outchannels = outchannels - self.upfactor = upfactor - - self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) - self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) - - self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) - - self.init_params() - #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False) - #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False) - #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False) - - def forward(self, low_x, high_x): - x = self.ftb1(low_x) - - ''' - x = torch.cat((x,high_x),1) - if x.shape[2] == 12: - x = self.p1(x) - elif x.shape[2] == 24: - x = self.p2(x) - elif x.shape[2] == 48: - x = self.p3(x) - ''' - x = x + high_x ###high_x - x = self.ftb2(x) - x = self.upsample(x) - - return x - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - - - -class noFFM(nn.Module): - def __init__(self, inchannels, midchannels, outchannels, upfactor=2): - super(noFFM, self).__init__() - self.inchannels = inchannels - self.midchannels = midchannels - self.outchannels = outchannels - self.upfactor = upfactor - - self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) - - self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) - - self.init_params() - #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False) - #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False) - #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False) - - def forward(self, low_x, high_x): - - #x = self.ftb1(low_x) - x = high_x ###high_x - x = self.ftb2(x) - x = self.upsample(x) - - return x - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - - - - -class AO(nn.Module): - # Adaptive output module - def __init__(self, inchannels, outchannels, upfactor=2): - super(AO, self).__init__() - self.inchannels = inchannels - self.outchannels = outchannels - self.upfactor = upfactor - - """ - self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\ - nn.BatchNorm2d(num_features=self.inchannels//2),\ - nn.ReLU(inplace=True),\ - nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\ - nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\ - #nn.ReLU(inplace=True)) ## get positive values - """ - self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\ - #nn.BatchNorm2d(num_features=self.inchannels//2),\ - nn.ReLU(inplace=True),\ - nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \ - nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1)) - - #nn.ReLU(inplace=True)) ## get positive values - - self.init_params() - - def forward(self, x): - x = self.adapt_conv(x) - return x - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - -class ASPP(nn.Module): - def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]): - super(ASPP, self).__init__() - self.inchannels = inchannels - self.planes = planes - self.rates = rates - self.kernel_sizes = [] - self.paddings = [] - for rate in self.rates: - if rate == 1: - self.kernel_sizes.append(1) - self.paddings.append(0) - else: - self.kernel_sizes.append(3) - self.paddings.append(rate) - self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0], - stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True), - nn.ReLU(inplace=True), - nn.BatchNorm2d(num_features=self.planes) - ) - self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1], - stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True), - nn.ReLU(inplace=True), - nn.BatchNorm2d(num_features=self.planes), - ) - self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2], - stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True), - nn.ReLU(inplace=True), - nn.BatchNorm2d(num_features=self.planes), - ) - self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3], - stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True), - nn.ReLU(inplace=True), - nn.BatchNorm2d(num_features=self.planes), - ) - - #self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True) - def forward(self, x): - x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1) - #x = self.conv(x) - - return x - -# ============================================================================================================== - - -class ResidualConv(nn.Module): - def __init__(self, inchannels): - super(ResidualConv, self).__init__() - #nn.BatchNorm2d - self.conv = nn.Sequential( - #nn.BatchNorm2d(num_features=inchannels), - nn.ReLU(inplace=False), - #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True), - #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True) - nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False), - nn.BatchNorm2d(num_features=inchannels//2), - nn.ReLU(inplace=False), - nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False) - ) - self.init_params() - - def forward(self, x): - x = self.conv(x)+x - return x - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - - -class FeatureFusion(nn.Module): - def __init__(self, inchannels, outchannels): - super(FeatureFusion, self).__init__() - self.conv = ResidualConv(inchannels=inchannels) - #nn.BatchNorm2d - self.up = nn.Sequential(ResidualConv(inchannels=inchannels), - nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1), - nn.BatchNorm2d(num_features=outchannels), - nn.ReLU(inplace=True)) - - def forward(self, lowfeat, highfeat): - return self.up(highfeat + self.conv(lowfeat)) - - def init_params(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.ConvTranspose2d): - #init.kaiming_normal_(m.weight, mode='fan_out') - init.normal_(m.weight, std=0.01) - #init.xavier_normal_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d - init.constant_(m.weight, 1) - init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal_(m.weight, std=0.01) - if m.bias is not None: - init.constant_(m.bias, 0) - - -class SenceUnderstand(nn.Module): - def __init__(self, channels): - super(SenceUnderstand, self).__init__() - self.channels = channels - self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), - nn.ReLU(inplace = True)) - self.pool = nn.AdaptiveAvgPool2d(8) - self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels), - nn.ReLU(inplace = True)) - self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0), - nn.ReLU(inplace=True)) - self.initial_params() - - def forward(self, x): - n,c,h,w = x.size() - x = self.conv1(x) - x = self.pool(x) - x = x.view(n,-1) - x = self.fc(x) - x = x.view(n, self.channels, 1, 1) - x = self.conv2(x) - x = x.repeat(1,1,h,w) - return x - - def initial_params(self, dev=0.01): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - #print torch.sum(m.weight) - m.weight.data.normal_(0, dev) - if m.bias is not None: - m.bias.data.fill_(0) - elif isinstance(m, nn.ConvTranspose2d): - #print torch.sum(m.weight) - m.weight.data.normal_(0, dev) - if m.bias is not None: - m.bias.data.fill_(0) - elif isinstance(m, nn.Linear): - m.weight.data.normal_(0, dev) diff --git a/models/SpaTrackV2/models/depth_refiner/stablilization_attention.py b/models/SpaTrackV2/models/depth_refiner/stablilization_attention.py deleted file mode 100644 index 0b2f27a..0000000 --- a/models/SpaTrackV2/models/depth_refiner/stablilization_attention.py +++ /dev/null @@ -1,1187 +0,0 @@ -import math -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.layers import DropPath, to_2tuple, trunc_normal_ -from einops import rearrange - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - -def window_partition_noreshape(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (B, num_windows_h, num_windows_w, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - return windows - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -def get_roll_masks(H, W, window_size, shift_size): - ##################################### - # move to top-left - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, H-window_size), - slice(H-window_size, H-shift_size), - slice(H-shift_size, H)) - w_slices = (slice(0, W-window_size), - slice(W-window_size, W-shift_size), - slice(W-shift_size, W)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, window_size * window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - #################################### - # move to top right - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, H-window_size), - slice(H-window_size, H-shift_size), - slice(H-shift_size, H)) - w_slices = (slice(0, shift_size), - slice(shift_size, window_size), - slice(window_size, W)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, window_size * window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - #################################### - # move to bottom left - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, shift_size), - slice(shift_size, window_size), - slice(window_size, H)) - w_slices = (slice(0, W-window_size), - slice(W-window_size, W-shift_size), - slice(W-shift_size, W)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, window_size * window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - #################################### - # move to bottom right - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, shift_size), - slice(shift_size, window_size), - slice(window_size, H)) - w_slices = (slice(0, shift_size), - slice(shift_size, window_size), - slice(window_size, W)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, window_size * window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - # append all - attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1) - return attn_mask_all - -def get_relative_position_index(q_windows, k_windows): - """ - Args: - q_windows: tuple (query_window_height, query_window_width) - k_windows: tuple (key_window_height, key_window_width) - - Returns: - relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width - """ - # get pair-wise relative position index for each token inside the window - coords_h_q = torch.arange(q_windows[0]) - coords_w_q = torch.arange(q_windows[1]) - coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q - - coords_h_k = torch.arange(k_windows[0]) - coords_w_k = torch.arange(k_windows[1]) - coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww - - coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q - coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k - - relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2 - relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += k_windows[1] - 1 - relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1 - relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k - return relative_position_index - -def get_relative_position_index3d(q_windows, k_windows, num_clips): - """ - Args: - q_windows: tuple (query_window_height, query_window_width) - k_windows: tuple (key_window_height, key_window_width) - - Returns: - relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width - """ - # get pair-wise relative position index for each token inside the window - coords_d_q = torch.arange(num_clips) - coords_h_q = torch.arange(q_windows[0]) - coords_w_q = torch.arange(q_windows[1]) - coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q - - coords_d_k = torch.arange(num_clips) - coords_h_k = torch.arange(k_windows[0]) - coords_w_k = torch.arange(k_windows[1]) - coords_k = torch.stack(torch.meshgrid([coords_d_k, coords_h_k, coords_w_k])) # 2, Wh, Ww - - coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q - coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k - - relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2 - relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0 - relative_coords[:, :, 1] += k_windows[0] - 1 - relative_coords[:, :, 2] += k_windows[1] - 1 - relative_coords[:, :, 0] *= (q_windows[0] + k_windows[0] - 1)*(q_windows[1] + k_windows[1] - 1) - relative_coords[:, :, 1] *= (q_windows[1] + k_windows[1] - 1) - relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k - return relative_position_index - - -class WindowAttention3d3(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - - Args: - dim (int): Number of input channels. - expand_size (int): The expand size at focal level 1. - window_size (tuple[int]): The height and width of the window. - focal_window (int): Focal region size. - focal_level (int): Focal attention level. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pool_method (str): window pooling method. Default: none - """ - - def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads, - qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none", focal_l_clips=[7,1,2], focal_kernel_clips=[7,5,3]): - - super().__init__() - self.dim = dim - self.expand_size = expand_size - self.window_size = window_size # Wh, Ww - self.pool_method = pool_method - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - self.focal_level = focal_level - self.focal_window = focal_window - - # define a parameter table of relative position bias for each window - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - num_clips=4 - # # define a parameter table of relative position bias - # self.relative_position_bias_table = nn.Parameter( - # torch.zeros((2 * num_clips - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH - - # # get pair-wise relative position index for each token inside the window - # coords_d = torch.arange(num_clips) - # coords_h = torch.arange(self.window_size[0]) - # coords_w = torch.arange(self.window_size[1]) - # coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww - # coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww - # relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww - # relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 - # relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0 - # relative_coords[:, :, 1] += self.window_size[0] - 1 - # relative_coords[:, :, 2] += self.window_size[1] - 1 - - # relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) - # relative_coords[:, :, 1] *= (2 * self.window_size[1] - 1) - # relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww - # self.register_buffer("relative_position_index", relative_position_index) - - - if self.expand_size > 0 and focal_level > 0: - # define a parameter table of position bias between window and its fine-grained surroundings - self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \ - (4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * (self.window_size[0] - self.expand_size)) - self.relative_position_bias_table_to_neighbors = nn.Parameter( - torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1], self.window_size_of_key)) # Wh*Ww, nH, nSurrounding - trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02) - - # get mask for rolled k and rolled v - mask_tl = torch.ones(self.window_size[0], self.window_size[1]); mask_tl[:-self.expand_size, :-self.expand_size] = 0 - mask_tr = torch.ones(self.window_size[0], self.window_size[1]); mask_tr[:-self.expand_size, self.expand_size:] = 0 - mask_bl = torch.ones(self.window_size[0], self.window_size[1]); mask_bl[self.expand_size:, :-self.expand_size] = 0 - mask_br = torch.ones(self.window_size[0], self.window_size[1]); mask_br[self.expand_size:, self.expand_size:] = 0 - mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0) - self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1)) - - if pool_method != "none" and focal_level > 1: - #self.relative_position_bias_table_to_windows = nn.ParameterList() - #self.relative_position_bias_table_to_windows_clips = nn.ParameterList() - #self.register_parameter('relative_position_bias_table_to_windows',[]) - #self.register_parameter('relative_position_bias_table_to_windows_clips',[]) - self.unfolds = nn.ModuleList() - self.unfolds_clips=nn.ModuleList() - - # build relative position bias between local patch and pooled windows - for k in range(focal_level-1): - stride = 2**k - kernel_size = 2*(self.focal_window // 2) + 2**k + (2**k-1) - # define unfolding operations - self.unfolds += [nn.Unfold( - kernel_size=(kernel_size, kernel_size), - stride=stride, padding=kernel_size // 2) - ] - - # define relative position bias table - relative_position_bias_table_to_windows = nn.Parameter( - torch.zeros( - self.num_heads, - (self.window_size[0] + self.focal_window + 2**k - 2) * (self.window_size[1] + self.focal_window + 2**k - 2), - ) - ) - trunc_normal_(relative_position_bias_table_to_windows, std=.02) - #self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows) - self.register_parameter('relative_position_bias_table_to_windows_{}'.format(k),relative_position_bias_table_to_windows) - - # define relative position bias index - relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(self.focal_window + 2**k - 1)) - # relative_position_index_k = get_relative_position_index3d(self.window_size, to_2tuple(self.focal_window + 2**k - 1), num_clips) - self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k) - - # define unfolding index for focal_level > 0 - if k > 0: - mask = torch.zeros(kernel_size, kernel_size); mask[(2**k)-1:, (2**k)-1:] = 1 - self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1)) - - for k in range(len(focal_l_clips)): - # kernel_size=focal_kernel_clips[k] - focal_l_big_flag=False - if focal_l_clips[k]>self.window_size[0]: - stride=1 - padding=0 - kernel_size=focal_kernel_clips[k] - kernel_size_true=kernel_size - focal_l_big_flag=True - # stride=math.ceil(self.window_size/focal_l_clips[k]) - # padding=(kernel_size-stride)/2 - else: - stride = focal_l_clips[k] - # kernel_size - # kernel_size = 2*(focal_kernel_clips[k]// 2) + 2**focal_l_clips[k] + (2**focal_l_clips[k]-1) - kernel_size = focal_kernel_clips[k] ## kernel_size must be jishu - assert kernel_size%2==1 - padding=kernel_size // 2 - # kernel_size_true=focal_kernel_clips[k]+2**focal_l_clips[k]-1 - kernel_size_true=kernel_size - # stride=math.ceil(self.window_size/focal_l_clips[k]) - - self.unfolds_clips += [nn.Unfold( - kernel_size=(kernel_size, kernel_size), - stride=stride, - padding=padding) - ] - relative_position_bias_table_to_windows = nn.Parameter( - torch.zeros( - self.num_heads, - (self.window_size[0] + kernel_size_true - 1) * (self.window_size[0] + kernel_size_true - 1), - ) - ) - trunc_normal_(relative_position_bias_table_to_windows, std=.02) - #self.relative_position_bias_table_to_windows_clips.append(relative_position_bias_table_to_windows) - self.register_parameter('relative_position_bias_table_to_windows_clips_{}'.format(k),relative_position_bias_table_to_windows) - relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(kernel_size_true)) - self.register_buffer("relative_position_index_clips_{}".format(k), relative_position_index_k) - # if (not focal_l_big_flag) and focal_l_clips[k]>0: - # mask = torch.zeros(kernel_size, kernel_size); mask[(2**focal_l_clips[k])-1:, (2**focal_l_clips[k])-1:] = 1 - # self.register_buffer("valid_ind_unfold_clips_{}".format(k), mask.flatten(0).nonzero().view(-1)) - - - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.softmax = nn.Softmax(dim=-1) - self.focal_l_clips=focal_l_clips - self.focal_kernel_clips=focal_kernel_clips - - def forward(self, x_all, mask_all=None, batch_size=None, num_clips=None): - """ - Args: - x_all (list[Tensors]): input features at different granularity - mask_all (list[Tensors/None]): masks for input features at different granularity - """ - x = x_all[0][0] # - - B0, nH, nW, C = x.shape - # assert B==batch_size*num_clips - assert B0==batch_size - qkv = self.qkv(x).reshape(B0, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # B0, nH, nW, C - - # partition q map - # print("x.shape: ", x.shape) - # print("q.shape: ", q.shape) # [4, 126, 126, 256] - (q_windows, k_windows, v_windows) = map( - lambda t: window_partition(t, self.window_size[0]).view( - -1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads - ).transpose(1, 2), - (q, k, v) - ) - - # q_dim0, q_dim1, q_dim2, q_dim3=q_windows.shape - # q_windows=q_windows.view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), q_dim1, q_dim2, q_dim3) - # q_windows=q_windows[:,-1].contiguous().view(-1, q_dim1, q_dim2, q_dim3) # query for the last frame (target frame) - - # k_windows.shape [1296, 8, 49, 32] - - if self.expand_size > 0 and self.focal_level > 0: - (k_tl, v_tl) = map( - lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v) - ) - (k_tr, v_tr) = map( - lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v) - ) - (k_bl, v_bl) = map( - lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v) - ) - (k_br, v_br) = map( - lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v) - ) - - (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( - lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads), - (k_tl, k_tr, k_bl, k_br) - ) - (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( - lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads), - (v_tl, v_tr, v_bl, v_br) - ) - k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2) - v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2) - - # mask out tokens in current window - # print("self.valid_ind_rolled.shape: ", self.valid_ind_rolled.shape) # [132] - # print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 196, 32] - k_rolled = k_rolled[:, :, self.valid_ind_rolled] - v_rolled = v_rolled[:, :, self.valid_ind_rolled] - k_rolled = torch.cat((k_windows, k_rolled), 2) - v_rolled = torch.cat((v_windows, v_rolled), 2) - else: - k_rolled = k_windows; v_rolled = v_windows; - - # print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 181, 32] - - if self.pool_method != "none" and self.focal_level > 1: - k_pooled = [] - v_pooled = [] - for k in range(self.focal_level-1): - stride = 2**k - x_window_pooled = x_all[0][k+1] # B0, nWh, nWw, C - nWh, nWw = x_window_pooled.shape[1:3] - - # generate mask for pooled windows - # print("x_window_pooled.shape: ", x_window_pooled.shape) - mask = x_window_pooled.new(nWh, nWw).fill_(1) - # print("here: ",x_window_pooled.shape, self.unfolds[k].kernel_size, self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).shape) - # print(mask.unique()) - unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view( - 1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\ - view(nWh*nWw // stride // stride, -1, 1) - - if k > 0: - valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k)) - unfolded_mask = unfolded_mask[:, valid_ind_unfold_k] - - # print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique()) - x_window_masks = unfolded_mask.flatten(1).unsqueeze(0) - # print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique()) - x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0)) - # print(x_window_masks.shape) - mask_all[0][k+1] = x_window_masks - - # generate k and v for pooled windows - qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous() - k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw - - - (k_pooled_k, v_pooled_k) = map( - lambda t: self.unfolds[k](t).view( - B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\ - view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2), - (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim - ) - - # print("k_pooled_k.shape: ", k_pooled_k.shape) - # print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape) - - if k > 0: - (k_pooled_k, v_pooled_k) = map( - lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k) - ) - - # print("k_pooled_k.shape: ", k_pooled_k.shape) - - k_pooled += [k_pooled_k] - v_pooled += [v_pooled_k] - - for k in range(len(self.focal_l_clips)): - focal_l_big_flag=False - if self.focal_l_clips[k]>self.window_size[0]: - stride=1 - focal_l_big_flag=True - else: - stride = self.focal_l_clips[k] - # if self.window_size>=focal_l_clips[k]: - # stride=math.ceil(self.window_size/focal_l_clips[k]) - # # padding=(kernel_size-stride)/2 - # else: - # stride=1 - # padding=0 - x_window_pooled = x_all[k+1] - nWh, nWw = x_window_pooled.shape[1:3] - mask = x_window_pooled.new(nWh, nWw).fill_(1) - - # import pdb; pdb.set_trace() - # print(x_window_pooled.shape, self.unfolds_clips[k].kernel_size, self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).shape) - - unfolded_mask = self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).view( - 1, 1, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\ - view(nWh*nWw // stride // stride, -1, 1) - - # if (not focal_l_big_flag) and self.focal_l_clips[k]>0: - # valid_ind_unfold_k = getattr(self, "valid_ind_unfold_clips_{}".format(k)) - # unfolded_mask = unfolded_mask[:, valid_ind_unfold_k] - - # print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique()) - x_window_masks = unfolded_mask.flatten(1).unsqueeze(0) - # print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique()) - x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0)) - # print(x_window_masks.shape) - mask_all[k+1] = x_window_masks - - # generate k and v for pooled windows - qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous() - k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw - - if (not focal_l_big_flag): - (k_pooled_k, v_pooled_k) = map( - lambda t: self.unfolds_clips[k](t).view( - B0, C, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\ - view(-1, self.unfolds_clips[k].kernel_size[0]*self.unfolds_clips[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2), - (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim - ) - else: - - (k_pooled_k, v_pooled_k) = map( - lambda t: self.unfolds_clips[k](t), - (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim - ) - LLL=k_pooled_k.size(2) - LLL_h=int(LLL**0.5) - assert LLL_h**2==LLL - k_pooled_k=k_pooled_k.reshape(B0, -1, LLL_h, LLL_h) - v_pooled_k=v_pooled_k.reshape(B0, -1, LLL_h, LLL_h) - - - - # print("k_pooled_k.shape: ", k_pooled_k.shape) - # print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape) - # if (not focal_l_big_flag) and self.focal_l_clips[k]: - # (k_pooled_k, v_pooled_k) = map( - # lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k) - # ) - - # print("k_pooled_k.shape: ", k_pooled_k.shape) - - k_pooled += [k_pooled_k] - v_pooled += [v_pooled_k] - - # qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous() - # k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw - # (k_pooled_k, v_pooled_k) = map( - # lambda t: self.unfolds[k](t).view( - # B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\ - # view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2), - # (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim - # ) - # k_pooled += [k_pooled_k] - # v_pooled += [v_pooled_k] - - - k_all = torch.cat([k_rolled] + k_pooled, 2) - v_all = torch.cat([v_rolled] + v_pooled, 2) - else: - k_all = k_rolled - v_all = v_rolled - - N = k_all.shape[-2] - q_windows = q_windows * self.scale - # print(q_windows.shape, k_all.shape, v_all.shape) - # exit() - # k_all_dim0, k_all_dim1, k_all_dim2, k_all_dim3=k_all.shape - # k_all=k_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), - # k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3) - # v_all=v_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), - # k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3) - - # print(q_windows.shape, k_all.shape, v_all.shape, k_rolled.shape) - # exit() - attn = (q_windows @ k_all.transpose(-2, -1)) # B0*nW, nHead, window_size*window_size, focal_window_size*focal_window_size - - window_area = self.window_size[0] * self.window_size[1] - # window_area_clips= num_clips*self.window_size[0] * self.window_size[1] - window_area_rolled = k_rolled.shape[2] - - # add relative position bias for tokens inside window - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - # print(relative_position_bias.shape, attn.shape) - attn[:, :, :window_area, :window_area] = attn[:, :, :window_area, :window_area] + relative_position_bias.unsqueeze(0) - - # relative_position_bias = self.relative_position_bias_table[self.relative_position_index[-window_area:, :window_area_clips].reshape(-1)].view( - # window_area, window_area_clips, -1) # Wh*Ww,Wd*Wh*Ww,nH - # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_heads,window_area,num_clips,window_area - # ).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,window_area_clips).contiguous() # nH, Wh*Ww, Wh*Ww*Wd - # # attn_dim0, attn_dim1, attn_dim2, attn_dim3=attn.shape - # # attn=attn.view(attn_dim0,attn_dim1,attn_dim2,num_clips,-1) - # # print(attn.shape, relative_position_bias.shape) - # attn[:,:,:window_area, :window_area_clips]=attn[:,:,:window_area, :window_area_clips] + relative_position_bias.unsqueeze(0) - # attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N - - # add relative position bias for patches inside a window - if self.expand_size > 0 and self.focal_level > 0: - attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area, window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors - - if self.pool_method != "none" and self.focal_level > 1: - # add relative position bias for different windows in an image - offset = window_area_rolled - # print(offset) - for k in range(self.focal_level-1): - # add relative position bias - relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k)) - relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_{}'.format(k))[:, relative_position_index_k.view(-1)].view( - -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2, - ) # nH, NWh*NWw,focal_region*focal_region - attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \ - attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0) - # add attentional mask - if mask_all[0][k+1] is not None: - attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \ - attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \ - mask_all[0][k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[0][k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[0][k+1].shape[-1]) - - offset += (self.focal_window+2**k-1)**2 - # print(offset) - for k in range(len(self.focal_l_clips)): - focal_l_big_flag=False - if self.focal_l_clips[k]>self.window_size[0]: - stride=1 - padding=0 - kernel_size=self.focal_kernel_clips[k] - kernel_size_true=kernel_size - focal_l_big_flag=True - # stride=math.ceil(self.window_size/focal_l_clips[k]) - # padding=(kernel_size-stride)/2 - else: - stride = self.focal_l_clips[k] - # kernel_size - # kernel_size = 2*(self.focal_kernel_clips[k]// 2) + 2**self.focal_l_clips[k] + (2**self.focal_l_clips[k]-1) - kernel_size = self.focal_kernel_clips[k] - padding=kernel_size // 2 - # kernel_size_true=self.focal_kernel_clips[k]+2**self.focal_l_clips[k]-1 - kernel_size_true=kernel_size - relative_position_index_k = getattr(self, 'relative_position_index_clips_{}'.format(k)) - relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_clips_{}'.format(k))[:, relative_position_index_k.view(-1)].view( - -1, self.window_size[0] * self.window_size[1], (kernel_size_true)**2, - ) - attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \ - attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + relative_position_bias_to_windows.unsqueeze(0) - if mask_all[k+1] is not None: - attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \ - attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + \ - mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1]) - offset += (kernel_size_true)**2 - # print(offset) - # relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k)) - # # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k.view(-1)].view( - # # -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2, - # # ) # nH, NWh*NWw,focal_region*focal_region - # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \ - # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0) - # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k[-window_area:, :].view(-1)].view( - # -1, self.window_size[0] * self.window_size[1], num_clips*(self.focal_window+2**k-1)**2, - # ).contiguous() # nH, NWh*NWw, num_clips*focal_region*focal_region - # relative_position_bias_to_windows = relative_position_bias_to_windows.view(self.num_heads, - # window_area,num_clips,-1).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,-1) - # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \ - # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0) - # # add attentional mask - # if mask_all[k+1] is not None: - # # print("inside the mask, be careful 1") - # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \ - # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \ - # # mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1]) - # # print("here: ", mask_all[k+1].shape, mask_all[k+1][:, :, None, None, :].shape) - - # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \ - # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + \ - # mask_all[k+1][:, :, None, None, :,None].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1, num_clips).view(-1, 1, 1, mask_all[k+1].shape[-1]*num_clips) - # # print() - - # offset += (self.focal_window+2**k-1)**2 - - # print("mask_all[0]: ", mask_all[0]) - # exit() - if mask_all[0][0] is not None: - print("inside the mask, be careful 0") - nW = mask_all[0].shape[0] - attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N) - attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :] - attn = attn.view(-1, self.num_heads, window_area, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C) - x = self.proj(x) - x = self.proj_drop(x) - # print(x.shape) - # x = x.view(B/num_clips, nH, nW, C ) - # exit() - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N, window_size, unfold_size): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - if self.pool_method != "none" and self.focal_level > 1: - flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size) - if self.expand_size > 0 and self.focal_level > 0: - flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2) - - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - if self.pool_method != "none" and self.focal_level > 1: - flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size) - if self.expand_size > 0 and self.focal_level > 0: - flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2) - - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class CffmTransformerBlock3d3(nn.Module): - r""" Focal Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - expand_size (int): expand size at first focal level (finest level). - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pool_method (str): window pooling method. Default: none, options: [none|fc|conv] - focal_level (int): number of focal levels. Default: 1. - focal_window (int): region size of focal attention. Default: 1 - use_layerscale (bool): whether use layer scale for training stability. Default: False - layerscale_value (float): scaling value for layer scale. Default: 1e-4 - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none", - focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[7,2,4], focal_kernel_clips=[7,5,3]): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.expand_size = expand_size - self.mlp_ratio = mlp_ratio - self.pool_method = pool_method - self.focal_level = focal_level - self.focal_window = focal_window - self.use_layerscale = use_layerscale - self.focal_l_clips=focal_l_clips - self.focal_kernel_clips=focal_kernel_clips - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.expand_size = 0 - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.window_size_glo = self.window_size - - self.pool_layers = nn.ModuleList() - self.pool_layers_clips = nn.ModuleList() - if self.pool_method != "none": - for k in range(self.focal_level-1): - window_size_glo = math.floor(self.window_size_glo / (2 ** k)) - if self.pool_method == "fc": - self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1)) - self.pool_layers[-1].weight.data.fill_(1./(window_size_glo * window_size_glo)) - self.pool_layers[-1].bias.data.fill_(0) - elif self.pool_method == "conv": - self.pool_layers.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim)) - for k in range(len(focal_l_clips)): - # window_size_glo = math.floor(self.window_size_glo / (2 ** k)) - if focal_l_clips[k]>self.window_size: - window_size_glo = focal_l_clips[k] - else: - window_size_glo = math.floor(self.window_size_glo / (focal_l_clips[k])) - # window_size_glo = focal_l_clips[k] - if self.pool_method == "fc": - self.pool_layers_clips.append(nn.Linear(window_size_glo * window_size_glo, 1)) - self.pool_layers_clips[-1].weight.data.fill_(1./(window_size_glo * window_size_glo)) - self.pool_layers_clips[-1].bias.data.fill_(0) - elif self.pool_method == "conv": - self.pool_layers_clips.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim)) - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention3d3( - dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size), - focal_window=focal_window, focal_level=focal_level, num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method, focal_l_clips=focal_l_clips, focal_kernel_clips=focal_kernel_clips) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - # print("******self.shift_size: ", self.shift_size) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - # print("here mask none") - attn_mask = None - self.register_buffer("attn_mask", attn_mask) - - if self.use_layerscale: - self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) - - def forward(self, x): - H0, W0 = self.input_resolution - # B, L, C = x.shape - B0, D0, H0, W0, C = x.shape - shortcut = x - # assert L == H * W, "input feature has wrong size" - x=x.reshape(B0*D0,H0,W0,C).reshape(B0*D0,H0*W0,C) - - - x = self.norm1(x) - x = x.reshape(B0*D0, H0, W0, C) - # print("here") - # exit() - - # pad feature maps to multiples of window size - pad_l = pad_t = 0 - pad_r = (self.window_size - W0 % self.window_size) % self.window_size - pad_b = (self.window_size - H0 % self.window_size) % self.window_size - if pad_r > 0 or pad_b > 0: - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - - B, H, W, C = x.shape ## B=B0*D0 - - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # print("shifted_x.shape: ", shifted_x.shape) - shifted_x=shifted_x.view(B0,D0,H,W,C) - x_windows_all = [shifted_x[:,-1]] - x_windows_all_clips=[] - x_window_masks_all = [self.attn_mask] - x_window_masks_all_clips=[] - - if self.focal_level > 1 and self.pool_method != "none": - # if we add coarser granularity and the pool method is not none - # pooling_index=0 - for k in range(self.focal_level-1): - window_size_glo = math.floor(self.window_size_glo / (2 ** k)) - pooled_h = math.ceil(H / self.window_size) * (2 ** k) - pooled_w = math.ceil(W / self.window_size) * (2 ** k) - H_pool = pooled_h * window_size_glo - W_pool = pooled_w * window_size_glo - - x_level_k = shifted_x[:,-1] - # trim or pad shifted_x depending on the required size - if H > H_pool: - trim_t = (H - H_pool) // 2 - trim_b = H - H_pool - trim_t - x_level_k = x_level_k[:, trim_t:-trim_b] - elif H < H_pool: - pad_t = (H_pool - H) // 2 - pad_b = H_pool - H - pad_t - x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b)) - - if W > W_pool: - trim_l = (W - W_pool) // 2 - trim_r = W - W_pool - trim_l - x_level_k = x_level_k[:, :, trim_l:-trim_r] - elif W < W_pool: - pad_l = (W_pool - W) // 2 - pad_r = W_pool - W - pad_l - x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r)) - - x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C - nWh, nWw = x_windows_noreshape.shape[1:3] - if self.pool_method == "mean": - x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C - elif self.pool_method == "max": - x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C - elif self.pool_method == "fc": - x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2 - x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C - elif self.pool_method == "conv": - x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize - x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C - - x_windows_all += [x_windows_pooled] - # print(x_windows_pooled.shape) - x_window_masks_all += [None] - # pooling_index=pooling_index+1 - - x_windows_all_clips += [x_windows_all] - x_window_masks_all_clips += [x_window_masks_all] - for k in range(len(self.focal_l_clips)): - if self.focal_l_clips[k]>self.window_size: - window_size_glo = self.focal_l_clips[k] - else: - window_size_glo = math.floor(self.window_size_glo / (self.focal_l_clips[k])) - - pooled_h = math.ceil(H / self.window_size) * (self.focal_l_clips[k]) - pooled_w = math.ceil(W / self.window_size) * (self.focal_l_clips[k]) - - H_pool = pooled_h * window_size_glo - W_pool = pooled_w * window_size_glo - - x_level_k = shifted_x[:,k] - if H!=H_pool or W!=W_pool: - x_level_k=F.interpolate(x_level_k.permute(0,3,1,2), size=(H_pool, W_pool), mode='bilinear').permute(0,2,3,1) - - # print(x_level_k.shape) - x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C - nWh, nWw = x_windows_noreshape.shape[1:3] - if self.pool_method == "mean": - x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C - elif self.pool_method == "max": - x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C - elif self.pool_method == "fc": - x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2 - x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C - elif self.pool_method == "conv": - x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize - x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C - - x_windows_all_clips += [x_windows_pooled] - # print(x_windows_pooled.shape) - x_window_masks_all_clips += [None] - # pooling_index=pooling_index+1 - # exit() - - attn_windows = self.attn(x_windows_all_clips, mask_all=x_window_masks_all_clips, batch_size=B0, num_clips=D0) # nW*B0, window_size*window_size, C - - attn_windows = attn_windows[:, :self.window_size ** 2] - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H(padded) W(padded) C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - # x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C) - x = x[:, :H0, :W0].contiguous().view(B0, -1, C) - - # FFN - # x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x)) - # x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x)))) - - # print(x.shape, shortcut[:,-1].view(B0, -1, C).shape) - x = shortcut[:,-1].view(B0, -1, C) + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x)) - x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x)))) - - # x=torch.cat([shortcut[:,:-1],x.view(B0,self.input_resolution[0],self.input_resolution[1],C).unsqueeze(1)],1) - x=torch.cat([shortcut[:,:-1],x.view(B0,H0,W0,C).unsqueeze(1)],1) - - assert x.shape==shortcut.shape - - # exit() - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window) - - if self.pool_method != "none" and self.focal_level > 1: - for k in range(self.focal_level-1): - window_size_glo = math.floor(self.window_size_glo / (2 ** k)) - nW_glo = nW * (2**k) - # (sub)-window pooling - flops += nW_glo * self.dim * window_size_glo * window_size_glo - # qkv for global levels - # NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer, - # but theoritically, we only need to compute k and v. - flops += nW_glo * self.dim * 3 * self.dim - - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class BasicLayer3d3(nn.Module): - """ A basic Focal Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - expand_size (int): expand size for focal level 1. - expand_layer (str): expand layer. Default: all - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pool_method (str): Window pooling method. Default: none. - focal_level (int): Number of focal levels. Default: 1. - focal_window (int): region size at each focal level. Default: 1. - use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False - use_shift (bool): Whether use window shift as in Swin Transformer. Default: False - use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - use_layerscale (bool): Whether use layer scale for stability. Default: False. - layerscale_value (float): Layerscale value. Default: 1e-4. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all", - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, pool_method="none", - focal_level=1, focal_window=1, use_conv_embed=False, use_shift=False, use_pre_norm=False, - downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[16,8,2], focal_kernel_clips=[7,5,3]): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - if expand_layer == "even": - expand_factor = 0 - elif expand_layer == "odd": - expand_factor = 1 - elif expand_layer == "all": - expand_factor = -1 - - # build blocks - self.blocks = nn.ModuleList([ - CffmTransformerBlock3d3(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0, - expand_size=0 if (i % 2 == expand_factor) else expand_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pool_method=pool_method, - focal_level=focal_level, - focal_window=focal_window, - use_layerscale=use_layerscale, - layerscale_value=layerscale_value, - focal_l_clips=focal_l_clips, - focal_kernel_clips=focal_kernel_clips) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample( - img_size=input_resolution, patch_size=2, in_chans=dim, embed_dim=2*dim, - use_conv_embed=use_conv_embed, norm_layer=norm_layer, use_pre_norm=use_pre_norm, - is_stem=False - ) - else: - self.downsample = None - - def forward(self, x, batch_size=None, num_clips=None, reg_tokens=None): - B, D, C, H, W = x.shape - x = rearrange(x, 'b d c h w -> b d h w c') - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) - - if self.downsample is not None: - x = x.view(x.shape[0], self.input_resolution[0], self.input_resolution[1], -1).permute(0, 3, 1, 2).contiguous() - x = self.downsample(x) - x = rearrange(x, 'b d h w c -> b d c h w') - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops diff --git a/models/SpaTrackV2/models/depth_refiner/stablizer.py b/models/SpaTrackV2/models/depth_refiner/stablizer.py deleted file mode 100644 index 5465610..0000000 --- a/models/SpaTrackV2/models/depth_refiner/stablizer.py +++ /dev/null @@ -1,342 +0,0 @@ -import numpy as np -import torch.nn as nn -import torch -# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule -from collections import OrderedDict -# from mmseg.ops import resize -from torch.nn.functional import interpolate as resize -# from builder import HEADS -from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow -# from mmseg.models.utils import * -import attr -from IPython import embed -from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3 -import cv2 -from models.SpaTrackV2.models.depth_refiner.network import * -import warnings -# from mmcv.utils import Registry, build_from_cfg -from torch import nn -from einops import rearrange -import torch.nn.functional as F -from models.SpaTrackV2.models.blocks import ( - AttnBlock, CrossAttnBlock, Mlp -) - -class MLP(nn.Module): - """ - Linear Embedding - """ - def __init__(self, input_dim=2048, embed_dim=768): - super().__init__() - self.proj = nn.Linear(input_dim, embed_dim) - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) - return x - - -def scatter_multiscale_fast( - track2d: torch.Tensor, - trackfeature: torch.Tensor, - H: int, - W: int, - kernel_sizes = [1] -) -> torch.Tensor: - """ - Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps. - - This function scatters sparse track features into a dense image grid and applies multi-scale average pooling - while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling, - avoiding dilution by empty pixels. - - Args: - track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates - for each track point across batches, frames, and points. - trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features - for each track point. - H (int): Height of the target output image. - W (int): Width of the target output image. - kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7]. - - Returns: - torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling. - """ - B, T, N, C = trackfeature.shape - device = trackfeature.device - - # 1. Flatten coordinates and filter valid points within image bounds - coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2) - x = coords_flat[:, 0] # x coordinates - y = coords_flat[:, 1] # y coordinates - feat_flat = trackfeature.reshape(-1, C) # Flatten features - - valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H) - x = x[valid_mask] - y = y[valid_mask] - feat_flat = feat_flat[valid_mask] - valid_count = x.shape[0] - - if valid_count == 0: - return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case - - # 2. Calculate linear indices and batch-frame indices for scattering - lin_idx = y * W + x # Linear index within a single frame (H*W range) - - # Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch) - bt_idx_raw = ( - torch.arange(B * T, device=device) - .view(B, T, 1) - .expand(B, T, N) - .reshape(-1) - ) - bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames - - # 3. Create accumulation buffers for features and weights - total_space = B * T * H * W - img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator - weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts) - - # 4. Scatter features and weights into accumulation buffers - idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index - - # Add features to corresponding indices (index_add_ is efficient for sparse updates) - img_accum_flat.index_add_(0, idx_in_accum, feat_flat) - weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device)) - - # 5. Normalize features by valid weights, keep zeros for invalid regions - valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels - img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero - img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions - - # 6. Reshape to (B, T, C, H, W) for further processing - img = ( - img_accum_flat.view(B, T, H, W, C) - .permute(0, 1, 4, 2, 3) - .contiguous() - ) # Shape: (B, T, C, H, W) - - # 7. Multi-scale pooling with weight masking to exclude zero holes - blurred_outputs = [] - for k in kernel_sizes: - pad = k // 2 - img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling - - # Create weight mask for valid regions (1 where features exist, 0 otherwise) - weight_mask = ( - weight_accum_flat.view(B, T, 1, H, W) > 0 - ).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W) - - # Calculate number of valid neighbors in each pooling window - weight_sum = F.conv2d( - weight_mask, - torch.ones((1, 1, k, k), device=device), - stride=1, - padding=pad - ) # Shape: (B*T, 1, H, W) - - # Sum features only in valid regions - feat_sum = F.conv2d( - img_bt * weight_mask, # Mask out invalid regions before summing - torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k), - stride=1, - padding=pad, - groups=C - ) # Shape: (B*T, C, H, W) - - # Compute average only over valid neighbors - feat_avg = feat_sum / (weight_sum + 1e-6) - blurred_outputs.append(feat_avg) - - # 8. Fuse multi-scale results by averaging across kernel sizes - fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes - return fused.view(B, T, C, H, W) # Restore original shape - -#@HEADS.register_module() -class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow): - - def __init__(self, feature_strides, **kwargs): - super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs) - self.training = False - assert len(feature_strides) == len(self.in_channels) - assert min(feature_strides) == feature_strides[0] - self.feature_strides = feature_strides - - c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels - - decoder_params = kwargs['decoder_params'] - embedding_dim = decoder_params['embed_dim'] - - self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) - self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) - self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) - self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) - - self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\ - nn.ReLU(inplace=True)) - - self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True) - - depths = decoder_params['depths'] - - self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim)) - self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True) - - self.att_temporal = nn.ModuleList( - [ - AttnBlock(embedding_dim, 8, - mlp_ratio=4, flash=True, ckpt_fwd=True) - for _ in range(8) - ] - ) - self.att_spatial = nn.ModuleList( - [ - AttnBlock(embedding_dim, 8, - mlp_ratio=4, flash=True, ckpt_fwd=True) - for _ in range(8) - ] - ) - self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4)) - - - # Initialize reg tokens - nn.init.trunc_normal_(self.reg_tokens, std=0.02) - - self.decoder_focal=BasicLayer3d3(dim=embedding_dim, - input_resolution=(96, - 96), - depth=depths, - num_heads=8, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - pool_method='fc', - downsample=None, - focal_level=2, - focal_window=5, - expand_size=3, - expand_layer="all", - use_conv_embed=False, - use_shift=False, - use_pre_norm=False, - use_checkpoint=False, - use_layerscale=False, - layerscale_value=1e-4, - focal_l_clips=[7,4,2], - focal_kernel_clips=[7,5,3]) - - self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128) - self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64) - self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1) - self.AO = AO(32, outchannels=3, upfactor=1) - self._c2 = None - self._c_further = None - - def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1): - - # input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized - if self.training: - assert self.num_clips==num_clips - - x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 - c1, c2, c3, c4 = x - - ############## MLP decoder on C1-C4 ########### - n, _, h, w = c4.shape - batch_size = n // num_clips - - _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) - _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) - - _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) - _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) - - _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) - _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) - - _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) - _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) - - _, _, h, w=_c.shape - _c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2 - - # Expand reg_tokens to match batch size - reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C] - - _c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens) - - assert _c_further.shape==_c2.shape - self._c2 = _c2 - self._c_further = _c_further - - # compute the scale and shift of the global patch - global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1) - global_patch = torch.cat([global_patch, reg_tokens], dim=1) - for i in range(8): - global_patch = self.att_temporal[i](global_patch) - global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2]) - global_patch = self.att_spatial[i](global_patch) - global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2]) - - reg_tokens = global_patch[:, -2:, :] - s_ = self.scale_shift_head(reg_tokens) - scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1) - shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1) - shift[:,:,:2,...] = 0 - return scale, shift - - def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1): - - if self._c2 is None: - scale, shift = self.buffer_forward(inputs,num_clips,imgs) - - B, T, N, _ = tracks.shape - - _c2 = self._c2 - _c_further = self._c_further - - # skip and head - _c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T) - _c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T) - - outframe = self.ffm2(_c_further, _c2) - - tracks_uv = tracks_uvd[...,:2].clone() - track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5]) - # visualize track_feature as video - # import cv2 - # import imageio - # import os - # BT, C, H, W = outframe.shape - # track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy() - # track_feature_vis = track_feature_vis.transpose(0,1,3,4,2) - # track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6) - # track_feature_vis = (track_feature_vis * 255).astype(np.uint8) - # imgs =(imgs.detach() + 1) * 127.5 - # vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test") - # for b in range(B): - # frames = [] - # for t in range(T): - # frame = track_feature_vis[b,t] - # frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET) - # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - # frames.append(frame) - # # Save as gif - # imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1) - # import pdb; pdb.set_trace() - track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w') - track_feature = self.proj_track(track_feature) - outframe = self.ffm1(edge_feat1 + track_feature,outframe) - outframe = self.ffm0(edge_feat,outframe) - outframe = self.AO(outframe) - - return outframe - - def reset_success(self): - self._c2 = None - self._c_further = None