online_prev
This commit is contained in:
parent
3b9e9a1e8e
commit
eb0ce33864
90
app.py
90
app.py
@ -43,7 +43,9 @@ except ImportError as e:
|
||||
raise
|
||||
|
||||
# Constants
|
||||
MAX_FRAMES = 80
|
||||
MAX_FRAMES_OFFLINE = 80
|
||||
MAX_FRAMES_ONLINE = 300
|
||||
|
||||
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
||||
MARKERS = [1, 5] # Cross for negative, Star for positive
|
||||
MARKER_SIZE = 8
|
||||
@ -88,13 +90,16 @@ vggt4track_model = vggt4track_model.to("cuda")
|
||||
|
||||
# Global model initialization
|
||||
print("🚀 Initializing local models...")
|
||||
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
||||
tracker_model.eval()
|
||||
tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
||||
tracker_model_offline.eval()
|
||||
tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
|
||||
tracker_model_online.eval()
|
||||
predictor = get_sam_predictor()
|
||||
print("✅ Models loaded successfully!")
|
||||
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
||||
|
||||
@spaces.GPU
|
||||
def gpu_run_inference(predictor_arg, image, points, boxes):
|
||||
"""GPU-accelerated SAM inference"""
|
||||
if predictor_arg is None:
|
||||
@ -118,6 +123,7 @@ def gpu_run_inference(predictor_arg, image, points, boxes):
|
||||
|
||||
return run_inference(predictor_arg, image, points, boxes)
|
||||
|
||||
@spaces.GPU
|
||||
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
|
||||
"""GPU-accelerated tracking"""
|
||||
import torchvision.transforms as T
|
||||
@ -126,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
||||
if tracker_model_arg is None or tracker_viser_arg is None:
|
||||
print("Initializing tracker models inside GPU function...")
|
||||
out_dir = os.path.join(temp_dir, "results")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
||||
tracker_model=tracker_model.cuda())
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if mode == "offline":
|
||||
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
||||
tracker_model=tracker_model_offline.cuda())
|
||||
else:
|
||||
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
||||
tracker_model=tracker_model_online.cuda())
|
||||
|
||||
# Setup paths
|
||||
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
||||
@ -146,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
||||
if scale < 1:
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
||||
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
|
||||
if mode == "offline":
|
||||
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
|
||||
else:
|
||||
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
|
||||
|
||||
# Move to GPU
|
||||
video_tensor = video_tensor.cuda()
|
||||
@ -524,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
|
||||
print(f"❌ Error in reset_points: {e}")
|
||||
return None, []
|
||||
|
||||
def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
||||
def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
|
||||
"""Launch visualization with user-specific temp directory"""
|
||||
if original_image_state is None:
|
||||
return None, None, None
|
||||
@ -536,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
||||
video_name = frame_data.get('video_name', 'video')
|
||||
|
||||
print(f"🚀 Starting tracking for video: {video_name}")
|
||||
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
||||
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
|
||||
|
||||
# Check for mask files
|
||||
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
||||
@ -550,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
||||
mask_path = mask_files[0] if mask_files else None
|
||||
|
||||
# Run tracker
|
||||
print("🎯 Running tracker...")
|
||||
print(f"🎯 Running tracker in {processing_mode} mode...")
|
||||
out_dir = os.path.join(temp_dir, "results")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
|
||||
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
|
||||
|
||||
# Process results
|
||||
npz_path = os.path.join(out_dir, "result.npz")
|
||||
@ -607,6 +620,7 @@ def clear_all_with_download():
|
||||
gr.update(value=50),
|
||||
gr.update(value=756),
|
||||
gr.update(value=3),
|
||||
gr.update(value="offline"), # processing_mode
|
||||
None, # tracking_video_download
|
||||
None) # HTML download component
|
||||
|
||||
@ -639,6 +653,13 @@ def get_video_settings(video_name):
|
||||
|
||||
return video_settings.get(video_name, (50, 756, 3))
|
||||
|
||||
def update_status_indicator(processing_mode):
|
||||
"""Update status indicator based on processing mode"""
|
||||
if processing_mode == "offline":
|
||||
return "**Status:** 🟢 Local Processing Mode (Offline)"
|
||||
else:
|
||||
return "**Status:** 🔵 Cloud Processing Mode (Online)"
|
||||
|
||||
# Create the Gradio interface
|
||||
print("🎨 Creating Gradio interface...")
|
||||
|
||||
@ -844,7 +865,7 @@ with gr.Blocks(
|
||||
""")
|
||||
|
||||
# Status indicator
|
||||
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
||||
status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
|
||||
|
||||
# Main content area - video upload left, 3D visualization right
|
||||
with gr.Row():
|
||||
@ -943,18 +964,29 @@ with gr.Blocks(
|
||||
with gr.Row():
|
||||
gr.Markdown("### ⚙️ Tracking Parameters")
|
||||
with gr.Row():
|
||||
grid_size = gr.Slider(
|
||||
minimum=10, maximum=100, step=10, value=50,
|
||||
label="Grid Size", info="Tracking detail level"
|
||||
)
|
||||
vo_points = gr.Slider(
|
||||
minimum=100, maximum=2000, step=50, value=756,
|
||||
label="VO Points", info="Motion accuracy"
|
||||
)
|
||||
fps = gr.Slider(
|
||||
minimum=1, maximum=20, step=1, value=3,
|
||||
label="FPS", info="Processing speed"
|
||||
)
|
||||
# 添加模式选择器
|
||||
with gr.Column(scale=1):
|
||||
processing_mode = gr.Radio(
|
||||
choices=["offline", "online"],
|
||||
value="offline",
|
||||
label="Processing Mode",
|
||||
info="Offline: default mode | Online: Sliding Window Mode"
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
grid_size = gr.Slider(
|
||||
minimum=10, maximum=100, step=10, value=50,
|
||||
label="Grid Size", info="Tracking detail level"
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
vo_points = gr.Slider(
|
||||
minimum=100, maximum=2000, step=50, value=756,
|
||||
label="VO Points", info="Motion accuracy"
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
fps = gr.Slider(
|
||||
minimum=1, maximum=20, step=1, value=3,
|
||||
label="FPS", info="Processing speed"
|
||||
)
|
||||
|
||||
# Advanced Point Selection with SAM - Collapsed by default
|
||||
with gr.Row():
|
||||
@ -1080,6 +1112,12 @@ with gr.Blocks(
|
||||
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
||||
)
|
||||
|
||||
processing_mode.change(
|
||||
fn=update_status_indicator,
|
||||
inputs=[processing_mode],
|
||||
outputs=[status_indicator]
|
||||
)
|
||||
|
||||
interactive_frame.select(
|
||||
fn=select_point,
|
||||
inputs=[original_image_state, selected_points, point_type],
|
||||
@ -1094,12 +1132,12 @@ with gr.Blocks(
|
||||
|
||||
clear_all_btn.click(
|
||||
fn=clear_all_with_download,
|
||||
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
||||
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
|
||||
)
|
||||
|
||||
launch_btn.click(
|
||||
fn=launch_viz,
|
||||
inputs=[grid_size, vo_points, fps, original_image_state],
|
||||
inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
|
||||
outputs=[viz_html, tracking_video_download, html_download]
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user