online_prev

This commit is contained in:
xiaoyuxi 2025-07-13 15:11:20 +08:00
parent 3b9e9a1e8e
commit eb0ce33864

88
app.py
View File

@ -43,7 +43,9 @@ except ImportError as e:
raise
# Constants
MAX_FRAMES = 80
MAX_FRAMES_OFFLINE = 80
MAX_FRAMES_ONLINE = 300
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
MARKERS = [1, 5] # Cross for negative, Star for positive
MARKER_SIZE = 8
@ -88,13 +90,16 @@ vggt4track_model = vggt4track_model.to("cuda")
# Global model initialization
print("🚀 Initializing local models...")
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model.eval()
tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model_offline.eval()
tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
tracker_model_online.eval()
predictor = get_sam_predictor()
print("✅ Models loaded successfully!")
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
@spaces.GPU
def gpu_run_inference(predictor_arg, image, points, boxes):
"""GPU-accelerated SAM inference"""
if predictor_arg is None:
@ -118,6 +123,7 @@ def gpu_run_inference(predictor_arg, image, points, boxes):
return run_inference(predictor_arg, image, points, boxes)
@spaces.GPU
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
"""GPU-accelerated tracking"""
import torchvision.transforms as T
@ -127,8 +133,12 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
print("Initializing tracker models inside GPU function...")
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model.cuda())
if mode == "offline":
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model_offline.cuda())
else:
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model_online.cuda())
# Setup paths
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
@ -146,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video_tensor = T.Resize((new_h, new_w))(video_tensor)
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
if mode == "offline":
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
else:
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
# Move to GPU
video_tensor = video_tensor.cuda()
@ -524,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
print(f"❌ Error in reset_points: {e}")
return None, []
def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
"""Launch visualization with user-specific temp directory"""
if original_image_state is None:
return None, None, None
@ -536,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
video_name = frame_data.get('video_name', 'video')
print(f"🚀 Starting tracking for video: {video_name}")
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
# Check for mask files
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
@ -550,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
mask_path = mask_files[0] if mask_files else None
# Run tracker
print("🎯 Running tracker...")
print(f"🎯 Running tracker in {processing_mode} mode...")
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
# Process results
npz_path = os.path.join(out_dir, "result.npz")
@ -607,6 +620,7 @@ def clear_all_with_download():
gr.update(value=50),
gr.update(value=756),
gr.update(value=3),
gr.update(value="offline"), # processing_mode
None, # tracking_video_download
None) # HTML download component
@ -639,6 +653,13 @@ def get_video_settings(video_name):
return video_settings.get(video_name, (50, 756, 3))
def update_status_indicator(processing_mode):
"""Update status indicator based on processing mode"""
if processing_mode == "offline":
return "**Status:** 🟢 Local Processing Mode (Offline)"
else:
return "**Status:** 🔵 Cloud Processing Mode (Online)"
# Create the Gradio interface
print("🎨 Creating Gradio interface...")
@ -844,7 +865,7 @@ with gr.Blocks(
""")
# Status indicator
gr.Markdown("**Status:** 🟢 Local Processing Mode")
status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
# Main content area - video upload left, 3D visualization right
with gr.Row():
@ -943,18 +964,29 @@ with gr.Blocks(
with gr.Row():
gr.Markdown("### ⚙️ Tracking Parameters")
with gr.Row():
grid_size = gr.Slider(
minimum=10, maximum=100, step=10, value=50,
label="Grid Size", info="Tracking detail level"
)
vo_points = gr.Slider(
minimum=100, maximum=2000, step=50, value=756,
label="VO Points", info="Motion accuracy"
)
fps = gr.Slider(
minimum=1, maximum=20, step=1, value=3,
label="FPS", info="Processing speed"
)
# 添加模式选择器
with gr.Column(scale=1):
processing_mode = gr.Radio(
choices=["offline", "online"],
value="offline",
label="Processing Mode",
info="Offline: default mode | Online: Sliding Window Mode"
)
with gr.Column(scale=1):
grid_size = gr.Slider(
minimum=10, maximum=100, step=10, value=50,
label="Grid Size", info="Tracking detail level"
)
with gr.Column(scale=1):
vo_points = gr.Slider(
minimum=100, maximum=2000, step=50, value=756,
label="VO Points", info="Motion accuracy"
)
with gr.Column(scale=1):
fps = gr.Slider(
minimum=1, maximum=20, step=1, value=3,
label="FPS", info="Processing speed"
)
# Advanced Point Selection with SAM - Collapsed by default
with gr.Row():
@ -1080,6 +1112,12 @@ with gr.Blocks(
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
)
processing_mode.change(
fn=update_status_indicator,
inputs=[processing_mode],
outputs=[status_indicator]
)
interactive_frame.select(
fn=select_point,
inputs=[original_image_state, selected_points, point_type],
@ -1094,12 +1132,12 @@ with gr.Blocks(
clear_all_btn.click(
fn=clear_all_with_download,
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
)
launch_btn.click(
fn=launch_viz,
inputs=[grid_size, vo_points, fps, original_image_state],
inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
outputs=[viz_html, tracking_video_download, html_download]
)