online_prev
This commit is contained in:
parent
3b9e9a1e8e
commit
eb0ce33864
90
app.py
90
app.py
@ -43,7 +43,9 @@ except ImportError as e:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MAX_FRAMES = 80
|
MAX_FRAMES_OFFLINE = 80
|
||||||
|
MAX_FRAMES_ONLINE = 300
|
||||||
|
|
||||||
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
||||||
MARKERS = [1, 5] # Cross for negative, Star for positive
|
MARKERS = [1, 5] # Cross for negative, Star for positive
|
||||||
MARKER_SIZE = 8
|
MARKER_SIZE = 8
|
||||||
@ -88,13 +90,16 @@ vggt4track_model = vggt4track_model.to("cuda")
|
|||||||
|
|
||||||
# Global model initialization
|
# Global model initialization
|
||||||
print("🚀 Initializing local models...")
|
print("🚀 Initializing local models...")
|
||||||
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
||||||
tracker_model.eval()
|
tracker_model_offline.eval()
|
||||||
|
tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
|
||||||
|
tracker_model_online.eval()
|
||||||
predictor = get_sam_predictor()
|
predictor = get_sam_predictor()
|
||||||
print("✅ Models loaded successfully!")
|
print("✅ Models loaded successfully!")
|
||||||
|
|
||||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
||||||
|
|
||||||
|
@spaces.GPU
|
||||||
def gpu_run_inference(predictor_arg, image, points, boxes):
|
def gpu_run_inference(predictor_arg, image, points, boxes):
|
||||||
"""GPU-accelerated SAM inference"""
|
"""GPU-accelerated SAM inference"""
|
||||||
if predictor_arg is None:
|
if predictor_arg is None:
|
||||||
@ -118,6 +123,7 @@ def gpu_run_inference(predictor_arg, image, points, boxes):
|
|||||||
|
|
||||||
return run_inference(predictor_arg, image, points, boxes)
|
return run_inference(predictor_arg, image, points, boxes)
|
||||||
|
|
||||||
|
@spaces.GPU
|
||||||
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
|
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
|
||||||
"""GPU-accelerated tracking"""
|
"""GPU-accelerated tracking"""
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
@ -126,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|||||||
if tracker_model_arg is None or tracker_viser_arg is None:
|
if tracker_model_arg is None or tracker_viser_arg is None:
|
||||||
print("Initializing tracker models inside GPU function...")
|
print("Initializing tracker models inside GPU function...")
|
||||||
out_dir = os.path.join(temp_dir, "results")
|
out_dir = os.path.join(temp_dir, "results")
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
if mode == "offline":
|
||||||
tracker_model=tracker_model.cuda())
|
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
||||||
|
tracker_model=tracker_model_offline.cuda())
|
||||||
|
else:
|
||||||
|
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
||||||
|
tracker_model=tracker_model_online.cuda())
|
||||||
|
|
||||||
# Setup paths
|
# Setup paths
|
||||||
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
||||||
@ -146,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|||||||
if scale < 1:
|
if scale < 1:
|
||||||
new_h, new_w = int(h * scale), int(w * scale)
|
new_h, new_w = int(h * scale), int(w * scale)
|
||||||
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
||||||
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
|
if mode == "offline":
|
||||||
|
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
|
||||||
|
else:
|
||||||
|
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
|
||||||
|
|
||||||
# Move to GPU
|
# Move to GPU
|
||||||
video_tensor = video_tensor.cuda()
|
video_tensor = video_tensor.cuda()
|
||||||
@ -524,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
|
|||||||
print(f"❌ Error in reset_points: {e}")
|
print(f"❌ Error in reset_points: {e}")
|
||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
|
||||||
"""Launch visualization with user-specific temp directory"""
|
"""Launch visualization with user-specific temp directory"""
|
||||||
if original_image_state is None:
|
if original_image_state is None:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
@ -536,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
|||||||
video_name = frame_data.get('video_name', 'video')
|
video_name = frame_data.get('video_name', 'video')
|
||||||
|
|
||||||
print(f"🚀 Starting tracking for video: {video_name}")
|
print(f"🚀 Starting tracking for video: {video_name}")
|
||||||
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
|
||||||
|
|
||||||
# Check for mask files
|
# Check for mask files
|
||||||
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
||||||
@ -550,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
|||||||
mask_path = mask_files[0] if mask_files else None
|
mask_path = mask_files[0] if mask_files else None
|
||||||
|
|
||||||
# Run tracker
|
# Run tracker
|
||||||
print("🎯 Running tracker...")
|
print(f"🎯 Running tracker in {processing_mode} mode...")
|
||||||
out_dir = os.path.join(temp_dir, "results")
|
out_dir = os.path.join(temp_dir, "results")
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
|
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
|
||||||
|
|
||||||
# Process results
|
# Process results
|
||||||
npz_path = os.path.join(out_dir, "result.npz")
|
npz_path = os.path.join(out_dir, "result.npz")
|
||||||
@ -607,6 +620,7 @@ def clear_all_with_download():
|
|||||||
gr.update(value=50),
|
gr.update(value=50),
|
||||||
gr.update(value=756),
|
gr.update(value=756),
|
||||||
gr.update(value=3),
|
gr.update(value=3),
|
||||||
|
gr.update(value="offline"), # processing_mode
|
||||||
None, # tracking_video_download
|
None, # tracking_video_download
|
||||||
None) # HTML download component
|
None) # HTML download component
|
||||||
|
|
||||||
@ -639,6 +653,13 @@ def get_video_settings(video_name):
|
|||||||
|
|
||||||
return video_settings.get(video_name, (50, 756, 3))
|
return video_settings.get(video_name, (50, 756, 3))
|
||||||
|
|
||||||
|
def update_status_indicator(processing_mode):
|
||||||
|
"""Update status indicator based on processing mode"""
|
||||||
|
if processing_mode == "offline":
|
||||||
|
return "**Status:** 🟢 Local Processing Mode (Offline)"
|
||||||
|
else:
|
||||||
|
return "**Status:** 🔵 Cloud Processing Mode (Online)"
|
||||||
|
|
||||||
# Create the Gradio interface
|
# Create the Gradio interface
|
||||||
print("🎨 Creating Gradio interface...")
|
print("🎨 Creating Gradio interface...")
|
||||||
|
|
||||||
@ -844,7 +865,7 @@ with gr.Blocks(
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
# Status indicator
|
# Status indicator
|
||||||
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
|
||||||
|
|
||||||
# Main content area - video upload left, 3D visualization right
|
# Main content area - video upload left, 3D visualization right
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -943,18 +964,29 @@ with gr.Blocks(
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown("### ⚙️ Tracking Parameters")
|
gr.Markdown("### ⚙️ Tracking Parameters")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
grid_size = gr.Slider(
|
# 添加模式选择器
|
||||||
minimum=10, maximum=100, step=10, value=50,
|
with gr.Column(scale=1):
|
||||||
label="Grid Size", info="Tracking detail level"
|
processing_mode = gr.Radio(
|
||||||
)
|
choices=["offline", "online"],
|
||||||
vo_points = gr.Slider(
|
value="offline",
|
||||||
minimum=100, maximum=2000, step=50, value=756,
|
label="Processing Mode",
|
||||||
label="VO Points", info="Motion accuracy"
|
info="Offline: default mode | Online: Sliding Window Mode"
|
||||||
)
|
)
|
||||||
fps = gr.Slider(
|
with gr.Column(scale=1):
|
||||||
minimum=1, maximum=20, step=1, value=3,
|
grid_size = gr.Slider(
|
||||||
label="FPS", info="Processing speed"
|
minimum=10, maximum=100, step=10, value=50,
|
||||||
)
|
label="Grid Size", info="Tracking detail level"
|
||||||
|
)
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
vo_points = gr.Slider(
|
||||||
|
minimum=100, maximum=2000, step=50, value=756,
|
||||||
|
label="VO Points", info="Motion accuracy"
|
||||||
|
)
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
fps = gr.Slider(
|
||||||
|
minimum=1, maximum=20, step=1, value=3,
|
||||||
|
label="FPS", info="Processing speed"
|
||||||
|
)
|
||||||
|
|
||||||
# Advanced Point Selection with SAM - Collapsed by default
|
# Advanced Point Selection with SAM - Collapsed by default
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1080,6 +1112,12 @@ with gr.Blocks(
|
|||||||
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
processing_mode.change(
|
||||||
|
fn=update_status_indicator,
|
||||||
|
inputs=[processing_mode],
|
||||||
|
outputs=[status_indicator]
|
||||||
|
)
|
||||||
|
|
||||||
interactive_frame.select(
|
interactive_frame.select(
|
||||||
fn=select_point,
|
fn=select_point,
|
||||||
inputs=[original_image_state, selected_points, point_type],
|
inputs=[original_image_state, selected_points, point_type],
|
||||||
@ -1094,12 +1132,12 @@ with gr.Blocks(
|
|||||||
|
|
||||||
clear_all_btn.click(
|
clear_all_btn.click(
|
||||||
fn=clear_all_with_download,
|
fn=clear_all_with_download,
|
||||||
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
|
||||||
)
|
)
|
||||||
|
|
||||||
launch_btn.click(
|
launch_btn.click(
|
||||||
fn=launch_viz,
|
fn=launch_viz,
|
||||||
inputs=[grid_size, vo_points, fps, original_image_state],
|
inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
|
||||||
outputs=[viz_html, tracking_video_download, html_download]
|
outputs=[viz_html, tracking_video_download, html_download]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user