first-commit

This commit is contained in:
xiaoyuxi 2025-07-08 15:56:48 +08:00
parent b930b1bbd4
commit 4bd03c72f3
6 changed files with 0 additions and 3164 deletions

View File

@ -1,472 +0,0 @@
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.layers import DropPath, to_2tuple, trunc_normal_
from timm.models import register_model
from timm.models.vision_transformer import _cfg
import math
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class OverlapPatchEmbed43(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
if x.shape[1]==4:
x = self.proj_4c(x)
else:
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.num_classes = num_classes
self.depths = depths
# patch_embed 43
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
# classification head
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
if x.dim() == 5:
x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
x = self.forward_features(x)
# x = self.head(x)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
#@BACKBONES.register_module()
class mit_b0(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
#@BACKBONES.register_module()
class mit_b1(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
#@BACKBONES.register_module()
class mit_b2(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
#@BACKBONES.register_module()
class mit_b3(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
#@BACKBONES.register_module()
class mit_b4(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
#@BACKBONES.register_module()
class mit_b5(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)

View File

@ -1,619 +0,0 @@
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
# from mmcv.cnn import normal_init
# from mmcv.runner import auto_fp16, force_fp32
# from mmseg.core import build_pixel_sampler
# from mmseg.ops import resize
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
decoder_params=None,
ignore_index=255,
sampler=None,
align_corners=False):
super(BaseDecodeHead, self).__init__()
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.ignore_index = ignore_index
self.align_corners = align_corners
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
# @auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs)
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead_clips.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
decoder_params=None,
ignore_index=255,
sampler=None,
align_corners=False,
num_clips=5):
super(BaseDecodeHead_clips, self).__init__()
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.ignore_index = ignore_index
self.align_corners = align_corners
self.num_clips=num_clips
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
# @auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs,batch_size, num_clips)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, batch_size, num_clips)
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead_clips_flow.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
decoder_params=None,
ignore_index=255,
sampler=None,
align_corners=False,
num_clips=5):
super(BaseDecodeHead_clips_flow, self).__init__()
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.ignore_index = ignore_index
self.align_corners = align_corners
self.num_clips=num_clips
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
# @auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs,batch_size, num_clips,img)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, batch_size, num_clips,img)
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output

View File

@ -1,115 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
from einops import rearrange
class TrackStablizer(nn.Module):
def __init__(self):
super().__init__()
self.backbone = mit_b3()
old_conv = self.backbone.patch_embed1.proj
new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
self.backbone.patch_embed1.proj = new_conv
self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=1,
align_corners=False,
decoder_params=dict(embed_dim=256, depths=4),
num_clips=16,
norm_cfg = dict(type='SyncBN', requires_grad=True))
self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
nn.ReLU(inplace=True))
self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
nn.ReLU(inplace=True))
self.success = False
self.x = None
def buffer_forward(self, inputs, num_clips=16):
"""
buffer forward for getting the pointmap and image features
"""
B, T, C, H, W = inputs.shape
self.x = self.backbone(inputs)
scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
self.success = True
return scale, shift
def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
"""
Args:
inputs: [B, T, C, H, W], RGB + PointMap + Mask
tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
num_clips: int, number of clips to use
"""
B, T, C, H, W = inputs.shape
edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
edge_feat1 = self.edge_conv1(edge_feat)
if not self.success:
scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
self.success = True
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
else:
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
return update
def reset_success(self):
self.success = False
self.x = None
self.Track_Stabilizer.reset_success()
if __name__ == "__main__":
# Create test input tensors
batch_size = 1
seq_len = 16
channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
height = 384
width = 512
# Create random input tensor with shape [B, T, C, H, W]
inputs = torch.randn(batch_size, seq_len, channels, height, width)
# Create random tracks
tracks = torch.randn(batch_size, seq_len, 1024, 4)
# Create random test images
test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
# Initialize model and move to GPU
model = TrackStablizer().cuda()
# Move inputs to GPU and run forward pass
inputs = inputs.cuda()
tracks = tracks.cuda()
outputs = model.buffer_forward(inputs, num_clips=seq_len)
import time
start_time = time.time()
outputs = model(inputs, tracks, num_clips=seq_len)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
import pdb; pdb.set_trace()
# # Print shapes for verification
# print(f"Input shape: {inputs.shape}")
# print(f"Output shape: {outputs.shape}")
# # Basic tests
# assert outputs.shape[0] == batch_size, "Batch size mismatch"
# assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
# assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
# print("All tests passed!")

