125 lines
3.5 KiB
Python
125 lines
3.5 KiB
Python
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
# Factory functions to build and load ViT models.
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import types
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Literal, Optional
|
|
|
|
import timm
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .vit import (
|
|
forward_features_eva_fixed,
|
|
make_vit_b16_backbone,
|
|
resize_patch_embed,
|
|
resize_vit,
|
|
)
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
ViTPreset = Literal[
|
|
"dinov2l16_384",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class ViTConfig:
|
|
"""Configuration for ViT."""
|
|
|
|
in_chans: int
|
|
embed_dim: int
|
|
|
|
img_size: int = 384
|
|
patch_size: int = 16
|
|
|
|
# In case we need to rescale the backbone when loading from timm.
|
|
timm_preset: Optional[str] = None
|
|
timm_img_size: int = 384
|
|
timm_patch_size: int = 16
|
|
|
|
# The following 2 parameters are only used by DPT. See dpt_factory.py.
|
|
encoder_feature_layer_ids: List[int] = None
|
|
"""The layers in the Beit/ViT used to constructs encoder features for DPT."""
|
|
encoder_feature_dims: List[int] = None
|
|
"""The dimension of features of encoder layers from Beit/ViT features for DPT."""
|
|
|
|
|
|
VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = {
|
|
"dinov2l16_384": ViTConfig(
|
|
in_chans=3,
|
|
embed_dim=1024,
|
|
encoder_feature_layer_ids=[5, 11, 17, 23],
|
|
encoder_feature_dims=[256, 512, 1024, 1024],
|
|
img_size=384,
|
|
patch_size=16,
|
|
timm_preset="vit_large_patch14_dinov2",
|
|
timm_img_size=518,
|
|
timm_patch_size=14,
|
|
),
|
|
}
|
|
|
|
|
|
def create_vit(
|
|
preset: ViTPreset,
|
|
use_pretrained: bool = False,
|
|
checkpoint_uri: str | None = None,
|
|
use_grad_checkpointing: bool = False,
|
|
) -> nn.Module:
|
|
"""Create and load a VIT backbone module.
|
|
|
|
Args:
|
|
----
|
|
preset: The VIT preset to load the pre-defined config.
|
|
use_pretrained: Load pretrained weights if True, default is False.
|
|
checkpoint_uri: Checkpoint to load the wights from.
|
|
use_grad_checkpointing: Use grandient checkpointing.
|
|
|
|
Returns:
|
|
-------
|
|
A Torch ViT backbone module.
|
|
|
|
"""
|
|
config = VIT_CONFIG_DICT[preset]
|
|
|
|
img_size = (config.img_size, config.img_size)
|
|
patch_size = (config.patch_size, config.patch_size)
|
|
|
|
if "eva02" in preset:
|
|
model = timm.create_model(config.timm_preset, pretrained=use_pretrained)
|
|
model.forward_features = types.MethodType(forward_features_eva_fixed, model)
|
|
else:
|
|
model = timm.create_model(
|
|
config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True
|
|
)
|
|
model = make_vit_b16_backbone(
|
|
model,
|
|
encoder_feature_dims=config.encoder_feature_dims,
|
|
encoder_feature_layer_ids=config.encoder_feature_layer_ids,
|
|
vit_features=config.embed_dim,
|
|
use_grad_checkpointing=use_grad_checkpointing,
|
|
)
|
|
if config.patch_size != config.timm_patch_size:
|
|
model.model = resize_patch_embed(model.model, new_patch_size=patch_size)
|
|
if config.img_size != config.timm_img_size:
|
|
model.model = resize_vit(model.model, img_size=img_size)
|
|
|
|
if checkpoint_uri is not None:
|
|
state_dict = torch.load(checkpoint_uri, map_location="cpu")
|
|
missing_keys, unexpected_keys = model.load_state_dict(
|
|
state_dict=state_dict, strict=False
|
|
)
|
|
|
|
if len(unexpected_keys) != 0:
|
|
raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}")
|
|
if len(missing_keys) != 0:
|
|
raise KeyError(f"Keys are missing when loading vit: {missing_keys}")
|
|
|
|
LOGGER.info(model)
|
|
return model.model
|