online_prev

This commit is contained in:
xiaoyuxi 2025-07-13 15:11:20 +08:00
parent 3b9e9a1e8e
commit eb0ce33864

90
app.py
View File

@ -43,7 +43,9 @@ except ImportError as e:
raise raise
# Constants # Constants
MAX_FRAMES = 80 MAX_FRAMES_OFFLINE = 80
MAX_FRAMES_ONLINE = 300
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
MARKERS = [1, 5] # Cross for negative, Star for positive MARKERS = [1, 5] # Cross for negative, Star for positive
MARKER_SIZE = 8 MARKER_SIZE = 8
@ -88,13 +90,16 @@ vggt4track_model = vggt4track_model.to("cuda")
# Global model initialization # Global model initialization
print("🚀 Initializing local models...") print("🚀 Initializing local models...")
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model.eval() tracker_model_offline.eval()
tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
tracker_model_online.eval()
predictor = get_sam_predictor() predictor = get_sam_predictor()
print("✅ Models loaded successfully!") print("✅ Models loaded successfully!")
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"]) gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
@spaces.GPU
def gpu_run_inference(predictor_arg, image, points, boxes): def gpu_run_inference(predictor_arg, image, points, boxes):
"""GPU-accelerated SAM inference""" """GPU-accelerated SAM inference"""
if predictor_arg is None: if predictor_arg is None:
@ -118,6 +123,7 @@ def gpu_run_inference(predictor_arg, image, points, boxes):
return run_inference(predictor_arg, image, points, boxes) return run_inference(predictor_arg, image, points, boxes)
@spaces.GPU
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"): def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
"""GPU-accelerated tracking""" """GPU-accelerated tracking"""
import torchvision.transforms as T import torchvision.transforms as T
@ -126,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
if tracker_model_arg is None or tracker_viser_arg is None: if tracker_model_arg is None or tracker_viser_arg is None:
print("Initializing tracker models inside GPU function...") print("Initializing tracker models inside GPU function...")
out_dir = os.path.join(temp_dir, "results") out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, if mode == "offline":
tracker_model=tracker_model.cuda()) tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model_offline.cuda())
else:
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model_online.cuda())
# Setup paths # Setup paths
video_path = os.path.join(temp_dir, f"{video_name}.mp4") video_path = os.path.join(temp_dir, f"{video_name}.mp4")
@ -146,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
if scale < 1: if scale < 1:
new_h, new_w = int(h * scale), int(w * scale) new_h, new_w = int(h * scale), int(w * scale)
video_tensor = T.Resize((new_h, new_w))(video_tensor) video_tensor = T.Resize((new_h, new_w))(video_tensor)
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES] if mode == "offline":
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
else:
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
# Move to GPU # Move to GPU
video_tensor = video_tensor.cuda() video_tensor = video_tensor.cuda()
@ -524,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
print(f"❌ Error in reset_points: {e}") print(f"❌ Error in reset_points: {e}")
return None, [] return None, []
def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"): def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
"""Launch visualization with user-specific temp directory""" """Launch visualization with user-specific temp directory"""
if original_image_state is None: if original_image_state is None:
return None, None, None return None, None, None
@ -536,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
video_name = frame_data.get('video_name', 'video') video_name = frame_data.get('video_name', 'video')
print(f"🚀 Starting tracking for video: {video_name}") print(f"🚀 Starting tracking for video: {video_name}")
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}") print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
# Check for mask files # Check for mask files
mask_files = glob.glob(os.path.join(temp_dir, "*.png")) mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
@ -550,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
mask_path = mask_files[0] if mask_files else None mask_path = mask_files[0] if mask_files else None
# Run tracker # Run tracker
print("🎯 Running tracker...") print(f"🎯 Running tracker in {processing_mode} mode...")
out_dir = os.path.join(temp_dir, "results") out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode) gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
# Process results # Process results
npz_path = os.path.join(out_dir, "result.npz") npz_path = os.path.join(out_dir, "result.npz")
@ -607,6 +620,7 @@ def clear_all_with_download():
gr.update(value=50), gr.update(value=50),
gr.update(value=756), gr.update(value=756),
gr.update(value=3), gr.update(value=3),
gr.update(value="offline"), # processing_mode
None, # tracking_video_download None, # tracking_video_download
None) # HTML download component None) # HTML download component
@ -639,6 +653,13 @@ def get_video_settings(video_name):
return video_settings.get(video_name, (50, 756, 3)) return video_settings.get(video_name, (50, 756, 3))
def update_status_indicator(processing_mode):
"""Update status indicator based on processing mode"""
if processing_mode == "offline":
return "**Status:** 🟢 Local Processing Mode (Offline)"
else:
return "**Status:** 🔵 Cloud Processing Mode (Online)"
# Create the Gradio interface # Create the Gradio interface
print("🎨 Creating Gradio interface...") print("🎨 Creating Gradio interface...")
@ -844,7 +865,7 @@ with gr.Blocks(
""") """)
# Status indicator # Status indicator
gr.Markdown("**Status:** 🟢 Local Processing Mode") status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
# Main content area - video upload left, 3D visualization right # Main content area - video upload left, 3D visualization right
with gr.Row(): with gr.Row():
@ -943,18 +964,29 @@ with gr.Blocks(
with gr.Row(): with gr.Row():
gr.Markdown("### ⚙️ Tracking Parameters") gr.Markdown("### ⚙️ Tracking Parameters")
with gr.Row(): with gr.Row():
grid_size = gr.Slider( # 添加模式选择器
minimum=10, maximum=100, step=10, value=50, with gr.Column(scale=1):
label="Grid Size", info="Tracking detail level" processing_mode = gr.Radio(
) choices=["offline", "online"],
vo_points = gr.Slider( value="offline",
minimum=100, maximum=2000, step=50, value=756, label="Processing Mode",
label="VO Points", info="Motion accuracy" info="Offline: default mode | Online: Sliding Window Mode"
) )
fps = gr.Slider( with gr.Column(scale=1):
minimum=1, maximum=20, step=1, value=3, grid_size = gr.Slider(
label="FPS", info="Processing speed" minimum=10, maximum=100, step=10, value=50,
) label="Grid Size", info="Tracking detail level"
)
with gr.Column(scale=1):
vo_points = gr.Slider(
minimum=100, maximum=2000, step=50, value=756,
label="VO Points", info="Motion accuracy"
)
with gr.Column(scale=1):
fps = gr.Slider(
minimum=1, maximum=20, step=1, value=3,
label="FPS", info="Processing speed"
)
# Advanced Point Selection with SAM - Collapsed by default # Advanced Point Selection with SAM - Collapsed by default
with gr.Row(): with gr.Row():
@ -1080,6 +1112,12 @@ with gr.Blocks(
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps] outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
) )
processing_mode.change(
fn=update_status_indicator,
inputs=[processing_mode],
outputs=[status_indicator]
)
interactive_frame.select( interactive_frame.select(
fn=select_point, fn=select_point,
inputs=[original_image_state, selected_points, point_type], inputs=[original_image_state, selected_points, point_type],
@ -1094,12 +1132,12 @@ with gr.Blocks(
clear_all_btn.click( clear_all_btn.click(
fn=clear_all_with_download, fn=clear_all_with_download,
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download] outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
) )
launch_btn.click( launch_btn.click(
fn=launch_viz, fn=launch_viz,
inputs=[grid_size, vo_points, fps, original_image_state], inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
outputs=[viz_html, tracking_video_download, html_download] outputs=[viz_html, tracking_video_download, html_download]
) )