From eb0ce33864a960807343161dd02fd107f0782b59 Mon Sep 17 00:00:00 2001 From: xiaoyuxi Date: Sun, 13 Jul 2025 15:11:20 +0800 Subject: [PATCH] online_prev --- app.py | 90 +++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 26 deletions(-) diff --git a/app.py b/app.py index 2255da9..dca6660 100644 --- a/app.py +++ b/app.py @@ -43,7 +43,9 @@ except ImportError as e: raise # Constants -MAX_FRAMES = 80 +MAX_FRAMES_OFFLINE = 80 +MAX_FRAMES_ONLINE = 300 + COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive MARKERS = [1, 5] # Cross for negative, Star for positive MARKER_SIZE = 8 @@ -88,13 +90,16 @@ vggt4track_model = vggt4track_model.to("cuda") # Global model initialization print("๐Ÿš€ Initializing local models...") -tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") -tracker_model.eval() +tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") +tracker_model_offline.eval() +tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online") +tracker_model_online.eval() predictor = get_sam_predictor() print("โœ… Models loaded successfully!") gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"]) +@spaces.GPU def gpu_run_inference(predictor_arg, image, points, boxes): """GPU-accelerated SAM inference""" if predictor_arg is None: @@ -118,6 +123,7 @@ def gpu_run_inference(predictor_arg, image, points, boxes): return run_inference(predictor_arg, image, points, boxes) +@spaces.GPU def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"): """GPU-accelerated tracking""" import torchvision.transforms as T @@ -126,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, if tracker_model_arg is None or tracker_viser_arg is None: print("Initializing tracker models inside GPU function...") out_dir = os.path.join(temp_dir, "results") - os.makedirs(out_dir, exist_ok=True) - tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, - tracker_model=tracker_model.cuda()) + os.makedirs(out_dir, exist_ok=True) + if mode == "offline": + tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, + tracker_model=tracker_model_offline.cuda()) + else: + tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, + tracker_model=tracker_model_online.cuda()) # Setup paths video_path = os.path.join(temp_dir, f"{video_name}.mp4") @@ -146,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, if scale < 1: new_h, new_w = int(h * scale), int(w * scale) video_tensor = T.Resize((new_h, new_w))(video_tensor) - video_tensor = video_tensor[::fps].float()[:MAX_FRAMES] + if mode == "offline": + video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE] + else: + video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE] # Move to GPU video_tensor = video_tensor.cuda() @@ -524,7 +537,7 @@ def reset_points(original_img: str, sel_pix): print(f"โŒ Error in reset_points: {e}") return None, [] -def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"): +def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode): """Launch visualization with user-specific temp directory""" if original_image_state is None: return None, None, None @@ -536,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"): video_name = frame_data.get('video_name', 'video') print(f"๐Ÿš€ Starting tracking for video: {video_name}") - print(f"๐Ÿ“Š Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}") + print(f"๐Ÿ“Š Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}") # Check for mask files mask_files = glob.glob(os.path.join(temp_dir, "*.png")) @@ -550,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"): mask_path = mask_files[0] if mask_files else None # Run tracker - print("๐ŸŽฏ Running tracker...") + print(f"๐ŸŽฏ Running tracker in {processing_mode} mode...") out_dir = os.path.join(temp_dir, "results") os.makedirs(out_dir, exist_ok=True) - gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode) + gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode) # Process results npz_path = os.path.join(out_dir, "result.npz") @@ -607,6 +620,7 @@ def clear_all_with_download(): gr.update(value=50), gr.update(value=756), gr.update(value=3), + gr.update(value="offline"), # processing_mode None, # tracking_video_download None) # HTML download component @@ -639,6 +653,13 @@ def get_video_settings(video_name): return video_settings.get(video_name, (50, 756, 3)) +def update_status_indicator(processing_mode): + """Update status indicator based on processing mode""" + if processing_mode == "offline": + return "**Status:** ๐ŸŸข Local Processing Mode (Offline)" + else: + return "**Status:** ๐Ÿ”ต Cloud Processing Mode (Online)" + # Create the Gradio interface print("๐ŸŽจ Creating Gradio interface...") @@ -844,7 +865,7 @@ with gr.Blocks( """) # Status indicator - gr.Markdown("**Status:** ๐ŸŸข Local Processing Mode") + status_indicator = gr.Markdown("**Status:** ๐ŸŸข Local Processing Mode (Offline)") # Main content area - video upload left, 3D visualization right with gr.Row(): @@ -943,18 +964,29 @@ with gr.Blocks( with gr.Row(): gr.Markdown("### โš™๏ธ Tracking Parameters") with gr.Row(): - grid_size = gr.Slider( - minimum=10, maximum=100, step=10, value=50, - label="Grid Size", info="Tracking detail level" - ) - vo_points = gr.Slider( - minimum=100, maximum=2000, step=50, value=756, - label="VO Points", info="Motion accuracy" - ) - fps = gr.Slider( - minimum=1, maximum=20, step=1, value=3, - label="FPS", info="Processing speed" - ) + # ๆทปๅŠ ๆจกๅผ้€‰ๆ‹ฉๅ™จ + with gr.Column(scale=1): + processing_mode = gr.Radio( + choices=["offline", "online"], + value="offline", + label="Processing Mode", + info="Offline: default mode | Online: Sliding Window Mode" + ) + with gr.Column(scale=1): + grid_size = gr.Slider( + minimum=10, maximum=100, step=10, value=50, + label="Grid Size", info="Tracking detail level" + ) + with gr.Column(scale=1): + vo_points = gr.Slider( + minimum=100, maximum=2000, step=50, value=756, + label="VO Points", info="Motion accuracy" + ) + with gr.Column(scale=1): + fps = gr.Slider( + minimum=1, maximum=20, step=1, value=3, + label="FPS", info="Processing speed" + ) # Advanced Point Selection with SAM - Collapsed by default with gr.Row(): @@ -1080,6 +1112,12 @@ with gr.Blocks( outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps] ) + processing_mode.change( + fn=update_status_indicator, + inputs=[processing_mode], + outputs=[status_indicator] + ) + interactive_frame.select( fn=select_point, inputs=[original_image_state, selected_points, point_type], @@ -1094,12 +1132,12 @@ with gr.Blocks( clear_all_btn.click( fn=clear_all_with_download, - outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download] + outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download] ) launch_btn.click( fn=launch_viz, - inputs=[grid_size, vo_points, fps, original_image_state], + inputs=[grid_size, vo_points, fps, original_image_state, processing_mode], outputs=[viz_html, tracking_video_download, html_download] )