43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
from typing import *
|
|
|
|
import click
|
|
import torch
|
|
|
|
|
|
class MGEBaselineInterface:
|
|
"""
|
|
Abstract class for model wrapper to uniformize the interface of loading and inference across different models.
|
|
"""
|
|
device: torch.device
|
|
|
|
@click.command()
|
|
@staticmethod
|
|
def load(*args, **kwargs) -> "MGEBaselineInterface":
|
|
"""
|
|
Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()`
|
|
"""
|
|
raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.")
|
|
|
|
def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
### Parameters
|
|
`image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1]
|
|
`intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional.
|
|
|
|
### Returns
|
|
A dictionary containing:
|
|
- `points_*`. point map output in OpenCV identity camera space.
|
|
Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`.
|
|
- `depth_*`. depth map output
|
|
Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`.
|
|
- `disparity_affine_invariant`. affine disparity map output
|
|
"""
|
|
raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.")
|
|
|
|
def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
If the model has a special evaluation mode, override this method to provide the evaluation mode inference.
|
|
|
|
By default, this method simply calls `infer()`.
|
|
"""
|
|
return self.infer(image, intrinsics) |