first-commit
This commit is contained in:
parent
b930b1bbd4
commit
4bd03c72f3
@ -1,472 +0,0 @@
|
||||
# ---------------------------------------------------------------
|
||||
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
||||
#
|
||||
# This work is licensed under the NVIDIA Source Code License
|
||||
# ---------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.models import register_model
|
||||
from timm.models.vision_transformer import _cfg
|
||||
import math
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
|
||||
|
||||
class OverlapPatchEmbed43(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1]==4:
|
||||
x = self.proj_4c(x)
|
||||
else:
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
class MixVisionTransformer(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
||||
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
|
||||
# patch_embed 43
|
||||
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
||||
embed_dim=embed_dims[0])
|
||||
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
||||
embed_dim=embed_dims[1])
|
||||
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
||||
embed_dim=embed_dims[2])
|
||||
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
||||
embed_dim=embed_dims[3])
|
||||
|
||||
# transformer encoder
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
cur = 0
|
||||
self.block1 = nn.ModuleList([Block(
|
||||
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[0])
|
||||
for i in range(depths[0])])
|
||||
self.norm1 = norm_layer(embed_dims[0])
|
||||
|
||||
cur += depths[0]
|
||||
self.block2 = nn.ModuleList([Block(
|
||||
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[1])
|
||||
for i in range(depths[1])])
|
||||
self.norm2 = norm_layer(embed_dims[1])
|
||||
|
||||
cur += depths[1]
|
||||
self.block3 = nn.ModuleList([Block(
|
||||
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[2])
|
||||
for i in range(depths[2])])
|
||||
self.norm3 = norm_layer(embed_dims[2])
|
||||
|
||||
cur += depths[2]
|
||||
self.block4 = nn.ModuleList([Block(
|
||||
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[3])
|
||||
for i in range(depths[3])])
|
||||
self.norm4 = norm_layer(embed_dims[3])
|
||||
|
||||
# classification head
|
||||
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
||||
|
||||
def reset_drop_path(self, drop_path_rate):
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
||||
cur = 0
|
||||
for i in range(self.depths[0]):
|
||||
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[0]
|
||||
for i in range(self.depths[1]):
|
||||
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[1]
|
||||
for i in range(self.depths[2]):
|
||||
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[2]
|
||||
for i in range(self.depths[3]):
|
||||
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
def freeze_patch_emb(self):
|
||||
self.patch_embed1.requires_grad = False
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
# stage 1
|
||||
x, H, W = self.patch_embed1(x)
|
||||
for i, blk in enumerate(self.block1):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm1(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# stage 2
|
||||
x, H, W = self.patch_embed2(x)
|
||||
for i, blk in enumerate(self.block2):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm2(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# stage 3
|
||||
x, H, W = self.patch_embed3(x)
|
||||
for i, blk in enumerate(self.block3):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm3(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# stage 4
|
||||
x, H, W = self.patch_embed4(x)
|
||||
for i, blk in enumerate(self.block4):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm4(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
|
||||
def forward(self, x):
|
||||
if x.dim() == 5:
|
||||
x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, dim=768):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
#@BACKBONES.register_module()
|
||||
class mit_b0(MixVisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
super(mit_b0, self).__init__(
|
||||
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
#@BACKBONES.register_module()
|
||||
class mit_b1(MixVisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
super(mit_b1, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
#@BACKBONES.register_module()
|
||||
class mit_b2(MixVisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
super(mit_b2, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
#@BACKBONES.register_module()
|
||||
class mit_b3(MixVisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
super(mit_b3, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
#@BACKBONES.register_module()
|
||||
class mit_b4(MixVisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
super(mit_b4, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
#@BACKBONES.register_module()
|
||||
class mit_b5(MixVisionTransformer):
|
||||
def __init__(self, **kwargs):
|
||||
super(mit_b5, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
@ -1,619 +0,0 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# from mmcv.cnn import normal_init
|
||||
# from mmcv.runner import auto_fp16, force_fp32
|
||||
|
||||
# from mmseg.core import build_pixel_sampler
|
||||
# from mmseg.ops import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
decoder_params=None,
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False):
|
||||
super(BaseDecodeHead, self).__init__()
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.fp16_enabled = False
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of classification layer."""
|
||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
# @auto_fp16()
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
||||
used if the architecture supports semantic segmentation task.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs, img_metas, test_cfg):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
return self.forward(inputs)
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
|
||||
class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead_clips.
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
decoder_params=None,
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
num_clips=5):
|
||||
super(BaseDecodeHead_clips, self).__init__()
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
self.num_clips=num_clips
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.fp16_enabled = False
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of classification layer."""
|
||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
# @auto_fp16()
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
||||
used if the architecture supports semantic segmentation task.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs,batch_size, num_clips)
|
||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
return self.forward(inputs, batch_size, num_clips)
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead_clips_flow.
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
decoder_params=None,
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
num_clips=5):
|
||||
super(BaseDecodeHead_clips_flow, self).__init__()
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
self.num_clips=num_clips
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.fp16_enabled = False
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of classification layer."""
|
||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
# @auto_fp16()
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
||||
used if the architecture supports semantic segmentation task.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs,batch_size, num_clips,img)
|
||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
return self.forward(inputs, batch_size, num_clips,img)
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
@ -1,115 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
|
||||
from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
|
||||
from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
|
||||
from einops import rearrange
|
||||
class TrackStablizer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = mit_b3()
|
||||
|
||||
old_conv = self.backbone.patch_embed1.proj
|
||||
new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
|
||||
|
||||
new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
|
||||
self.backbone.patch_embed1.proj = new_conv
|
||||
|
||||
self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
|
||||
in_index=[0, 1, 2, 3],
|
||||
feature_strides=[4, 8, 16, 32],
|
||||
channels=128,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=1,
|
||||
align_corners=False,
|
||||
decoder_params=dict(embed_dim=256, depths=4),
|
||||
num_clips=16,
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True))
|
||||
|
||||
self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
|
||||
nn.ReLU(inplace=True))
|
||||
self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
|
||||
nn.ReLU(inplace=True))
|
||||
self.success = False
|
||||
self.x = None
|
||||
|
||||
def buffer_forward(self, inputs, num_clips=16):
|
||||
"""
|
||||
buffer forward for getting the pointmap and image features
|
||||
"""
|
||||
B, T, C, H, W = inputs.shape
|
||||
self.x = self.backbone(inputs)
|
||||
scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
|
||||
self.success = True
|
||||
return scale, shift
|
||||
|
||||
def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
|
||||
|
||||
"""
|
||||
Args:
|
||||
inputs: [B, T, C, H, W], RGB + PointMap + Mask
|
||||
tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
|
||||
num_clips: int, number of clips to use
|
||||
"""
|
||||
B, T, C, H, W = inputs.shape
|
||||
edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
|
||||
edge_feat1 = self.edge_conv1(edge_feat)
|
||||
|
||||
if not self.success:
|
||||
scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
|
||||
self.success = True
|
||||
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
||||
else:
|
||||
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
||||
|
||||
return update
|
||||
|
||||
def reset_success(self):
|
||||
self.success = False
|
||||
self.x = None
|
||||
self.Track_Stabilizer.reset_success()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create test input tensors
|
||||
batch_size = 1
|
||||
seq_len = 16
|
||||
channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
|
||||
height = 384
|
||||
width = 512
|
||||
|
||||
# Create random input tensor with shape [B, T, C, H, W]
|
||||
inputs = torch.randn(batch_size, seq_len, channels, height, width)
|
||||
|
||||
# Create random tracks
|
||||
tracks = torch.randn(batch_size, seq_len, 1024, 4)
|
||||
|
||||
# Create random test images
|
||||
test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
|
||||
|
||||
# Initialize model and move to GPU
|
||||
model = TrackStablizer().cuda()
|
||||
|
||||
# Move inputs to GPU and run forward pass
|
||||
inputs = inputs.cuda()
|
||||
tracks = tracks.cuda()
|
||||
outputs = model.buffer_forward(inputs, num_clips=seq_len)
|
||||
import time
|
||||
start_time = time.time()
|
||||
outputs = model(inputs, tracks, num_clips=seq_len)
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
import pdb; pdb.set_trace()
|
||||
# # Print shapes for verification
|
||||
# print(f"Input shape: {inputs.shape}")
|
||||
# print(f"Output shape: {outputs.shape}")
|
||||
|
||||
# # Basic tests
|
||||
# assert outputs.shape[0] == batch_size, "Batch size mismatch"
|
||||
# assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
|
||||
# assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
|
||||
|
||||
# print("All tests passed!")
|
||||
|
||||
@ -1,429 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Author: Ke Xian
|
||||
Email: kexian@hust.edu.cn
|
||||
Date: 2020/07/20
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
# ==============================================================================================================
|
||||
|
||||
class FTB(nn.Module):
|
||||
def __init__(self, inchannels, midchannels=512):
|
||||
super(FTB, self).__init__()
|
||||
self.in1 = inchannels
|
||||
self.mid = midchannels
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
|
||||
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
|
||||
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
|
||||
#nn.BatchNorm2d(num_features=self.mid),\
|
||||
nn.ReLU(inplace=True),\
|
||||
nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.init_params()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = x + self.conv_branch(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
# init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
# init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
class ATA(nn.Module):
|
||||
def __init__(self, inchannels, reduction = 8):
|
||||
super(ATA, self).__init__()
|
||||
self.inchannels = inchannels
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(self.inchannels // reduction, self.inchannels),
|
||||
nn.Sigmoid())
|
||||
self.init_params()
|
||||
|
||||
def forward(self, low_x, high_x):
|
||||
n, c, _, _ = low_x.size()
|
||||
x = torch.cat([low_x, high_x], 1)
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(n, -1)
|
||||
x = self.fc(x).view(n,c,1,1)
|
||||
x = low_x * x + high_x
|
||||
|
||||
return x
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
#init.normal(m.weight, std=0.01)
|
||||
init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
#init.normal_(m.weight, std=0.01)
|
||||
init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class FFM(nn.Module):
|
||||
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
||||
super(FFM, self).__init__()
|
||||
self.inchannels = inchannels
|
||||
self.midchannels = midchannels
|
||||
self.outchannels = outchannels
|
||||
self.upfactor = upfactor
|
||||
|
||||
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
|
||||
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
||||
|
||||
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
||||
|
||||
self.init_params()
|
||||
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
||||
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
||||
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, low_x, high_x):
|
||||
x = self.ftb1(low_x)
|
||||
|
||||
'''
|
||||
x = torch.cat((x,high_x),1)
|
||||
if x.shape[2] == 12:
|
||||
x = self.p1(x)
|
||||
elif x.shape[2] == 24:
|
||||
x = self.p2(x)
|
||||
elif x.shape[2] == 48:
|
||||
x = self.p3(x)
|
||||
'''
|
||||
x = x + high_x ###high_x
|
||||
x = self.ftb2(x)
|
||||
x = self.upsample(x)
|
||||
|
||||
return x
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
|
||||
class noFFM(nn.Module):
|
||||
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
||||
super(noFFM, self).__init__()
|
||||
self.inchannels = inchannels
|
||||
self.midchannels = midchannels
|
||||
self.outchannels = outchannels
|
||||
self.upfactor = upfactor
|
||||
|
||||
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
||||
|
||||
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
||||
|
||||
self.init_params()
|
||||
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
||||
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
||||
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, low_x, high_x):
|
||||
|
||||
#x = self.ftb1(low_x)
|
||||
x = high_x ###high_x
|
||||
x = self.ftb2(x)
|
||||
x = self.upsample(x)
|
||||
|
||||
return x
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
|
||||
|
||||
class AO(nn.Module):
|
||||
# Adaptive output module
|
||||
def __init__(self, inchannels, outchannels, upfactor=2):
|
||||
super(AO, self).__init__()
|
||||
self.inchannels = inchannels
|
||||
self.outchannels = outchannels
|
||||
self.upfactor = upfactor
|
||||
|
||||
"""
|
||||
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
||||
nn.BatchNorm2d(num_features=self.inchannels//2),\
|
||||
nn.ReLU(inplace=True),\
|
||||
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
|
||||
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
|
||||
#nn.ReLU(inplace=True)) ## get positive values
|
||||
"""
|
||||
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
||||
#nn.BatchNorm2d(num_features=self.inchannels//2),\
|
||||
nn.ReLU(inplace=True),\
|
||||
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
|
||||
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
|
||||
|
||||
#nn.ReLU(inplace=True)) ## get positive values
|
||||
|
||||
self.init_params()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.adapt_conv(x)
|
||||
return x
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
|
||||
super(ASPP, self).__init__()
|
||||
self.inchannels = inchannels
|
||||
self.planes = planes
|
||||
self.rates = rates
|
||||
self.kernel_sizes = []
|
||||
self.paddings = []
|
||||
for rate in self.rates:
|
||||
if rate == 1:
|
||||
self.kernel_sizes.append(1)
|
||||
self.paddings.append(0)
|
||||
else:
|
||||
self.kernel_sizes.append(3)
|
||||
self.paddings.append(rate)
|
||||
self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
|
||||
stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(num_features=self.planes)
|
||||
)
|
||||
self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
|
||||
stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(num_features=self.planes),
|
||||
)
|
||||
self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
|
||||
stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(num_features=self.planes),
|
||||
)
|
||||
self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
|
||||
stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(num_features=self.planes),
|
||||
)
|
||||
|
||||
#self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
|
||||
def forward(self, x):
|
||||
x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
|
||||
#x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
# ==============================================================================================================
|
||||
|
||||
|
||||
class ResidualConv(nn.Module):
|
||||
def __init__(self, inchannels):
|
||||
super(ResidualConv, self).__init__()
|
||||
#nn.BatchNorm2d
|
||||
self.conv = nn.Sequential(
|
||||
#nn.BatchNorm2d(num_features=inchannels),
|
||||
nn.ReLU(inplace=False),
|
||||
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
|
||||
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
|
||||
nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(num_features=inchannels//2),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
|
||||
)
|
||||
self.init_params()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)+x
|
||||
return x
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class FeatureFusion(nn.Module):
|
||||
def __init__(self, inchannels, outchannels):
|
||||
super(FeatureFusion, self).__init__()
|
||||
self.conv = ResidualConv(inchannels=inchannels)
|
||||
#nn.BatchNorm2d
|
||||
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
|
||||
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
|
||||
nn.BatchNorm2d(num_features=outchannels),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, lowfeat, highfeat):
|
||||
return self.up(highfeat + self.conv(lowfeat))
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
init.normal_(m.weight, std=0.01)
|
||||
#init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class SenceUnderstand(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(SenceUnderstand, self).__init__()
|
||||
self.channels = channels
|
||||
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace = True))
|
||||
self.pool = nn.AdaptiveAvgPool2d(8)
|
||||
self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
|
||||
nn.ReLU(inplace = True))
|
||||
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
|
||||
nn.ReLU(inplace=True))
|
||||
self.initial_params()
|
||||
|
||||
def forward(self, x):
|
||||
n,c,h,w = x.size()
|
||||
x = self.conv1(x)
|
||||
x = self.pool(x)
|
||||
x = x.view(n,-1)
|
||||
x = self.fc(x)
|
||||
x = x.view(n, self.channels, 1, 1)
|
||||
x = self.conv2(x)
|
||||
x = x.repeat(1,1,h,w)
|
||||
return x
|
||||
|
||||
def initial_params(self, dev=0.01):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
#print torch.sum(m.weight)
|
||||
m.weight.data.normal_(0, dev)
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(0)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
#print torch.sum(m.weight)
|
||||
m.weight.data.normal_(0, dev)
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, dev)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,342 +0,0 @@
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from collections import OrderedDict
|
||||
# from mmseg.ops import resize
|
||||
from torch.nn.functional import interpolate as resize
|
||||
# from builder import HEADS
|
||||
from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
|
||||
# from mmseg.models.utils import *
|
||||
import attr
|
||||
from IPython import embed
|
||||
from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
|
||||
import cv2
|
||||
from models.SpaTrackV2.models.depth_refiner.network import *
|
||||
import warnings
|
||||
# from mmcv.utils import Registry, build_from_cfg
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from models.SpaTrackV2.models.blocks import (
|
||||
AttnBlock, CrossAttnBlock, Mlp
|
||||
)
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Linear Embedding
|
||||
"""
|
||||
def __init__(self, input_dim=2048, embed_dim=768):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(input_dim, embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def scatter_multiscale_fast(
|
||||
track2d: torch.Tensor,
|
||||
trackfeature: torch.Tensor,
|
||||
H: int,
|
||||
W: int,
|
||||
kernel_sizes = [1]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
|
||||
|
||||
This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
|
||||
while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
|
||||
avoiding dilution by empty pixels.
|
||||
|
||||
Args:
|
||||
track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
|
||||
for each track point across batches, frames, and points.
|
||||
trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
|
||||
for each track point.
|
||||
H (int): Height of the target output image.
|
||||
W (int): Width of the target output image.
|
||||
kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
|
||||
"""
|
||||
B, T, N, C = trackfeature.shape
|
||||
device = trackfeature.device
|
||||
|
||||
# 1. Flatten coordinates and filter valid points within image bounds
|
||||
coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
|
||||
x = coords_flat[:, 0] # x coordinates
|
||||
y = coords_flat[:, 1] # y coordinates
|
||||
feat_flat = trackfeature.reshape(-1, C) # Flatten features
|
||||
|
||||
valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
|
||||
x = x[valid_mask]
|
||||
y = y[valid_mask]
|
||||
feat_flat = feat_flat[valid_mask]
|
||||
valid_count = x.shape[0]
|
||||
|
||||
if valid_count == 0:
|
||||
return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
|
||||
|
||||
# 2. Calculate linear indices and batch-frame indices for scattering
|
||||
lin_idx = y * W + x # Linear index within a single frame (H*W range)
|
||||
|
||||
# Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
|
||||
bt_idx_raw = (
|
||||
torch.arange(B * T, device=device)
|
||||
.view(B, T, 1)
|
||||
.expand(B, T, N)
|
||||
.reshape(-1)
|
||||
)
|
||||
bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
|
||||
|
||||
# 3. Create accumulation buffers for features and weights
|
||||
total_space = B * T * H * W
|
||||
img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
|
||||
weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
|
||||
|
||||
# 4. Scatter features and weights into accumulation buffers
|
||||
idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
|
||||
|
||||
# Add features to corresponding indices (index_add_ is efficient for sparse updates)
|
||||
img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
|
||||
weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
|
||||
|
||||
# 5. Normalize features by valid weights, keep zeros for invalid regions
|
||||
valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
|
||||
img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
|
||||
img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
|
||||
|
||||
# 6. Reshape to (B, T, C, H, W) for further processing
|
||||
img = (
|
||||
img_accum_flat.view(B, T, H, W, C)
|
||||
.permute(0, 1, 4, 2, 3)
|
||||
.contiguous()
|
||||
) # Shape: (B, T, C, H, W)
|
||||
|
||||
# 7. Multi-scale pooling with weight masking to exclude zero holes
|
||||
blurred_outputs = []
|
||||
for k in kernel_sizes:
|
||||
pad = k // 2
|
||||
img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
|
||||
|
||||
# Create weight mask for valid regions (1 where features exist, 0 otherwise)
|
||||
weight_mask = (
|
||||
weight_accum_flat.view(B, T, 1, H, W) > 0
|
||||
).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
|
||||
|
||||
# Calculate number of valid neighbors in each pooling window
|
||||
weight_sum = F.conv2d(
|
||||
weight_mask,
|
||||
torch.ones((1, 1, k, k), device=device),
|
||||
stride=1,
|
||||
padding=pad
|
||||
) # Shape: (B*T, 1, H, W)
|
||||
|
||||
# Sum features only in valid regions
|
||||
feat_sum = F.conv2d(
|
||||
img_bt * weight_mask, # Mask out invalid regions before summing
|
||||
torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
|
||||
stride=1,
|
||||
padding=pad,
|
||||
groups=C
|
||||
) # Shape: (B*T, C, H, W)
|
||||
|
||||
# Compute average only over valid neighbors
|
||||
feat_avg = feat_sum / (weight_sum + 1e-6)
|
||||
blurred_outputs.append(feat_avg)
|
||||
|
||||
# 8. Fuse multi-scale results by averaging across kernel sizes
|
||||
fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
|
||||
return fused.view(B, T, C, H, W) # Restore original shape
|
||||
|
||||
#@HEADS.register_module()
|
||||
class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
|
||||
self.training = False
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
||||
|
||||
decoder_params = kwargs['decoder_params']
|
||||
embedding_dim = decoder_params['embed_dim']
|
||||
|
||||
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
|
||||
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
|
||||
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
|
||||
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
|
||||
|
||||
self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
|
||||
|
||||
depths = decoder_params['depths']
|
||||
|
||||
self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
|
||||
self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
|
||||
|
||||
self.att_temporal = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(embedding_dim, 8,
|
||||
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
||||
for _ in range(8)
|
||||
]
|
||||
)
|
||||
self.att_spatial = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(embedding_dim, 8,
|
||||
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
||||
for _ in range(8)
|
||||
]
|
||||
)
|
||||
self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
|
||||
|
||||
|
||||
# Initialize reg tokens
|
||||
nn.init.trunc_normal_(self.reg_tokens, std=0.02)
|
||||
|
||||
self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
|
||||
input_resolution=(96,
|
||||
96),
|
||||
depth=depths,
|
||||
num_heads=8,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pool_method='fc',
|
||||
downsample=None,
|
||||
focal_level=2,
|
||||
focal_window=5,
|
||||
expand_size=3,
|
||||
expand_layer="all",
|
||||
use_conv_embed=False,
|
||||
use_shift=False,
|
||||
use_pre_norm=False,
|
||||
use_checkpoint=False,
|
||||
use_layerscale=False,
|
||||
layerscale_value=1e-4,
|
||||
focal_l_clips=[7,4,2],
|
||||
focal_kernel_clips=[7,5,3])
|
||||
|
||||
self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
|
||||
self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
|
||||
self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
|
||||
self.AO = AO(32, outchannels=3, upfactor=1)
|
||||
self._c2 = None
|
||||
self._c_further = None
|
||||
|
||||
def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
|
||||
|
||||
# input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
|
||||
if self.training:
|
||||
assert self.num_clips==num_clips
|
||||
|
||||
x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
|
||||
c1, c2, c3, c4 = x
|
||||
|
||||
############## MLP decoder on C1-C4 ###########
|
||||
n, _, h, w = c4.shape
|
||||
batch_size = n // num_clips
|
||||
|
||||
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
||||
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
||||
|
||||
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
||||
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
||||
|
||||
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
||||
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
||||
|
||||
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
||||
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
||||
|
||||
_, _, h, w=_c.shape
|
||||
_c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
|
||||
|
||||
# Expand reg_tokens to match batch size
|
||||
reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
|
||||
|
||||
_c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
|
||||
|
||||
assert _c_further.shape==_c2.shape
|
||||
self._c2 = _c2
|
||||
self._c_further = _c_further
|
||||
|
||||
# compute the scale and shift of the global patch
|
||||
global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
|
||||
global_patch = torch.cat([global_patch, reg_tokens], dim=1)
|
||||
for i in range(8):
|
||||
global_patch = self.att_temporal[i](global_patch)
|
||||
global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
||||
global_patch = self.att_spatial[i](global_patch)
|
||||
global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
||||
|
||||
reg_tokens = global_patch[:, -2:, :]
|
||||
s_ = self.scale_shift_head(reg_tokens)
|
||||
scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
|
||||
shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
|
||||
shift[:,:,:2,...] = 0
|
||||
return scale, shift
|
||||
|
||||
def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
|
||||
|
||||
if self._c2 is None:
|
||||
scale, shift = self.buffer_forward(inputs,num_clips,imgs)
|
||||
|
||||
B, T, N, _ = tracks.shape
|
||||
|
||||
_c2 = self._c2
|
||||
_c_further = self._c_further
|
||||
|
||||
# skip and head
|
||||
_c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
|
||||
_c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
|
||||
|
||||
outframe = self.ffm2(_c_further, _c2)
|
||||
|
||||
tracks_uv = tracks_uvd[...,:2].clone()
|
||||
track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
|
||||
# visualize track_feature as video
|
||||
# import cv2
|
||||
# import imageio
|
||||
# import os
|
||||
# BT, C, H, W = outframe.shape
|
||||
# track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
|
||||
# track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
|
||||
# track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
|
||||
# track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
|
||||
# imgs =(imgs.detach() + 1) * 127.5
|
||||
# vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
|
||||
# for b in range(B):
|
||||
# frames = []
|
||||
# for t in range(T):
|
||||
# frame = track_feature_vis[b,t]
|
||||
# frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
|
||||
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
# frames.append(frame)
|
||||
# # Save as gif
|
||||
# imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
|
||||
# import pdb; pdb.set_trace()
|
||||
track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
|
||||
track_feature = self.proj_track(track_feature)
|
||||
outframe = self.ffm1(edge_feat1 + track_feature,outframe)
|
||||
outframe = self.ffm0(edge_feat,outframe)
|
||||
outframe = self.AO(outframe)
|
||||
|
||||
return outframe
|
||||
|
||||
def reset_success(self):
|
||||
self._c2 = None
|
||||
self._c_further = None
|
||||
Loading…
x
Reference in New Issue
Block a user