129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
import gc
|
|
import numpy as np
|
|
import torch
|
|
from typing import Optional, Tuple, List, Union
|
|
import warnings
|
|
import cv2
|
|
try:
|
|
from transformers import SamModel, SamProcessor
|
|
from huggingface_hub import hf_hub_download
|
|
HF_AVAILABLE = True
|
|
except ImportError:
|
|
HF_AVAILABLE = False
|
|
warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
|
|
|
|
# Hugging Face model mapping
|
|
HF_MODELS = {
|
|
'vit_b': 'facebook/sam-vit-base',
|
|
'vit_l': 'facebook/sam-vit-large',
|
|
'vit_h': 'facebook/sam-vit-huge'
|
|
}
|
|
|
|
class HFSamPredictor:
|
|
"""
|
|
Hugging Face version of SamPredictor that wraps the transformers SAM models.
|
|
This class provides the same interface as the original SamPredictor for seamless integration.
|
|
"""
|
|
|
|
def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
|
|
"""
|
|
Initialize the HF SAM predictor.
|
|
|
|
Args:
|
|
model: The SAM model from transformers
|
|
processor: The SAM processor from transformers
|
|
device: Device to run the model on ('cuda', 'cpu', etc.)
|
|
"""
|
|
self.model = model
|
|
self.processor = processor
|
|
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
# Store the current image and its features
|
|
self.original_size = None
|
|
self.input_size = None
|
|
self.features = None
|
|
self.image = None
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
|
|
"""
|
|
Load a SAM model from Hugging Face Hub.
|
|
|
|
Args:
|
|
model_name: Model name from HF_MODELS or direct HF model path
|
|
device: Device to load the model on
|
|
|
|
Returns:
|
|
HFSamPredictor instance
|
|
"""
|
|
if not HF_AVAILABLE:
|
|
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
|
|
|
# Map model type to HF model name if needed
|
|
if model_name in HF_MODELS:
|
|
model_name = HF_MODELS[model_name]
|
|
|
|
print(f"Loading SAM model from Hugging Face: {model_name}")
|
|
|
|
# Load model and processor
|
|
model = SamModel.from_pretrained(model_name)
|
|
processor = SamProcessor.from_pretrained(model_name)
|
|
return cls(model, processor, device)
|
|
|
|
def preprocess(self, image: np.ndarray,
|
|
input_points: List[List[float]], input_labels: List[int]) -> None:
|
|
"""
|
|
Set the image for prediction. This preprocesses the image and extracts features.
|
|
|
|
Args:
|
|
image: Input image as numpy array (H, W, C) in RGB format
|
|
"""
|
|
if image.dtype != np.uint8:
|
|
image = (image * 255).astype(np.uint8)
|
|
|
|
self.image = image
|
|
self.original_size = image.shape[:2]
|
|
|
|
# Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
|
|
inputs = self.processor(
|
|
images=image,
|
|
input_points=input_points,
|
|
input_labels=input_labels,
|
|
return_tensors="pt"
|
|
)
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
self.input_size = inputs['pixel_values'].shape[-2:]
|
|
self.features = inputs
|
|
return inputs
|
|
|
|
|
|
def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
|
|
image: Optional[np.ndarray] = None) -> HFSamPredictor:
|
|
"""
|
|
Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
|
|
|
|
Args:
|
|
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
|
device: Device to run the model on
|
|
image: Optional image to set immediately
|
|
|
|
Returns:
|
|
HFSamPredictor instance
|
|
"""
|
|
if not HF_AVAILABLE:
|
|
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
|
|
|
if device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
# Load the predictor
|
|
predictor = HFSamPredictor.from_pretrained(model_type, device)
|
|
|
|
# Set image if provided
|
|
if image is not None:
|
|
predictor.set_image(image)
|
|
|
|
return predictor |