View File

@ -1,429 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
Author: Ke Xian
Email: kexian@hust.edu.cn
Date: 2020/07/20
'''
import torch
import torch.nn as nn
import torch.nn.init as init
# ==============================================================================================================
class FTB(nn.Module):
def __init__(self, inchannels, midchannels=512):
super(FTB, self).__init__()
self.in1 = inchannels
self.mid = midchannels
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
#nn.BatchNorm2d(num_features=self.mid),\
nn.ReLU(inplace=True),\
nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
self.relu = nn.ReLU(inplace=True)
self.init_params()
def forward(self, x):
x = self.conv1(x)
x = x + self.conv_branch(x)
x = self.relu(x)
return x
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
# init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
# init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class ATA(nn.Module):
def __init__(self, inchannels, reduction = 8):
super(ATA, self).__init__()
self.inchannels = inchannels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
nn.ReLU(inplace=True),
nn.Linear(self.inchannels // reduction, self.inchannels),
nn.Sigmoid())
self.init_params()
def forward(self, low_x, high_x):
n, c, _, _ = low_x.size()
x = torch.cat([low_x, high_x], 1)
x = self.avg_pool(x)
x = x.view(n, -1)
x = self.fc(x).view(n,c,1,1)
x = low_x * x + high_x
return x
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
#init.normal(m.weight, std=0.01)
init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
#init.normal_(m.weight, std=0.01)
init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class FFM(nn.Module):
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
super(FFM, self).__init__()
self.inchannels = inchannels
self.midchannels = midchannels
self.outchannels = outchannels
self.upfactor = upfactor
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
self.init_params()
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
def forward(self, low_x, high_x):
x = self.ftb1(low_x)
'''
x = torch.cat((x,high_x),1)
if x.shape[2] == 12:
x = self.p1(x)
elif x.shape[2] == 24:
x = self.p2(x)
elif x.shape[2] == 48:
x = self.p3(x)
'''
x = x + high_x ###high_x
x = self.ftb2(x)
x = self.upsample(x)
return x
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class noFFM(nn.Module):
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
super(noFFM, self).__init__()
self.inchannels = inchannels
self.midchannels = midchannels
self.outchannels = outchannels
self.upfactor = upfactor
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
self.init_params()
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
def forward(self, low_x, high_x):
#x = self.ftb1(low_x)
x = high_x ###high_x
x = self.ftb2(x)
x = self.upsample(x)
return x
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class AO(nn.Module):
# Adaptive output module
def __init__(self, inchannels, outchannels, upfactor=2):
super(AO, self).__init__()
self.inchannels = inchannels
self.outchannels = outchannels
self.upfactor = upfactor
"""
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
nn.BatchNorm2d(num_features=self.inchannels//2),\
nn.ReLU(inplace=True),\
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
#nn.ReLU(inplace=True)) ## get positive values
"""
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
#nn.BatchNorm2d(num_features=self.inchannels//2),\
nn.ReLU(inplace=True),\
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
#nn.ReLU(inplace=True)) ## get positive values
self.init_params()
def forward(self, x):
x = self.adapt_conv(x)
return x
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class ASPP(nn.Module):
def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
super(ASPP, self).__init__()
self.inchannels = inchannels
self.planes = planes
self.rates = rates
self.kernel_sizes = []
self.paddings = []
for rate in self.rates:
if rate == 1:
self.kernel_sizes.append(1)
self.paddings.append(0)
else:
self.kernel_sizes.append(3)
self.paddings.append(rate)
self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=self.planes)
)
self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=self.planes),
)
self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=self.planes),
)
self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=self.planes),
)
#self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
def forward(self, x):
x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
#x = self.conv(x)
return x
# ==============================================================================================================
class ResidualConv(nn.Module):
def __init__(self, inchannels):
super(ResidualConv, self).__init__()
#nn.BatchNorm2d
self.conv = nn.Sequential(
#nn.BatchNorm2d(num_features=inchannels),
nn.ReLU(inplace=False),
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(num_features=inchannels//2),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
)
self.init_params()
def forward(self, x):
x = self.conv(x)+x
return x
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class FeatureFusion(nn.Module):
def __init__(self, inchannels, outchannels):
super(FeatureFusion, self).__init__()
self.conv = ResidualConv(inchannels=inchannels)
#nn.BatchNorm2d
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(num_features=outchannels),
nn.ReLU(inplace=True))
def forward(self, lowfeat, highfeat):
return self.up(highfeat + self.conv(lowfeat))
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
#init.kaiming_normal_(m.weight, mode='fan_out')
init.normal_(m.weight, std=0.01)
#init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.01)
if m.bias is not None:
init.constant_(m.bias, 0)
class SenceUnderstand(nn.Module):
def __init__(self, channels):
super(SenceUnderstand, self).__init__()
self.channels = channels
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
nn.ReLU(inplace = True))
self.pool = nn.AdaptiveAvgPool2d(8)
self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
nn.ReLU(inplace = True))
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
nn.ReLU(inplace=True))
self.initial_params()
def forward(self, x):
n,c,h,w = x.size()
x = self.conv1(x)
x = self.pool(x)
x = x.view(n,-1)
x = self.fc(x)
x = x.view(n, self.channels, 1, 1)
x = self.conv2(x)
x = x.repeat(1,1,h,w)
return x
def initial_params(self, dev=0.01):
for m in self.modules():
if isinstance(m, nn.Conv2d):
#print torch.sum(m.weight)
m.weight.data.normal_(0, dev)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, nn.ConvTranspose2d):
#print torch.sum(m.weight)
m.weight.data.normal_(0, dev)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, dev)

View File

@ -1,342 +0,0 @@
import numpy as np
import torch.nn as nn
import torch
# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from collections import OrderedDict
# from mmseg.ops import resize
from torch.nn.functional import interpolate as resize
# from builder import HEADS
from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
# from mmseg.models.utils import *
import attr
from IPython import embed
from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
import cv2
from models.SpaTrackV2.models.depth_refiner.network import *
import warnings
# from mmcv.utils import Registry, build_from_cfg
from torch import nn
from einops import rearrange
import torch.nn.functional as F
from models.SpaTrackV2.models.blocks import (
AttnBlock, CrossAttnBlock, Mlp
)
class MLP(nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
def scatter_multiscale_fast(
track2d: torch.Tensor,
trackfeature: torch.Tensor,
H: int,
W: int,
kernel_sizes = [1]
) -> torch.Tensor:
"""
Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
avoiding dilution by empty pixels.
Args:
track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
for each track point across batches, frames, and points.
trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
for each track point.
H (int): Height of the target output image.
W (int): Width of the target output image.
kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
Returns:
torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
"""
B, T, N, C = trackfeature.shape
device = trackfeature.device
# 1. Flatten coordinates and filter valid points within image bounds
coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
x = coords_flat[:, 0] # x coordinates
y = coords_flat[:, 1] # y coordinates
feat_flat = trackfeature.reshape(-1, C) # Flatten features
valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
x = x[valid_mask]
y = y[valid_mask]
feat_flat = feat_flat[valid_mask]
valid_count = x.shape[0]
if valid_count == 0:
return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
# 2. Calculate linear indices and batch-frame indices for scattering
lin_idx = y * W + x # Linear index within a single frame (H*W range)
# Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
bt_idx_raw = (
torch.arange(B * T, device=device)
.view(B, T, 1)
.expand(B, T, N)
.reshape(-1)
)
bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
# 3. Create accumulation buffers for features and weights
total_space = B * T * H * W
img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
# 4. Scatter features and weights into accumulation buffers
idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
# Add features to corresponding indices (index_add_ is efficient for sparse updates)
img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
# 5. Normalize features by valid weights, keep zeros for invalid regions
valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
# 6. Reshape to (B, T, C, H, W) for further processing
img = (
img_accum_flat.view(B, T, H, W, C)
.permute(0, 1, 4, 2, 3)
.contiguous()
) # Shape: (B, T, C, H, W)
# 7. Multi-scale pooling with weight masking to exclude zero holes
blurred_outputs = []
for k in kernel_sizes:
pad = k // 2
img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
# Create weight mask for valid regions (1 where features exist, 0 otherwise)
weight_mask = (
weight_accum_flat.view(B, T, 1, H, W) > 0
).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
# Calculate number of valid neighbors in each pooling window
weight_sum = F.conv2d(
weight_mask,
torch.ones((1, 1, k, k), device=device),
stride=1,
padding=pad
) # Shape: (B*T, 1, H, W)
# Sum features only in valid regions
feat_sum = F.conv2d(
img_bt * weight_mask, # Mask out invalid regions before summing
torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
stride=1,
padding=pad,
groups=C
) # Shape: (B*T, C, H, W)
# Compute average only over valid neighbors
feat_avg = feat_sum / (weight_sum + 1e-6)
blurred_outputs.append(feat_avg)
# 8. Fuse multi-scale results by averaging across kernel sizes
fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
return fused.view(B, T, C, H, W) # Restore original shape
#@HEADS.register_module()
class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
def __init__(self, feature_strides, **kwargs):
super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
self.training = False
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
decoder_params = kwargs['decoder_params']
embedding_dim = decoder_params['embed_dim']
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
nn.ReLU(inplace=True))
self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
depths = decoder_params['depths']
self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
self.att_temporal = nn.ModuleList(
[
AttnBlock(embedding_dim, 8,
mlp_ratio=4, flash=True, ckpt_fwd=True)
for _ in range(8)
]
)
self.att_spatial = nn.ModuleList(
[
AttnBlock(embedding_dim, 8,
mlp_ratio=4, flash=True, ckpt_fwd=True)
for _ in range(8)
]
)
self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
# Initialize reg tokens
nn.init.trunc_normal_(self.reg_tokens, std=0.02)
self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
input_resolution=(96,
96),
depth=depths,
num_heads=8,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
pool_method='fc',
downsample=None,
focal_level=2,
focal_window=5,
expand_size=3,
expand_layer="all",
use_conv_embed=False,
use_shift=False,
use_pre_norm=False,
use_checkpoint=False,
use_layerscale=False,
layerscale_value=1e-4,
focal_l_clips=[7,4,2],
focal_kernel_clips=[7,5,3])
self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
self.AO = AO(32, outchannels=3, upfactor=1)
self._c2 = None
self._c_further = None
def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
# input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
if self.training:
assert self.num_clips==num_clips
x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
c1, c2, c3, c4 = x
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
batch_size = n // num_clips
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
_, _, h, w=_c.shape
_c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
# Expand reg_tokens to match batch size
reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
_c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
assert _c_further.shape==_c2.shape
self._c2 = _c2
self._c_further = _c_further
# compute the scale and shift of the global patch
global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
global_patch = torch.cat([global_patch, reg_tokens], dim=1)
for i in range(8):
global_patch = self.att_temporal[i](global_patch)
global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
global_patch = self.att_spatial[i](global_patch)
global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
reg_tokens = global_patch[:, -2:, :]
s_ = self.scale_shift_head(reg_tokens)
scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
shift[:,:,:2,...] = 0
return scale, shift
def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
if self._c2 is None:
scale, shift = self.buffer_forward(inputs,num_clips,imgs)
B, T, N, _ = tracks.shape
_c2 = self._c2
_c_further = self._c_further
# skip and head
_c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
_c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
outframe = self.ffm2(_c_further, _c2)
tracks_uv = tracks_uvd[...,:2].clone()
track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
# visualize track_feature as video
# import cv2
# import imageio
# import os
# BT, C, H, W = outframe.shape
# track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
# track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
# track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
# track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
# imgs =(imgs.detach() + 1) * 127.5
# vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
# for b in range(B):
# frames = []
# for t in range(T):
# frame = track_feature_vis[b,t]
# frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# frames.append(frame)
# # Save as gif
# imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
# import pdb; pdb.set_trace()
track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
track_feature = self.proj_track(track_feature)
outframe = self.ffm1(edge_feat1 + track_feature,outframe)
outframe = self.ffm0(edge_feat,outframe)
outframe = self.AO(outframe)
return outframe
def reset_success(self):
self._c2 = None
self._c_further = None