101 lines
3.4 KiB
Python
101 lines
3.4 KiB
Python
import argparse
|
|
import cv2
|
|
import numpy as np
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.transforms import Compose
|
|
|
|
from models.monoD.depth_anything.dpt import DPT_DINOv2
|
|
from models.monoD.depth_anything.util.transform import (
|
|
Resize, NormalizeImage, PrepareForNet
|
|
)
|
|
|
|
|
|
def build(config):
|
|
"""
|
|
Build the model from the config
|
|
NOTE: the config should contain the following
|
|
- encoder: the encoder type of the model
|
|
- load_from: the path to the pretrained model
|
|
"""
|
|
args = config
|
|
assert args.encoder in ['vits', 'vitb', 'vitl']
|
|
if args.encoder == 'vits':
|
|
depth_anything = DPT_DINOv2(encoder='vits', features=64,
|
|
out_channels=[48, 96, 192, 384],
|
|
localhub=args.localhub).cuda()
|
|
elif args.encoder == 'vitb':
|
|
depth_anything = DPT_DINOv2(encoder='vitb', features=128,
|
|
out_channels=[96, 192, 384, 768],
|
|
localhub=args.localhub).cuda()
|
|
else:
|
|
depth_anything = DPT_DINOv2(encoder='vitl', features=256,
|
|
out_channels=[256, 512, 1024, 1024],
|
|
localhub=args.localhub).cuda()
|
|
depth_anything.load_state_dict(torch.load(args.load_from,
|
|
map_location='cpu'), strict=True)
|
|
total_params = sum(param.numel() for param in depth_anything.parameters())
|
|
print('Total parameters: {:.2f}M'.format(total_params / 1e6))
|
|
depth_anything.eval()
|
|
|
|
return depth_anything
|
|
|
|
class DepthAnything(nn.Module):
|
|
def __init__(self, args):
|
|
super(DepthAnything, self).__init__()
|
|
|
|
# build the chosen model
|
|
self.dpAny = build(args)
|
|
|
|
def infer(self, rgbs):
|
|
"""
|
|
Infer the depth map from the input RGB image
|
|
|
|
Args:
|
|
rgbs: the input RGB image B x 3 x H x W (Cuda Tensor)
|
|
|
|
Asserts:
|
|
the input should be a cuda tensor
|
|
"""
|
|
assert (rgbs.is_cuda)&(len(rgbs.shape) == 4)
|
|
T, C, H, W = rgbs.shape
|
|
# prepare the input
|
|
Resizer = Resize(
|
|
width=518,
|
|
height=518,
|
|
resize_target=False,
|
|
keep_aspect_ratio=True,
|
|
ensure_multiple_of=14,
|
|
resize_method='lower_bound',
|
|
image_interpolation_method=cv2.INTER_CUBIC,
|
|
)
|
|
#NOTE: step 1 Resize
|
|
width, height = Resizer.get_size(
|
|
rgbs.shape[2], rgbs.shape[3]
|
|
)
|
|
rgbs = F.interpolate(
|
|
rgbs, (int(height), int(width)), mode='bicubic', align_corners=False
|
|
)
|
|
#NOTE: step 2 NormalizeImage
|
|
mean_ = torch.tensor([0.485, 0.456, 0.406],
|
|
device=rgbs.device).view(1, 3, 1, 1)
|
|
std_ = torch.tensor([0.229, 0.224, 0.225],
|
|
device=rgbs.device).view(1, 3, 1, 1)
|
|
rgbs = (rgbs - mean_)/std_
|
|
#NOTE: step 3 PrepareForNet
|
|
|
|
# get the depth map
|
|
|
|
disp = self.dpAny(rgbs)
|
|
disp = F.interpolate(
|
|
disp[:,None], (H, W),
|
|
mode='bilinear', align_corners=False
|
|
)
|
|
# clamping the farthest depth to 100x of the nearest
|
|
depth_map = disp
|
|
|
|
return depth_map
|
|
|