124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
|
|
|
|
try:
|
|
from timm.layers import resample_abs_pos_embed
|
|
except ImportError as err:
|
|
print("ImportError: {0}".format(err))
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
def make_vit_b16_backbone(
|
|
model,
|
|
encoder_feature_dims,
|
|
encoder_feature_layer_ids,
|
|
vit_features,
|
|
start_index=1,
|
|
use_grad_checkpointing=False,
|
|
) -> nn.Module:
|
|
"""Make a ViTb16 backbone for the DPT model."""
|
|
if use_grad_checkpointing:
|
|
model.set_grad_checkpointing()
|
|
|
|
vit_model = nn.Module()
|
|
vit_model.hooks = encoder_feature_layer_ids
|
|
vit_model.model = model
|
|
vit_model.features = encoder_feature_dims
|
|
vit_model.vit_features = vit_features
|
|
vit_model.model.start_index = start_index
|
|
vit_model.model.patch_size = vit_model.model.patch_embed.patch_size
|
|
vit_model.model.is_vit = True
|
|
vit_model.model.forward = vit_model.model.forward_features
|
|
|
|
return vit_model
|
|
|
|
|
|
def forward_features_eva_fixed(self, x):
|
|
"""Encode features."""
|
|
x = self.patch_embed(x)
|
|
x, rot_pos_embed = self._pos_embed(x)
|
|
for blk in self.blocks:
|
|
if self.grad_checkpointing:
|
|
x = checkpoint(blk, x, rot_pos_embed)
|
|
else:
|
|
x = blk(x, rot_pos_embed)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
|
|
def resize_vit(model: nn.Module, img_size) -> nn.Module:
|
|
"""Resample the ViT module to the given size."""
|
|
patch_size = model.patch_embed.patch_size
|
|
model.patch_embed.img_size = img_size
|
|
grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
|
|
model.patch_embed.grid_size = grid_size
|
|
|
|
pos_embed = resample_abs_pos_embed(
|
|
model.pos_embed,
|
|
grid_size, # img_size
|
|
num_prefix_tokens=(
|
|
0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens
|
|
),
|
|
)
|
|
model.pos_embed = torch.nn.Parameter(pos_embed)
|
|
|
|
return model
|
|
|
|
|
|
def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module:
|
|
"""Resample the ViT patch size to the given one."""
|
|
# interpolate patch embedding
|
|
if hasattr(model, "patch_embed"):
|
|
old_patch_size = model.patch_embed.patch_size
|
|
|
|
if (
|
|
new_patch_size[0] != old_patch_size[0]
|
|
or new_patch_size[1] != old_patch_size[1]
|
|
):
|
|
patch_embed_proj = model.patch_embed.proj.weight
|
|
patch_embed_proj_bias = model.patch_embed.proj.bias
|
|
use_bias = True if patch_embed_proj_bias is not None else False
|
|
_, _, h, w = patch_embed_proj.shape
|
|
|
|
new_patch_embed_proj = torch.nn.functional.interpolate(
|
|
patch_embed_proj,
|
|
size=[new_patch_size[0], new_patch_size[1]],
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
new_patch_embed_proj = (
|
|
new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1])
|
|
)
|
|
|
|
model.patch_embed.proj = nn.Conv2d(
|
|
in_channels=model.patch_embed.proj.in_channels,
|
|
out_channels=model.patch_embed.proj.out_channels,
|
|
kernel_size=new_patch_size,
|
|
stride=new_patch_size,
|
|
bias=use_bias,
|
|
)
|
|
|
|
if use_bias:
|
|
model.patch_embed.proj.bias = patch_embed_proj_bias
|
|
|
|
model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj)
|
|
|
|
model.patch_size = new_patch_size
|
|
model.patch_embed.patch_size = new_patch_size
|
|
model.patch_embed.img_size = (
|
|
int(
|
|
model.patch_embed.img_size[0]
|
|
* new_patch_size[0]
|
|
/ old_patch_size[0]
|
|
),
|
|
int(
|
|
model.patch_embed.img_size[1]
|
|
* new_patch_size[1]
|
|
/ old_patch_size[1]
|
|
),
|
|
)
|
|
|
|
return model
|