| |
| |
|
|
| |
| |
|
|
| import copy |
| import os |
| from datetime import datetime |
|
|
| import gradio as gr |
|
|
| os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" |
| import tempfile |
|
|
| import cv2 |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import spaces |
|
|
| import torch |
|
|
| from moviepy.editor import ImageSequenceClip |
| from PIL import Image |
| from sam2.build_sam import build_sam2_video_predictor |
|
|
| |
| title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" |
|
|
| description_p = """# Instructions |
| <ol> |
| <li> Upload one video or click one example video</li> |
| <li> Click 'include' point type, select the object to segment and track</li> |
| <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> |
| <li> Click the 'Track' button to obtain the masked video </li> |
| </ol> |
| """ |
|
|
| |
| examples = [ |
| ["examples/trimmed/01_dog.mp4"], |
| ["examples/trimmed/02_cups.mp4"], |
| ["examples/trimmed/03_blocks.mp4"], |
| ["examples/trimmed/04_coffee.mp4"], |
| ["examples/trimmed/05_default_juggle.mp4"], |
| ["examples/trimmed/01_breakdancer.mp4"], |
| ["examples/trimmed/02_hummingbird.mp4"], |
| ["examples/trimmed/03_skateboarder.mp4"], |
| ["examples/trimmed/04_octopus.mp4"], |
| ["examples/trimmed/05_landing_dog_soccer.mp4"], |
| ["examples/trimmed/06_pingpong.mp4"], |
| ["examples/trimmed/07_snowboarder.mp4"], |
| ["examples/trimmed/08_driving.mp4"], |
| ["examples/trimmed/09_birdcartoon.mp4"], |
| ["examples/trimmed/10_cloth_magic.mp4"], |
| ["examples/trimmed/11_polevault.mp4"], |
| ["examples/trimmed/12_hideandseek.mp4"], |
| ["examples/trimmed/13_butterfly.mp4"], |
| ["examples/trimmed/14_social_dog_training.mp4"], |
| ["examples/trimmed/15_cricket.mp4"], |
| ["examples/trimmed/16_robotarm.mp4"], |
| ["examples/trimmed/17_childrendancing.mp4"], |
| ["examples/trimmed/18_threedogs.mp4"], |
| ["examples/trimmed/19_cyclist.mp4"], |
| ["examples/trimmed/20_doughkneading.mp4"], |
| ["examples/trimmed/21_biker.mp4"], |
| ["examples/trimmed/22_dogskateboarder.mp4"], |
| ["examples/trimmed/23_racecar.mp4"], |
| ["examples/trimmed/24_clownfish.mp4"], |
| ] |
|
|
| OBJ_ID = 0 |
| sam2_checkpoint = "checkpoints/edgetam.pt" |
| model_cfg = "edgetam.yaml" |
| predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
|
|
|
|
| def get_video_fps(video_path): |
| |
| cap = cv2.VideoCapture(video_path) |
|
|
| if not cap.isOpened(): |
| print("Error: Could not open video.") |
| return None |
|
|
| |
| fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
| return fps |
|
|
|
|
| def reset_state(inference_state): |
| for v in inference_state["point_inputs_per_obj"].values(): |
| v.clear() |
| for v in inference_state["mask_inputs_per_obj"].values(): |
| v.clear() |
| for v in inference_state["output_dict_per_obj"].values(): |
| v["cond_frame_outputs"].clear() |
| v["non_cond_frame_outputs"].clear() |
| for v in inference_state["temp_output_dict_per_obj"].values(): |
| v["cond_frame_outputs"].clear() |
| v["non_cond_frame_outputs"].clear() |
| inference_state["output_dict"]["cond_frame_outputs"].clear() |
| inference_state["output_dict"]["non_cond_frame_outputs"].clear() |
| inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() |
| inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() |
| inference_state["tracking_has_started"] = False |
| inference_state["frames_already_tracked"].clear() |
| inference_state["obj_id_to_idx"].clear() |
| inference_state["obj_idx_to_id"].clear() |
| inference_state["obj_ids"].clear() |
| inference_state["point_inputs_per_obj"].clear() |
| inference_state["mask_inputs_per_obj"].clear() |
| inference_state["output_dict_per_obj"].clear() |
| inference_state["temp_output_dict_per_obj"].clear() |
| return inference_state |
|
|
|
|
| def reset( |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ): |
| first_frame = None |
| all_frames = None |
| input_points = [] |
| input_labels = [] |
|
|
| inference_state = None |
| return ( |
| None, |
| gr.update(open=True), |
| None, |
| None, |
| gr.update(value=None, visible=False), |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ) |
|
|
|
|
| def clear_points( |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ): |
| input_points = [] |
| input_labels = [] |
| if inference_state and inference_state["tracking_has_started"]: |
| inference_state = reset_state(inference_state) |
| return ( |
| first_frame, |
| None, |
| gr.update(value=None, visible=False), |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ) |
|
|
|
|
| def preprocess_video_in( |
| video_path, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ): |
| if video_path is None: |
| return ( |
| gr.update(open=True), |
| None, |
| None, |
| gr.update(value=None, visible=False), |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ) |
|
|
| |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| print("Error: Could not open video.") |
| return ( |
| gr.update(open=True), |
| None, |
| None, |
| gr.update(value=None, visible=False), |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ) |
|
|
| frame_number = 0 |
| _first_frame = None |
| all_frames = [] |
|
|
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frame = np.array(frame) |
|
|
| |
| if frame_number == 0: |
| _first_frame = frame |
| all_frames.append(frame) |
|
|
| frame_number += 1 |
|
|
| cap.release() |
| first_frame = copy.deepcopy(_first_frame) |
| input_points = [] |
| input_labels = [] |
|
|
| predictor.to("cpu") |
| inference_state = predictor.init_state( |
| offload_video_to_cpu=True, |
| offload_state_to_cpu=True, |
| video_path=video_path, |
| ) |
|
|
| return [ |
| gr.update(open=False), |
| first_frame, |
| None, |
| gr.update(value=None, visible=False), |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ] |
|
|
|
|
| def segment_with_points( |
| point_type, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| evt: gr.SelectData, |
| ): |
| predictor.to("cpu") |
| if inference_state: |
| inference_state["device"] = predictor.device |
| input_points.append(evt.index) |
| print(f"TRACKING INPUT POINT: {input_points}") |
|
|
| if point_type == "include": |
| input_labels.append(1) |
| elif point_type == "exclude": |
| input_labels.append(0) |
| print(f"TRACKING INPUT LABEL: {input_labels}") |
|
|
| |
| transparent_background = Image.fromarray(first_frame).convert("RGBA") |
| w, h = transparent_background.size |
|
|
| |
| fraction = 0.01 |
| radius = int(fraction * min(w, h)) |
|
|
| |
| transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
| for index, track in enumerate(input_points): |
| if input_labels[index] == 1: |
| cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) |
| else: |
| cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) |
|
|
| |
| transparent_layer = Image.fromarray(transparent_layer, "RGBA") |
| selected_point_map = Image.alpha_composite( |
| transparent_background, transparent_layer |
| ) |
|
|
| |
| points = np.array(input_points, dtype=np.float32) |
| |
| labels = np.array(input_labels, dtype=np.int32) |
| _, _, out_mask_logits = predictor.add_new_points( |
| inference_state=inference_state, |
| frame_idx=0, |
| obj_id=OBJ_ID, |
| points=points, |
| labels=labels, |
| ) |
|
|
| mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy()) |
| first_frame_output = Image.alpha_composite(transparent_background, mask_image) |
|
|
| torch.cuda.empty_cache() |
| return ( |
| selected_point_map, |
| first_frame_output, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ) |
|
|
|
|
| def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): |
| if random_color: |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| else: |
| cmap = plt.get_cmap("tab10") |
| cmap_idx = 0 if obj_id is None else obj_id |
| color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
| h, w = mask.shape[-2:] |
| mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| mask = (mask * 255).astype(np.uint8) |
| if convert_to_image: |
| mask = Image.fromarray(mask, "RGBA") |
| return mask |
|
|
|
|
| @spaces.GPU(duration=60) |
| def propagate_to_all( |
| video_in, |
| all_frames, |
| input_points, |
| inference_state, |
| ): |
| if torch.cuda.get_device_properties(0).major >= 8: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| predictor.to("cuda") |
| if inference_state: |
| inference_state["device"] = predictor.device |
|
|
| if len(input_points) == 0 or video_in is None or inference_state is None: |
| return None |
| |
| video_segments = ( |
| {} |
| ) |
| print("starting propagate_in_video") |
| for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
| inference_state |
| ): |
| video_segments[out_frame_idx] = { |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
| for i, out_obj_id in enumerate(out_obj_ids) |
| } |
|
|
| |
| vis_frame_stride = 1 |
|
|
| output_frames = [] |
| for out_frame_idx in range(0, len(video_segments), vis_frame_stride): |
| transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert( |
| "RGBA" |
| ) |
| out_mask = video_segments[out_frame_idx][OBJ_ID] |
| mask_image = show_mask(out_mask) |
| output_frame = Image.alpha_composite(transparent_background, mask_image) |
| output_frame = np.array(output_frame) |
| output_frames.append(output_frame) |
|
|
| torch.cuda.empty_cache() |
|
|
| |
| original_fps = get_video_fps(video_in) |
| fps = original_fps |
| clip = ImageSequenceClip(output_frames, fps=fps) |
| |
| unique_id = datetime.now().strftime("%Y%m%d%H%M%S") |
| final_vid_output_path = f"output_video_{unique_id}.mp4" |
| final_vid_output_path = os.path.join( |
| tempfile.gettempdir(), final_vid_output_path |
| ) |
|
|
| |
| clip.write_videofile(final_vid_output_path, codec="libx264") |
|
|
| return gr.update(value=final_vid_output_path) |
|
|
|
|
| def update_ui(): |
| return gr.update(visible=True) |
|
|
|
|
| with gr.Blocks() as demo: |
| first_frame = gr.State() |
| all_frames = gr.State() |
| input_points = gr.State([]) |
| input_labels = gr.State([]) |
| inference_state = gr.State() |
|
|
| with gr.Column(): |
| |
| gr.Markdown(title) |
| with gr.Row(): |
|
|
| with gr.Column(): |
| |
| gr.Markdown(description_p) |
|
|
| with gr.Accordion("Input Video", open=True) as video_in_drawer: |
| video_in = gr.Video(label="Input Video", format="mp4") |
|
|
| with gr.Row(): |
| point_type = gr.Radio( |
| label="point type", |
| choices=["include", "exclude"], |
| value="include", |
| scale=2, |
| ) |
| propagate_btn = gr.Button("Track", scale=1, variant="primary") |
| clear_points_btn = gr.Button("Clear Points", scale=1) |
| reset_btn = gr.Button("Reset", scale=1) |
|
|
| points_map = gr.Image( |
| label="Frame with Point Prompt", type="numpy", interactive=False |
| ) |
|
|
| with gr.Column(): |
| gr.Markdown("# Try some of the examples below ⬇️") |
| gr.Examples( |
| examples=examples, |
| inputs=[ |
| video_in, |
| ], |
| examples_per_page=8, |
| ) |
| gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
| gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
| gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
| output_image = gr.Image(label="Reference Mask") |
|
|
| output_video = gr.Video(visible=False) |
|
|
| |
| video_in.upload( |
| fn=preprocess_video_in, |
| inputs=[ |
| video_in, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| outputs=[ |
| video_in_drawer, |
| points_map, |
| output_image, |
| output_video, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| queue=False, |
| ) |
|
|
| video_in.change( |
| fn=preprocess_video_in, |
| inputs=[ |
| video_in, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| outputs=[ |
| video_in_drawer, |
| points_map, |
| output_image, |
| output_video, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| queue=False, |
| ) |
|
|
| |
| points_map.select( |
| fn=segment_with_points, |
| inputs=[ |
| point_type, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| outputs=[ |
| points_map, |
| output_image, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| queue=False, |
| ) |
|
|
| |
| clear_points_btn.click( |
| fn=clear_points, |
| inputs=[ |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| outputs=[ |
| points_map, |
| output_image, |
| output_video, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| queue=False, |
| ) |
|
|
| reset_btn.click( |
| fn=reset, |
| inputs=[ |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| outputs=[ |
| video_in, |
| video_in_drawer, |
| points_map, |
| output_image, |
| output_video, |
| first_frame, |
| all_frames, |
| input_points, |
| input_labels, |
| inference_state, |
| ], |
| queue=False, |
| ) |
|
|
| propagate_btn.click( |
| fn=update_ui, |
| inputs=[], |
| outputs=output_video, |
| queue=False, |
| ).then( |
| fn=propagate_to_all, |
| inputs=[ |
| video_in, |
| all_frames, |
| input_points, |
| inference_state, |
| ], |
| outputs=[ |
| output_video, |
| ], |
| concurrency_limit=10, |
| queue=False, |
| ) |
|
|
|
|
| |
| demo.launch() |
|
|