first-commit
This commit is contained in:
parent
b930b1bbd4
commit
4bd03c72f3
@ -1,472 +0,0 @@
|
|||||||
# ---------------------------------------------------------------
|
|
||||||
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
|
||||||
#
|
|
||||||
# This work is licensed under the NVIDIA Source Code License
|
|
||||||
# ---------------------------------------------------------------
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
|
||||||
from timm.models import register_model
|
|
||||||
from timm.models.vision_transformer import _cfg
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
|
||||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
||||||
super().__init__()
|
|
||||||
out_features = out_features or in_features
|
|
||||||
hidden_features = hidden_features or in_features
|
|
||||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
||||||
self.dwconv = DWConv(hidden_features)
|
|
||||||
self.act = act_layer()
|
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
||||||
self.drop = nn.Dropout(drop)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x, H, W):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.dwconv(x, H, W)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
|
||||||
super().__init__()
|
|
||||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
self.scale = qk_scale or head_dim ** -0.5
|
|
||||||
|
|
||||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
|
||||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
self.sr_ratio = sr_ratio
|
|
||||||
if sr_ratio > 1:
|
|
||||||
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
|
||||||
self.norm = nn.LayerNorm(dim)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x, H, W):
|
|
||||||
B, N, C = x.shape
|
|
||||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if self.sr_ratio > 1:
|
|
||||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
|
||||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
|
||||||
x_ = self.norm(x_)
|
|
||||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
||||||
else:
|
|
||||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
||||||
k, v = kv[0], kv[1]
|
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
||||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = norm_layer(dim)
|
|
||||||
self.attn = Attention(
|
|
||||||
dim,
|
|
||||||
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
|
||||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
||||||
self.norm2 = norm_layer(dim)
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x, H, W):
|
|
||||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
|
||||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class OverlapPatchEmbed(nn.Module):
|
|
||||||
""" Image to Patch Embedding
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
|
||||||
super().__init__()
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
patch_size = to_2tuple(patch_size)
|
|
||||||
|
|
||||||
self.img_size = img_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
|
||||||
self.num_patches = self.H * self.W
|
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
|
||||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
|
||||||
self.norm = nn.LayerNorm(embed_dim)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
_, _, H, W = x.shape
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
return x, H, W
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OverlapPatchEmbed43(nn.Module):
|
|
||||||
""" Image to Patch Embedding
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
|
||||||
super().__init__()
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
patch_size = to_2tuple(patch_size)
|
|
||||||
|
|
||||||
self.img_size = img_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
|
||||||
self.num_patches = self.H * self.W
|
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
|
||||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
|
||||||
self.norm = nn.LayerNorm(embed_dim)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.shape[1]==4:
|
|
||||||
x = self.proj_4c(x)
|
|
||||||
else:
|
|
||||||
x = self.proj(x)
|
|
||||||
_, _, H, W = x.shape
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
return x, H, W
|
|
||||||
|
|
||||||
class MixVisionTransformer(nn.Module):
|
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
|
||||||
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
|
||||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
|
||||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
|
||||||
super().__init__()
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.depths = depths
|
|
||||||
|
|
||||||
# patch_embed 43
|
|
||||||
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
|
||||||
embed_dim=embed_dims[0])
|
|
||||||
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
|
||||||
embed_dim=embed_dims[1])
|
|
||||||
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
|
||||||
embed_dim=embed_dims[2])
|
|
||||||
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
|
||||||
embed_dim=embed_dims[3])
|
|
||||||
|
|
||||||
# transformer encoder
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|
||||||
cur = 0
|
|
||||||
self.block1 = nn.ModuleList([Block(
|
|
||||||
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
|
||||||
sr_ratio=sr_ratios[0])
|
|
||||||
for i in range(depths[0])])
|
|
||||||
self.norm1 = norm_layer(embed_dims[0])
|
|
||||||
|
|
||||||
cur += depths[0]
|
|
||||||
self.block2 = nn.ModuleList([Block(
|
|
||||||
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
|
||||||
sr_ratio=sr_ratios[1])
|
|
||||||
for i in range(depths[1])])
|
|
||||||
self.norm2 = norm_layer(embed_dims[1])
|
|
||||||
|
|
||||||
cur += depths[1]
|
|
||||||
self.block3 = nn.ModuleList([Block(
|
|
||||||
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
|
||||||
sr_ratio=sr_ratios[2])
|
|
||||||
for i in range(depths[2])])
|
|
||||||
self.norm3 = norm_layer(embed_dims[2])
|
|
||||||
|
|
||||||
cur += depths[2]
|
|
||||||
self.block4 = nn.ModuleList([Block(
|
|
||||||
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
|
||||||
sr_ratio=sr_ratios[3])
|
|
||||||
for i in range(depths[3])])
|
|
||||||
self.norm4 = norm_layer(embed_dims[3])
|
|
||||||
|
|
||||||
# classification head
|
|
||||||
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
elif isinstance(m, nn.Conv2d):
|
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = get_root_logger()
|
|
||||||
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
|
||||||
|
|
||||||
def reset_drop_path(self, drop_path_rate):
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
|
||||||
cur = 0
|
|
||||||
for i in range(self.depths[0]):
|
|
||||||
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
|
||||||
|
|
||||||
cur += self.depths[0]
|
|
||||||
for i in range(self.depths[1]):
|
|
||||||
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
|
||||||
|
|
||||||
cur += self.depths[1]
|
|
||||||
for i in range(self.depths[2]):
|
|
||||||
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
|
||||||
|
|
||||||
cur += self.depths[2]
|
|
||||||
for i in range(self.depths[3]):
|
|
||||||
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
|
||||||
|
|
||||||
def freeze_patch_emb(self):
|
|
||||||
self.patch_embed1.requires_grad = False
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def no_weight_decay(self):
|
|
||||||
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
|
||||||
|
|
||||||
def get_classifier(self):
|
|
||||||
return self.head
|
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool=''):
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
|
||||||
B = x.shape[0]
|
|
||||||
outs = []
|
|
||||||
|
|
||||||
# stage 1
|
|
||||||
x, H, W = self.patch_embed1(x)
|
|
||||||
for i, blk in enumerate(self.block1):
|
|
||||||
x = blk(x, H, W)
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
||||||
outs.append(x)
|
|
||||||
|
|
||||||
# stage 2
|
|
||||||
x, H, W = self.patch_embed2(x)
|
|
||||||
for i, blk in enumerate(self.block2):
|
|
||||||
x = blk(x, H, W)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
||||||
outs.append(x)
|
|
||||||
|
|
||||||
# stage 3
|
|
||||||
x, H, W = self.patch_embed3(x)
|
|
||||||
for i, blk in enumerate(self.block3):
|
|
||||||
x = blk(x, H, W)
|
|
||||||
x = self.norm3(x)
|
|
||||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
||||||
outs.append(x)
|
|
||||||
|
|
||||||
# stage 4
|
|
||||||
x, H, W = self.patch_embed4(x)
|
|
||||||
for i, blk in enumerate(self.block4):
|
|
||||||
x = blk(x, H, W)
|
|
||||||
x = self.norm4(x)
|
|
||||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
||||||
outs.append(x)
|
|
||||||
|
|
||||||
return outs
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.dim() == 5:
|
|
||||||
x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
|
|
||||||
x = self.forward_features(x)
|
|
||||||
# x = self.head(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DWConv(nn.Module):
|
|
||||||
def __init__(self, dim=768):
|
|
||||||
super(DWConv, self).__init__()
|
|
||||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
|
||||||
|
|
||||||
def forward(self, x, H, W):
|
|
||||||
B, N, C = x.shape
|
|
||||||
x = x.transpose(1, 2).view(B, C, H, W)
|
|
||||||
x = self.dwconv(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#@BACKBONES.register_module()
|
|
||||||
class mit_b0(MixVisionTransformer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(mit_b0, self).__init__(
|
|
||||||
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
|
||||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
|
||||||
drop_rate=0.0, drop_path_rate=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
#@BACKBONES.register_module()
|
|
||||||
class mit_b1(MixVisionTransformer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(mit_b1, self).__init__(
|
|
||||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
|
||||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
|
||||||
drop_rate=0.0, drop_path_rate=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
#@BACKBONES.register_module()
|
|
||||||
class mit_b2(MixVisionTransformer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(mit_b2, self).__init__(
|
|
||||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
|
||||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
|
||||||
drop_rate=0.0, drop_path_rate=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
#@BACKBONES.register_module()
|
|
||||||
class mit_b3(MixVisionTransformer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(mit_b3, self).__init__(
|
|
||||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
|
||||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
|
||||||
drop_rate=0.0, drop_path_rate=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
#@BACKBONES.register_module()
|
|
||||||
class mit_b4(MixVisionTransformer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(mit_b4, self).__init__(
|
|
||||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
|
||||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
|
||||||
drop_rate=0.0, drop_path_rate=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
#@BACKBONES.register_module()
|
|
||||||
class mit_b5(MixVisionTransformer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(mit_b5, self).__init__(
|
|
||||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
|
||||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
|
||||||
drop_rate=0.0, drop_path_rate=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,619 +0,0 @@
|
|||||||
from abc import ABCMeta, abstractmethod
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
# from mmcv.cnn import normal_init
|
|
||||||
# from mmcv.runner import auto_fp16, force_fp32
|
|
||||||
|
|
||||||
# from mmseg.core import build_pixel_sampler
|
|
||||||
# from mmseg.ops import resize
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|
||||||
"""Base class for BaseDecodeHead.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int|Sequence[int]): Input channels.
|
|
||||||
channels (int): Channels after modules, before conv_seg.
|
|
||||||
num_classes (int): Number of classes.
|
|
||||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
|
||||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
|
||||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
|
||||||
act_cfg (dict): Config of activation layers.
|
|
||||||
Default: dict(type='ReLU')
|
|
||||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
|
||||||
input_transform (str|None): Transformation type of input features.
|
|
||||||
Options: 'resize_concat', 'multiple_select', None.
|
|
||||||
'resize_concat': Multiple feature maps will be resize to the
|
|
||||||
same size as first one and than concat together.
|
|
||||||
Usually used in FCN head of HRNet.
|
|
||||||
'multiple_select': Multiple feature maps will be bundle into
|
|
||||||
a list and passed into decode head.
|
|
||||||
None: Only one select feature map is allowed.
|
|
||||||
Default: None.
|
|
||||||
loss_decode (dict): Config of decode loss.
|
|
||||||
Default: dict(type='CrossEntropyLoss').
|
|
||||||
ignore_index (int | None): The label index to be ignored. When using
|
|
||||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
|
||||||
sampler (dict|None): The config of segmentation map sampler.
|
|
||||||
Default: None.
|
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
|
||||||
Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
channels,
|
|
||||||
*,
|
|
||||||
num_classes,
|
|
||||||
dropout_ratio=0.1,
|
|
||||||
conv_cfg=None,
|
|
||||||
norm_cfg=None,
|
|
||||||
act_cfg=dict(type='ReLU'),
|
|
||||||
in_index=-1,
|
|
||||||
input_transform=None,
|
|
||||||
loss_decode=dict(
|
|
||||||
type='CrossEntropyLoss',
|
|
||||||
use_sigmoid=False,
|
|
||||||
loss_weight=1.0),
|
|
||||||
decoder_params=None,
|
|
||||||
ignore_index=255,
|
|
||||||
sampler=None,
|
|
||||||
align_corners=False):
|
|
||||||
super(BaseDecodeHead, self).__init__()
|
|
||||||
self._init_inputs(in_channels, in_index, input_transform)
|
|
||||||
self.channels = channels
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.dropout_ratio = dropout_ratio
|
|
||||||
self.conv_cfg = conv_cfg
|
|
||||||
self.norm_cfg = norm_cfg
|
|
||||||
self.act_cfg = act_cfg
|
|
||||||
self.in_index = in_index
|
|
||||||
self.ignore_index = ignore_index
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
if sampler is not None:
|
|
||||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
|
||||||
else:
|
|
||||||
self.sampler = None
|
|
||||||
|
|
||||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
|
||||||
if dropout_ratio > 0:
|
|
||||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
self.fp16_enabled = False
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
"""Extra repr."""
|
|
||||||
s = f'input_transform={self.input_transform}, ' \
|
|
||||||
f'ignore_index={self.ignore_index}, ' \
|
|
||||||
f'align_corners={self.align_corners}'
|
|
||||||
return s
|
|
||||||
|
|
||||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
|
||||||
"""Check and initialize input transforms.
|
|
||||||
|
|
||||||
The in_channels, in_index and input_transform must match.
|
|
||||||
Specifically, when input_transform is None, only single feature map
|
|
||||||
will be selected. So in_channels and in_index must be of type int.
|
|
||||||
When input_transform
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int|Sequence[int]): Input channels.
|
|
||||||
in_index (int|Sequence[int]): Input feature index.
|
|
||||||
input_transform (str|None): Transformation type of input features.
|
|
||||||
Options: 'resize_concat', 'multiple_select', None.
|
|
||||||
'resize_concat': Multiple feature maps will be resize to the
|
|
||||||
same size as first one and than concat together.
|
|
||||||
Usually used in FCN head of HRNet.
|
|
||||||
'multiple_select': Multiple feature maps will be bundle into
|
|
||||||
a list and passed into decode head.
|
|
||||||
None: Only one select feature map is allowed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if input_transform is not None:
|
|
||||||
assert input_transform in ['resize_concat', 'multiple_select']
|
|
||||||
self.input_transform = input_transform
|
|
||||||
self.in_index = in_index
|
|
||||||
if input_transform is not None:
|
|
||||||
assert isinstance(in_channels, (list, tuple))
|
|
||||||
assert isinstance(in_index, (list, tuple))
|
|
||||||
assert len(in_channels) == len(in_index)
|
|
||||||
if input_transform == 'resize_concat':
|
|
||||||
self.in_channels = sum(in_channels)
|
|
||||||
else:
|
|
||||||
self.in_channels = in_channels
|
|
||||||
else:
|
|
||||||
assert isinstance(in_channels, int)
|
|
||||||
assert isinstance(in_index, int)
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
"""Initialize weights of classification layer."""
|
|
||||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
|
||||||
|
|
||||||
def _transform_inputs(self, inputs):
|
|
||||||
"""Transform inputs for decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: The transformed inputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.input_transform == 'resize_concat':
|
|
||||||
inputs = [inputs[i] for i in self.in_index]
|
|
||||||
upsampled_inputs = [
|
|
||||||
resize(
|
|
||||||
input=x,
|
|
||||||
size=inputs[0].shape[2:],
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=self.align_corners) for x in inputs
|
|
||||||
]
|
|
||||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
|
||||||
elif self.input_transform == 'multiple_select':
|
|
||||||
inputs = [inputs[i] for i in self.in_index]
|
|
||||||
else:
|
|
||||||
inputs = inputs[self.in_index]
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
# @auto_fp16()
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, inputs):
|
|
||||||
"""Placeholder of forward function."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
|
||||||
"""Forward function for training.
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
img_metas (list[dict]): List of image info dict where each dict
|
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
||||||
For details on the values of these keys see
|
|
||||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
||||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
|
||||||
used if the architecture supports semantic segmentation task.
|
|
||||||
train_cfg (dict): The training config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Tensor]: a dictionary of loss components
|
|
||||||
"""
|
|
||||||
seg_logits = self.forward(inputs)
|
|
||||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def forward_test(self, inputs, img_metas, test_cfg):
|
|
||||||
"""Forward function for testing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
img_metas (list[dict]): List of image info dict where each dict
|
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
||||||
For details on the values of these keys see
|
|
||||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
||||||
test_cfg (dict): The testing config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output segmentation map.
|
|
||||||
"""
|
|
||||||
return self.forward(inputs)
|
|
||||||
|
|
||||||
def cls_seg(self, feat):
|
|
||||||
"""Classify each pixel."""
|
|
||||||
if self.dropout is not None:
|
|
||||||
feat = self.dropout(feat)
|
|
||||||
output = self.conv_seg(feat)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
|
|
||||||
"""Base class for BaseDecodeHead_clips.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int|Sequence[int]): Input channels.
|
|
||||||
channels (int): Channels after modules, before conv_seg.
|
|
||||||
num_classes (int): Number of classes.
|
|
||||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
|
||||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
|
||||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
|
||||||
act_cfg (dict): Config of activation layers.
|
|
||||||
Default: dict(type='ReLU')
|
|
||||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
|
||||||
input_transform (str|None): Transformation type of input features.
|
|
||||||
Options: 'resize_concat', 'multiple_select', None.
|
|
||||||
'resize_concat': Multiple feature maps will be resize to the
|
|
||||||
same size as first one and than concat together.
|
|
||||||
Usually used in FCN head of HRNet.
|
|
||||||
'multiple_select': Multiple feature maps will be bundle into
|
|
||||||
a list and passed into decode head.
|
|
||||||
None: Only one select feature map is allowed.
|
|
||||||
Default: None.
|
|
||||||
loss_decode (dict): Config of decode loss.
|
|
||||||
Default: dict(type='CrossEntropyLoss').
|
|
||||||
ignore_index (int | None): The label index to be ignored. When using
|
|
||||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
|
||||||
sampler (dict|None): The config of segmentation map sampler.
|
|
||||||
Default: None.
|
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
|
||||||
Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
channels,
|
|
||||||
*,
|
|
||||||
num_classes,
|
|
||||||
dropout_ratio=0.1,
|
|
||||||
conv_cfg=None,
|
|
||||||
norm_cfg=None,
|
|
||||||
act_cfg=dict(type='ReLU'),
|
|
||||||
in_index=-1,
|
|
||||||
input_transform=None,
|
|
||||||
loss_decode=dict(
|
|
||||||
type='CrossEntropyLoss',
|
|
||||||
use_sigmoid=False,
|
|
||||||
loss_weight=1.0),
|
|
||||||
decoder_params=None,
|
|
||||||
ignore_index=255,
|
|
||||||
sampler=None,
|
|
||||||
align_corners=False,
|
|
||||||
num_clips=5):
|
|
||||||
super(BaseDecodeHead_clips, self).__init__()
|
|
||||||
self._init_inputs(in_channels, in_index, input_transform)
|
|
||||||
self.channels = channels
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.dropout_ratio = dropout_ratio
|
|
||||||
self.conv_cfg = conv_cfg
|
|
||||||
self.norm_cfg = norm_cfg
|
|
||||||
self.act_cfg = act_cfg
|
|
||||||
self.in_index = in_index
|
|
||||||
self.ignore_index = ignore_index
|
|
||||||
self.align_corners = align_corners
|
|
||||||
self.num_clips=num_clips
|
|
||||||
|
|
||||||
if sampler is not None:
|
|
||||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
|
||||||
else:
|
|
||||||
self.sampler = None
|
|
||||||
|
|
||||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
|
||||||
if dropout_ratio > 0:
|
|
||||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
self.fp16_enabled = False
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
"""Extra repr."""
|
|
||||||
s = f'input_transform={self.input_transform}, ' \
|
|
||||||
f'ignore_index={self.ignore_index}, ' \
|
|
||||||
f'align_corners={self.align_corners}'
|
|
||||||
return s
|
|
||||||
|
|
||||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
|
||||||
"""Check and initialize input transforms.
|
|
||||||
|
|
||||||
The in_channels, in_index and input_transform must match.
|
|
||||||
Specifically, when input_transform is None, only single feature map
|
|
||||||
will be selected. So in_channels and in_index must be of type int.
|
|
||||||
When input_transform
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int|Sequence[int]): Input channels.
|
|
||||||
in_index (int|Sequence[int]): Input feature index.
|
|
||||||
input_transform (str|None): Transformation type of input features.
|
|
||||||
Options: 'resize_concat', 'multiple_select', None.
|
|
||||||
'resize_concat': Multiple feature maps will be resize to the
|
|
||||||
same size as first one and than concat together.
|
|
||||||
Usually used in FCN head of HRNet.
|
|
||||||
'multiple_select': Multiple feature maps will be bundle into
|
|
||||||
a list and passed into decode head.
|
|
||||||
None: Only one select feature map is allowed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if input_transform is not None:
|
|
||||||
assert input_transform in ['resize_concat', 'multiple_select']
|
|
||||||
self.input_transform = input_transform
|
|
||||||
self.in_index = in_index
|
|
||||||
if input_transform is not None:
|
|
||||||
assert isinstance(in_channels, (list, tuple))
|
|
||||||
assert isinstance(in_index, (list, tuple))
|
|
||||||
assert len(in_channels) == len(in_index)
|
|
||||||
if input_transform == 'resize_concat':
|
|
||||||
self.in_channels = sum(in_channels)
|
|
||||||
else:
|
|
||||||
self.in_channels = in_channels
|
|
||||||
else:
|
|
||||||
assert isinstance(in_channels, int)
|
|
||||||
assert isinstance(in_index, int)
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
"""Initialize weights of classification layer."""
|
|
||||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
|
||||||
|
|
||||||
def _transform_inputs(self, inputs):
|
|
||||||
"""Transform inputs for decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: The transformed inputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.input_transform == 'resize_concat':
|
|
||||||
inputs = [inputs[i] for i in self.in_index]
|
|
||||||
upsampled_inputs = [
|
|
||||||
resize(
|
|
||||||
input=x,
|
|
||||||
size=inputs[0].shape[2:],
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=self.align_corners) for x in inputs
|
|
||||||
]
|
|
||||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
|
||||||
elif self.input_transform == 'multiple_select':
|
|
||||||
inputs = [inputs[i] for i in self.in_index]
|
|
||||||
else:
|
|
||||||
inputs = inputs[self.in_index]
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
# @auto_fp16()
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, inputs):
|
|
||||||
"""Placeholder of forward function."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
|
|
||||||
"""Forward function for training.
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
img_metas (list[dict]): List of image info dict where each dict
|
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
||||||
For details on the values of these keys see
|
|
||||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
||||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
|
||||||
used if the architecture supports semantic segmentation task.
|
|
||||||
train_cfg (dict): The training config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Tensor]: a dictionary of loss components
|
|
||||||
"""
|
|
||||||
seg_logits = self.forward(inputs,batch_size, num_clips)
|
|
||||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
|
|
||||||
"""Forward function for testing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
img_metas (list[dict]): List of image info dict where each dict
|
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
||||||
For details on the values of these keys see
|
|
||||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
||||||
test_cfg (dict): The testing config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output segmentation map.
|
|
||||||
"""
|
|
||||||
return self.forward(inputs, batch_size, num_clips)
|
|
||||||
|
|
||||||
def cls_seg(self, feat):
|
|
||||||
"""Classify each pixel."""
|
|
||||||
if self.dropout is not None:
|
|
||||||
feat = self.dropout(feat)
|
|
||||||
output = self.conv_seg(feat)
|
|
||||||
return output
|
|
||||||
|
|
||||||
class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
|
|
||||||
"""Base class for BaseDecodeHead_clips_flow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int|Sequence[int]): Input channels.
|
|
||||||
channels (int): Channels after modules, before conv_seg.
|
|
||||||
num_classes (int): Number of classes.
|
|
||||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
|
||||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
|
||||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
|
||||||
act_cfg (dict): Config of activation layers.
|
|
||||||
Default: dict(type='ReLU')
|
|
||||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
|
||||||
input_transform (str|None): Transformation type of input features.
|
|
||||||
Options: 'resize_concat', 'multiple_select', None.
|
|
||||||
'resize_concat': Multiple feature maps will be resize to the
|
|
||||||
same size as first one and than concat together.
|
|
||||||
Usually used in FCN head of HRNet.
|
|
||||||
'multiple_select': Multiple feature maps will be bundle into
|
|
||||||
a list and passed into decode head.
|
|
||||||
None: Only one select feature map is allowed.
|
|
||||||
Default: None.
|
|
||||||
loss_decode (dict): Config of decode loss.
|
|
||||||
Default: dict(type='CrossEntropyLoss').
|
|
||||||
ignore_index (int | None): The label index to be ignored. When using
|
|
||||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
|
||||||
sampler (dict|None): The config of segmentation map sampler.
|
|
||||||
Default: None.
|
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
|
||||||
Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
channels,
|
|
||||||
*,
|
|
||||||
num_classes,
|
|
||||||
dropout_ratio=0.1,
|
|
||||||
conv_cfg=None,
|
|
||||||
norm_cfg=None,
|
|
||||||
act_cfg=dict(type='ReLU'),
|
|
||||||
in_index=-1,
|
|
||||||
input_transform=None,
|
|
||||||
loss_decode=dict(
|
|
||||||
type='CrossEntropyLoss',
|
|
||||||
use_sigmoid=False,
|
|
||||||
loss_weight=1.0),
|
|
||||||
decoder_params=None,
|
|
||||||
ignore_index=255,
|
|
||||||
sampler=None,
|
|
||||||
align_corners=False,
|
|
||||||
num_clips=5):
|
|
||||||
super(BaseDecodeHead_clips_flow, self).__init__()
|
|
||||||
self._init_inputs(in_channels, in_index, input_transform)
|
|
||||||
self.channels = channels
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.dropout_ratio = dropout_ratio
|
|
||||||
self.conv_cfg = conv_cfg
|
|
||||||
self.norm_cfg = norm_cfg
|
|
||||||
self.act_cfg = act_cfg
|
|
||||||
self.in_index = in_index
|
|
||||||
self.ignore_index = ignore_index
|
|
||||||
self.align_corners = align_corners
|
|
||||||
self.num_clips=num_clips
|
|
||||||
|
|
||||||
if sampler is not None:
|
|
||||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
|
||||||
else:
|
|
||||||
self.sampler = None
|
|
||||||
|
|
||||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
|
||||||
if dropout_ratio > 0:
|
|
||||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
self.fp16_enabled = False
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
"""Extra repr."""
|
|
||||||
s = f'input_transform={self.input_transform}, ' \
|
|
||||||
f'ignore_index={self.ignore_index}, ' \
|
|
||||||
f'align_corners={self.align_corners}'
|
|
||||||
return s
|
|
||||||
|
|
||||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
|
||||||
"""Check and initialize input transforms.
|
|
||||||
|
|
||||||
The in_channels, in_index and input_transform must match.
|
|
||||||
Specifically, when input_transform is None, only single feature map
|
|
||||||
will be selected. So in_channels and in_index must be of type int.
|
|
||||||
When input_transform
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int|Sequence[int]): Input channels.
|
|
||||||
in_index (int|Sequence[int]): Input feature index.
|
|
||||||
input_transform (str|None): Transformation type of input features.
|
|
||||||
Options: 'resize_concat', 'multiple_select', None.
|
|
||||||
'resize_concat': Multiple feature maps will be resize to the
|
|
||||||
same size as first one and than concat together.
|
|
||||||
Usually used in FCN head of HRNet.
|
|
||||||
'multiple_select': Multiple feature maps will be bundle into
|
|
||||||
a list and passed into decode head.
|
|
||||||
None: Only one select feature map is allowed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if input_transform is not None:
|
|
||||||
assert input_transform in ['resize_concat', 'multiple_select']
|
|
||||||
self.input_transform = input_transform
|
|
||||||
self.in_index = in_index
|
|
||||||
if input_transform is not None:
|
|
||||||
assert isinstance(in_channels, (list, tuple))
|
|
||||||
assert isinstance(in_index, (list, tuple))
|
|
||||||
assert len(in_channels) == len(in_index)
|
|
||||||
if input_transform == 'resize_concat':
|
|
||||||
self.in_channels = sum(in_channels)
|
|
||||||
else:
|
|
||||||
self.in_channels = in_channels
|
|
||||||
else:
|
|
||||||
assert isinstance(in_channels, int)
|
|
||||||
assert isinstance(in_index, int)
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
"""Initialize weights of classification layer."""
|
|
||||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
|
||||||
|
|
||||||
def _transform_inputs(self, inputs):
|
|
||||||
"""Transform inputs for decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: The transformed inputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.input_transform == 'resize_concat':
|
|
||||||
inputs = [inputs[i] for i in self.in_index]
|
|
||||||
upsampled_inputs = [
|
|
||||||
resize(
|
|
||||||
input=x,
|
|
||||||
size=inputs[0].shape[2:],
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=self.align_corners) for x in inputs
|
|
||||||
]
|
|
||||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
|
||||||
elif self.input_transform == 'multiple_select':
|
|
||||||
inputs = [inputs[i] for i in self.in_index]
|
|
||||||
else:
|
|
||||||
inputs = inputs[self.in_index]
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
# @auto_fp16()
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, inputs):
|
|
||||||
"""Placeholder of forward function."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
|
|
||||||
"""Forward function for training.
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
img_metas (list[dict]): List of image info dict where each dict
|
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
||||||
For details on the values of these keys see
|
|
||||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
||||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
|
||||||
used if the architecture supports semantic segmentation task.
|
|
||||||
train_cfg (dict): The training config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Tensor]: a dictionary of loss components
|
|
||||||
"""
|
|
||||||
seg_logits = self.forward(inputs,batch_size, num_clips,img)
|
|
||||||
losses = self.losses(seg_logits, gt_semantic_seg)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
|
|
||||||
"""Forward function for testing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list[Tensor]): List of multi-level img features.
|
|
||||||
img_metas (list[dict]): List of image info dict where each dict
|
|
||||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
||||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
||||||
For details on the values of these keys see
|
|
||||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
||||||
test_cfg (dict): The testing config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output segmentation map.
|
|
||||||
"""
|
|
||||||
return self.forward(inputs, batch_size, num_clips,img)
|
|
||||||
|
|
||||||
def cls_seg(self, feat):
|
|
||||||
"""Classify each pixel."""
|
|
||||||
if self.dropout is not None:
|
|
||||||
feat = self.dropout(feat)
|
|
||||||
output = self.conv_seg(feat)
|
|
||||||
return output
|
|
||||||
@ -1,115 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
|
|
||||||
from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
|
|
||||||
from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
|
|
||||||
from einops import rearrange
|
|
||||||
class TrackStablizer(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.backbone = mit_b3()
|
|
||||||
|
|
||||||
old_conv = self.backbone.patch_embed1.proj
|
|
||||||
new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
|
|
||||||
|
|
||||||
new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
|
|
||||||
self.backbone.patch_embed1.proj = new_conv
|
|
||||||
|
|
||||||
self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
|
|
||||||
in_index=[0, 1, 2, 3],
|
|
||||||
feature_strides=[4, 8, 16, 32],
|
|
||||||
channels=128,
|
|
||||||
dropout_ratio=0.1,
|
|
||||||
num_classes=1,
|
|
||||||
align_corners=False,
|
|
||||||
decoder_params=dict(embed_dim=256, depths=4),
|
|
||||||
num_clips=16,
|
|
||||||
norm_cfg = dict(type='SyncBN', requires_grad=True))
|
|
||||||
|
|
||||||
self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
|
|
||||||
nn.ReLU(inplace=True))
|
|
||||||
self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
|
|
||||||
nn.ReLU(inplace=True))
|
|
||||||
self.success = False
|
|
||||||
self.x = None
|
|
||||||
|
|
||||||
def buffer_forward(self, inputs, num_clips=16):
|
|
||||||
"""
|
|
||||||
buffer forward for getting the pointmap and image features
|
|
||||||
"""
|
|
||||||
B, T, C, H, W = inputs.shape
|
|
||||||
self.x = self.backbone(inputs)
|
|
||||||
scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
|
|
||||||
self.success = True
|
|
||||||
return scale, shift
|
|
||||||
|
|
||||||
def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
inputs: [B, T, C, H, W], RGB + PointMap + Mask
|
|
||||||
tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
|
|
||||||
num_clips: int, number of clips to use
|
|
||||||
"""
|
|
||||||
B, T, C, H, W = inputs.shape
|
|
||||||
edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
|
|
||||||
edge_feat1 = self.edge_conv1(edge_feat)
|
|
||||||
|
|
||||||
if not self.success:
|
|
||||||
scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
|
|
||||||
self.success = True
|
|
||||||
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
|
||||||
else:
|
|
||||||
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
|
||||||
|
|
||||||
return update
|
|
||||||
|
|
||||||
def reset_success(self):
|
|
||||||
self.success = False
|
|
||||||
self.x = None
|
|
||||||
self.Track_Stabilizer.reset_success()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Create test input tensors
|
|
||||||
batch_size = 1
|
|
||||||
seq_len = 16
|
|
||||||
channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
|
|
||||||
height = 384
|
|
||||||
width = 512
|
|
||||||
|
|
||||||
# Create random input tensor with shape [B, T, C, H, W]
|
|
||||||
inputs = torch.randn(batch_size, seq_len, channels, height, width)
|
|
||||||
|
|
||||||
# Create random tracks
|
|
||||||
tracks = torch.randn(batch_size, seq_len, 1024, 4)
|
|
||||||
|
|
||||||
# Create random test images
|
|
||||||
test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
|
|
||||||
|
|
||||||
# Initialize model and move to GPU
|
|
||||||
model = TrackStablizer().cuda()
|
|
||||||
|
|
||||||
# Move inputs to GPU and run forward pass
|
|
||||||
inputs = inputs.cuda()
|
|
||||||
tracks = tracks.cuda()
|
|
||||||
outputs = model.buffer_forward(inputs, num_clips=seq_len)
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
outputs = model(inputs, tracks, num_clips=seq_len)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
# # Print shapes for verification
|
|
||||||
# print(f"Input shape: {inputs.shape}")
|
|
||||||
# print(f"Output shape: {outputs.shape}")
|
|
||||||
|
|
||||||
# # Basic tests
|
|
||||||
# assert outputs.shape[0] == batch_size, "Batch size mismatch"
|
|
||||||
# assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
|
|
||||||
# assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
|
|
||||||
|
|
||||||
# print("All tests passed!")
|
|
||||||
|
|
||||||
@ -1,429 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
Author: Ke Xian
|
|
||||||
Email: kexian@hust.edu.cn
|
|
||||||
Date: 2020/07/20
|
|
||||||
'''
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
# ==============================================================================================================
|
|
||||||
|
|
||||||
class FTB(nn.Module):
|
|
||||||
def __init__(self, inchannels, midchannels=512):
|
|
||||||
super(FTB, self).__init__()
|
|
||||||
self.in1 = inchannels
|
|
||||||
self.mid = midchannels
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
|
|
||||||
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
|
|
||||||
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
|
|
||||||
#nn.BatchNorm2d(num_features=self.mid),\
|
|
||||||
nn.ReLU(inplace=True),\
|
|
||||||
nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.init_params()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = x + self.conv_branch(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
# init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
# init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
class ATA(nn.Module):
|
|
||||||
def __init__(self, inchannels, reduction = 8):
|
|
||||||
super(ATA, self).__init__()
|
|
||||||
self.inchannels = inchannels
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(self.inchannels // reduction, self.inchannels),
|
|
||||||
nn.Sigmoid())
|
|
||||||
self.init_params()
|
|
||||||
|
|
||||||
def forward(self, low_x, high_x):
|
|
||||||
n, c, _, _ = low_x.size()
|
|
||||||
x = torch.cat([low_x, high_x], 1)
|
|
||||||
x = self.avg_pool(x)
|
|
||||||
x = x.view(n, -1)
|
|
||||||
x = self.fc(x).view(n,c,1,1)
|
|
||||||
x = low_x * x + high_x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
#init.normal(m.weight, std=0.01)
|
|
||||||
init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
#init.normal_(m.weight, std=0.01)
|
|
||||||
init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class FFM(nn.Module):
|
|
||||||
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
|
||||||
super(FFM, self).__init__()
|
|
||||||
self.inchannels = inchannels
|
|
||||||
self.midchannels = midchannels
|
|
||||||
self.outchannels = outchannels
|
|
||||||
self.upfactor = upfactor
|
|
||||||
|
|
||||||
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
|
|
||||||
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
|
||||||
|
|
||||||
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
|
||||||
|
|
||||||
self.init_params()
|
|
||||||
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
|
||||||
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
|
||||||
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
|
||||||
|
|
||||||
def forward(self, low_x, high_x):
|
|
||||||
x = self.ftb1(low_x)
|
|
||||||
|
|
||||||
'''
|
|
||||||
x = torch.cat((x,high_x),1)
|
|
||||||
if x.shape[2] == 12:
|
|
||||||
x = self.p1(x)
|
|
||||||
elif x.shape[2] == 24:
|
|
||||||
x = self.p2(x)
|
|
||||||
elif x.shape[2] == 48:
|
|
||||||
x = self.p3(x)
|
|
||||||
'''
|
|
||||||
x = x + high_x ###high_x
|
|
||||||
x = self.ftb2(x)
|
|
||||||
x = self.upsample(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class noFFM(nn.Module):
|
|
||||||
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
|
||||||
super(noFFM, self).__init__()
|
|
||||||
self.inchannels = inchannels
|
|
||||||
self.midchannels = midchannels
|
|
||||||
self.outchannels = outchannels
|
|
||||||
self.upfactor = upfactor
|
|
||||||
|
|
||||||
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
|
||||||
|
|
||||||
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
|
||||||
|
|
||||||
self.init_params()
|
|
||||||
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
|
||||||
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
|
||||||
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
|
||||||
|
|
||||||
def forward(self, low_x, high_x):
|
|
||||||
|
|
||||||
#x = self.ftb1(low_x)
|
|
||||||
x = high_x ###high_x
|
|
||||||
x = self.ftb2(x)
|
|
||||||
x = self.upsample(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AO(nn.Module):
|
|
||||||
# Adaptive output module
|
|
||||||
def __init__(self, inchannels, outchannels, upfactor=2):
|
|
||||||
super(AO, self).__init__()
|
|
||||||
self.inchannels = inchannels
|
|
||||||
self.outchannels = outchannels
|
|
||||||
self.upfactor = upfactor
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
|
||||||
nn.BatchNorm2d(num_features=self.inchannels//2),\
|
|
||||||
nn.ReLU(inplace=True),\
|
|
||||||
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
|
|
||||||
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
|
|
||||||
#nn.ReLU(inplace=True)) ## get positive values
|
|
||||||
"""
|
|
||||||
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
|
||||||
#nn.BatchNorm2d(num_features=self.inchannels//2),\
|
|
||||||
nn.ReLU(inplace=True),\
|
|
||||||
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
|
|
||||||
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
|
|
||||||
|
|
||||||
#nn.ReLU(inplace=True)) ## get positive values
|
|
||||||
|
|
||||||
self.init_params()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.adapt_conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
class ASPP(nn.Module):
|
|
||||||
def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
|
|
||||||
super(ASPP, self).__init__()
|
|
||||||
self.inchannels = inchannels
|
|
||||||
self.planes = planes
|
|
||||||
self.rates = rates
|
|
||||||
self.kernel_sizes = []
|
|
||||||
self.paddings = []
|
|
||||||
for rate in self.rates:
|
|
||||||
if rate == 1:
|
|
||||||
self.kernel_sizes.append(1)
|
|
||||||
self.paddings.append(0)
|
|
||||||
else:
|
|
||||||
self.kernel_sizes.append(3)
|
|
||||||
self.paddings.append(rate)
|
|
||||||
self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
|
|
||||||
stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.BatchNorm2d(num_features=self.planes)
|
|
||||||
)
|
|
||||||
self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
|
|
||||||
stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.BatchNorm2d(num_features=self.planes),
|
|
||||||
)
|
|
||||||
self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
|
|
||||||
stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.BatchNorm2d(num_features=self.planes),
|
|
||||||
)
|
|
||||||
self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
|
|
||||||
stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.BatchNorm2d(num_features=self.planes),
|
|
||||||
)
|
|
||||||
|
|
||||||
#self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
|
|
||||||
def forward(self, x):
|
|
||||||
x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
|
|
||||||
#x = self.conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
# ==============================================================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConv(nn.Module):
|
|
||||||
def __init__(self, inchannels):
|
|
||||||
super(ResidualConv, self).__init__()
|
|
||||||
#nn.BatchNorm2d
|
|
||||||
self.conv = nn.Sequential(
|
|
||||||
#nn.BatchNorm2d(num_features=inchannels),
|
|
||||||
nn.ReLU(inplace=False),
|
|
||||||
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
|
|
||||||
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
|
|
||||||
nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
|
|
||||||
nn.BatchNorm2d(num_features=inchannels//2),
|
|
||||||
nn.ReLU(inplace=False),
|
|
||||||
nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
|
|
||||||
)
|
|
||||||
self.init_params()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)+x
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusion(nn.Module):
|
|
||||||
def __init__(self, inchannels, outchannels):
|
|
||||||
super(FeatureFusion, self).__init__()
|
|
||||||
self.conv = ResidualConv(inchannels=inchannels)
|
|
||||||
#nn.BatchNorm2d
|
|
||||||
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
|
|
||||||
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
|
|
||||||
nn.BatchNorm2d(num_features=outchannels),
|
|
||||||
nn.ReLU(inplace=True))
|
|
||||||
|
|
||||||
def forward(self, lowfeat, highfeat):
|
|
||||||
return self.up(highfeat + self.conv(lowfeat))
|
|
||||||
|
|
||||||
def init_params(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#init.kaiming_normal_(m.weight, mode='fan_out')
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
#init.xavier_normal_(m.weight)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.normal_(m.weight, std=0.01)
|
|
||||||
if m.bias is not None:
|
|
||||||
init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class SenceUnderstand(nn.Module):
|
|
||||||
def __init__(self, channels):
|
|
||||||
super(SenceUnderstand, self).__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(inplace = True))
|
|
||||||
self.pool = nn.AdaptiveAvgPool2d(8)
|
|
||||||
self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
|
|
||||||
nn.ReLU(inplace = True))
|
|
||||||
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
|
|
||||||
nn.ReLU(inplace=True))
|
|
||||||
self.initial_params()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
n,c,h,w = x.size()
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.pool(x)
|
|
||||||
x = x.view(n,-1)
|
|
||||||
x = self.fc(x)
|
|
||||||
x = x.view(n, self.channels, 1, 1)
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = x.repeat(1,1,h,w)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def initial_params(self, dev=0.01):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
#print torch.sum(m.weight)
|
|
||||||
m.weight.data.normal_(0, dev)
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.fill_(0)
|
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
|
||||||
#print torch.sum(m.weight)
|
|
||||||
m.weight.data.normal_(0, dev)
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.fill_(0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
m.weight.data.normal_(0, dev)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,342 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
|
||||||
# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
|
||||||
from collections import OrderedDict
|
|
||||||
# from mmseg.ops import resize
|
|
||||||
from torch.nn.functional import interpolate as resize
|
|
||||||
# from builder import HEADS
|
|
||||||
from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
|
|
||||||
# from mmseg.models.utils import *
|
|
||||||
import attr
|
|
||||||
from IPython import embed
|
|
||||||
from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
|
|
||||||
import cv2
|
|
||||||
from models.SpaTrackV2.models.depth_refiner.network import *
|
|
||||||
import warnings
|
|
||||||
# from mmcv.utils import Registry, build_from_cfg
|
|
||||||
from torch import nn
|
|
||||||
from einops import rearrange
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from models.SpaTrackV2.models.blocks import (
|
|
||||||
AttnBlock, CrossAttnBlock, Mlp
|
|
||||||
)
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
"""
|
|
||||||
Linear Embedding
|
|
||||||
"""
|
|
||||||
def __init__(self, input_dim=2048, embed_dim=768):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = nn.Linear(input_dim, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
x = self.proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_multiscale_fast(
|
|
||||||
track2d: torch.Tensor,
|
|
||||||
trackfeature: torch.Tensor,
|
|
||||||
H: int,
|
|
||||||
W: int,
|
|
||||||
kernel_sizes = [1]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
|
|
||||||
|
|
||||||
This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
|
|
||||||
while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
|
|
||||||
avoiding dilution by empty pixels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
|
|
||||||
for each track point across batches, frames, and points.
|
|
||||||
trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
|
|
||||||
for each track point.
|
|
||||||
H (int): Height of the target output image.
|
|
||||||
W (int): Width of the target output image.
|
|
||||||
kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
|
|
||||||
"""
|
|
||||||
B, T, N, C = trackfeature.shape
|
|
||||||
device = trackfeature.device
|
|
||||||
|
|
||||||
# 1. Flatten coordinates and filter valid points within image bounds
|
|
||||||
coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
|
|
||||||
x = coords_flat[:, 0] # x coordinates
|
|
||||||
y = coords_flat[:, 1] # y coordinates
|
|
||||||
feat_flat = trackfeature.reshape(-1, C) # Flatten features
|
|
||||||
|
|
||||||
valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
|
|
||||||
x = x[valid_mask]
|
|
||||||
y = y[valid_mask]
|
|
||||||
feat_flat = feat_flat[valid_mask]
|
|
||||||
valid_count = x.shape[0]
|
|
||||||
|
|
||||||
if valid_count == 0:
|
|
||||||
return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
|
|
||||||
|
|
||||||
# 2. Calculate linear indices and batch-frame indices for scattering
|
|
||||||
lin_idx = y * W + x # Linear index within a single frame (H*W range)
|
|
||||||
|
|
||||||
# Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
|
|
||||||
bt_idx_raw = (
|
|
||||||
torch.arange(B * T, device=device)
|
|
||||||
.view(B, T, 1)
|
|
||||||
.expand(B, T, N)
|
|
||||||
.reshape(-1)
|
|
||||||
)
|
|
||||||
bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
|
|
||||||
|
|
||||||
# 3. Create accumulation buffers for features and weights
|
|
||||||
total_space = B * T * H * W
|
|
||||||
img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
|
|
||||||
weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
|
|
||||||
|
|
||||||
# 4. Scatter features and weights into accumulation buffers
|
|
||||||
idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
|
|
||||||
|
|
||||||
# Add features to corresponding indices (index_add_ is efficient for sparse updates)
|
|
||||||
img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
|
|
||||||
weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
|
|
||||||
|
|
||||||
# 5. Normalize features by valid weights, keep zeros for invalid regions
|
|
||||||
valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
|
|
||||||
img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
|
|
||||||
img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
|
|
||||||
|
|
||||||
# 6. Reshape to (B, T, C, H, W) for further processing
|
|
||||||
img = (
|
|
||||||
img_accum_flat.view(B, T, H, W, C)
|
|
||||||
.permute(0, 1, 4, 2, 3)
|
|
||||||
.contiguous()
|
|
||||||
) # Shape: (B, T, C, H, W)
|
|
||||||
|
|
||||||
# 7. Multi-scale pooling with weight masking to exclude zero holes
|
|
||||||
blurred_outputs = []
|
|
||||||
for k in kernel_sizes:
|
|
||||||
pad = k // 2
|
|
||||||
img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
|
|
||||||
|
|
||||||
# Create weight mask for valid regions (1 where features exist, 0 otherwise)
|
|
||||||
weight_mask = (
|
|
||||||
weight_accum_flat.view(B, T, 1, H, W) > 0
|
|
||||||
).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
|
|
||||||
|
|
||||||
# Calculate number of valid neighbors in each pooling window
|
|
||||||
weight_sum = F.conv2d(
|
|
||||||
weight_mask,
|
|
||||||
torch.ones((1, 1, k, k), device=device),
|
|
||||||
stride=1,
|
|
||||||
padding=pad
|
|
||||||
) # Shape: (B*T, 1, H, W)
|
|
||||||
|
|
||||||
# Sum features only in valid regions
|
|
||||||
feat_sum = F.conv2d(
|
|
||||||
img_bt * weight_mask, # Mask out invalid regions before summing
|
|
||||||
torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
|
|
||||||
stride=1,
|
|
||||||
padding=pad,
|
|
||||||
groups=C
|
|
||||||
) # Shape: (B*T, C, H, W)
|
|
||||||
|
|
||||||
# Compute average only over valid neighbors
|
|
||||||
feat_avg = feat_sum / (weight_sum + 1e-6)
|
|
||||||
blurred_outputs.append(feat_avg)
|
|
||||||
|
|
||||||
# 8. Fuse multi-scale results by averaging across kernel sizes
|
|
||||||
fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
|
|
||||||
return fused.view(B, T, C, H, W) # Restore original shape
|
|
||||||
|
|
||||||
#@HEADS.register_module()
|
|
||||||
class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
|
|
||||||
|
|
||||||
def __init__(self, feature_strides, **kwargs):
|
|
||||||
super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
|
|
||||||
self.training = False
|
|
||||||
assert len(feature_strides) == len(self.in_channels)
|
|
||||||
assert min(feature_strides) == feature_strides[0]
|
|
||||||
self.feature_strides = feature_strides
|
|
||||||
|
|
||||||
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
|
||||||
|
|
||||||
decoder_params = kwargs['decoder_params']
|
|
||||||
embedding_dim = decoder_params['embed_dim']
|
|
||||||
|
|
||||||
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
|
|
||||||
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
|
|
||||||
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
|
|
||||||
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
|
|
||||||
|
|
||||||
self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
|
|
||||||
nn.ReLU(inplace=True))
|
|
||||||
|
|
||||||
self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
|
|
||||||
|
|
||||||
depths = decoder_params['depths']
|
|
||||||
|
|
||||||
self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
|
|
||||||
self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
|
|
||||||
|
|
||||||
self.att_temporal = nn.ModuleList(
|
|
||||||
[
|
|
||||||
AttnBlock(embedding_dim, 8,
|
|
||||||
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
|
||||||
for _ in range(8)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.att_spatial = nn.ModuleList(
|
|
||||||
[
|
|
||||||
AttnBlock(embedding_dim, 8,
|
|
||||||
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
|
||||||
for _ in range(8)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize reg tokens
|
|
||||||
nn.init.trunc_normal_(self.reg_tokens, std=0.02)
|
|
||||||
|
|
||||||
self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
|
|
||||||
input_resolution=(96,
|
|
||||||
96),
|
|
||||||
depth=depths,
|
|
||||||
num_heads=8,
|
|
||||||
window_size=7,
|
|
||||||
mlp_ratio=4.,
|
|
||||||
qkv_bias=True,
|
|
||||||
qk_scale=None,
|
|
||||||
drop=0.,
|
|
||||||
attn_drop=0.,
|
|
||||||
drop_path=0.,
|
|
||||||
norm_layer=nn.LayerNorm,
|
|
||||||
pool_method='fc',
|
|
||||||
downsample=None,
|
|
||||||
focal_level=2,
|
|
||||||
focal_window=5,
|
|
||||||
expand_size=3,
|
|
||||||
expand_layer="all",
|
|
||||||
use_conv_embed=False,
|
|
||||||
use_shift=False,
|
|
||||||
use_pre_norm=False,
|
|
||||||
use_checkpoint=False,
|
|
||||||
use_layerscale=False,
|
|
||||||
layerscale_value=1e-4,
|
|
||||||
focal_l_clips=[7,4,2],
|
|
||||||
focal_kernel_clips=[7,5,3])
|
|
||||||
|
|
||||||
self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
|
|
||||||
self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
|
|
||||||
self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
|
|
||||||
self.AO = AO(32, outchannels=3, upfactor=1)
|
|
||||||
self._c2 = None
|
|
||||||
self._c_further = None
|
|
||||||
|
|
||||||
def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
|
|
||||||
|
|
||||||
# input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
|
|
||||||
if self.training:
|
|
||||||
assert self.num_clips==num_clips
|
|
||||||
|
|
||||||
x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
|
|
||||||
c1, c2, c3, c4 = x
|
|
||||||
|
|
||||||
############## MLP decoder on C1-C4 ###########
|
|
||||||
n, _, h, w = c4.shape
|
|
||||||
batch_size = n // num_clips
|
|
||||||
|
|
||||||
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
|
||||||
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
|
||||||
|
|
||||||
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
|
||||||
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
|
||||||
|
|
||||||
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
|
||||||
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
|
||||||
|
|
||||||
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
|
||||||
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
|
||||||
|
|
||||||
_, _, h, w=_c.shape
|
|
||||||
_c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
|
|
||||||
|
|
||||||
# Expand reg_tokens to match batch size
|
|
||||||
reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
|
|
||||||
|
|
||||||
_c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
|
|
||||||
|
|
||||||
assert _c_further.shape==_c2.shape
|
|
||||||
self._c2 = _c2
|
|
||||||
self._c_further = _c_further
|
|
||||||
|
|
||||||
# compute the scale and shift of the global patch
|
|
||||||
global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
|
|
||||||
global_patch = torch.cat([global_patch, reg_tokens], dim=1)
|
|
||||||
for i in range(8):
|
|
||||||
global_patch = self.att_temporal[i](global_patch)
|
|
||||||
global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
|
||||||
global_patch = self.att_spatial[i](global_patch)
|
|
||||||
global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
|
||||||
|
|
||||||
reg_tokens = global_patch[:, -2:, :]
|
|
||||||
s_ = self.scale_shift_head(reg_tokens)
|
|
||||||
scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
|
|
||||||
shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
|
|
||||||
shift[:,:,:2,...] = 0
|
|
||||||
return scale, shift
|
|
||||||
|
|
||||||
def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
|
|
||||||
|
|
||||||
if self._c2 is None:
|
|
||||||
scale, shift = self.buffer_forward(inputs,num_clips,imgs)
|
|
||||||
|
|
||||||
B, T, N, _ = tracks.shape
|
|
||||||
|
|
||||||
_c2 = self._c2
|
|
||||||
_c_further = self._c_further
|
|
||||||
|
|
||||||
# skip and head
|
|
||||||
_c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
|
|
||||||
_c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
|
|
||||||
|
|
||||||
outframe = self.ffm2(_c_further, _c2)
|
|
||||||
|
|
||||||
tracks_uv = tracks_uvd[...,:2].clone()
|
|
||||||
track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
|
|
||||||
# visualize track_feature as video
|
|
||||||
# import cv2
|
|
||||||
# import imageio
|
|
||||||
# import os
|
|
||||||
# BT, C, H, W = outframe.shape
|
|
||||||
# track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
|
|
||||||
# track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
|
|
||||||
# track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
|
|
||||||
# track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
|
|
||||||
# imgs =(imgs.detach() + 1) * 127.5
|
|
||||||
# vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
|
|
||||||
# for b in range(B):
|
|
||||||
# frames = []
|
|
||||||
# for t in range(T):
|
|
||||||
# frame = track_feature_vis[b,t]
|
|
||||||
# frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
|
|
||||||
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
# frames.append(frame)
|
|
||||||
# # Save as gif
|
|
||||||
# imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
|
|
||||||
track_feature = self.proj_track(track_feature)
|
|
||||||
outframe = self.ffm1(edge_feat1 + track_feature,outframe)
|
|
||||||
outframe = self.ffm0(edge_feat,outframe)
|
|
||||||
outframe = self.AO(outframe)
|
|
||||||
|
|
||||||
return outframe
|
|
||||||
|
|
||||||
def reset_success(self):
|
|
||||||
self._c2 = None
|
|
||||||
self._c_further = None
|
|
||||||
Loading…
x
Reference in New Issue
Block a user