svjack commited on
Commit
a0e45b9
·
verified ·
1 Parent(s): dd5c13c

Update video_mask_app.py

Browse files
Files changed (1) hide show
  1. video_mask_app.py +15 -2
video_mask_app.py CHANGED
@@ -93,6 +93,7 @@ def process_video(input_video_path):
93
 
94
  # Process each frame
95
  frames = []
 
96
  for frame in tqdm(video_clip.iter_frames()):
97
  frame_pil = Image.fromarray(frame)
98
  frame_no_bg, mask_resized = remove_background(frame_pil)
@@ -109,12 +110,23 @@ def process_video(input_video_path):
109
  output_np = np.array(output)
110
  frames.append(output_np)
111
 
 
 
 
 
 
 
112
  # Save the processed frames as a new video
113
  output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
114
  processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
115
  processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
116
 
117
- return output_video_path
 
 
 
 
 
118
 
119
  # Gradio components
120
  slider1 = ImageSlider(label="RMBG-2.0", type="pil")
@@ -125,6 +137,7 @@ text = gr.Textbox(label="Paste an image URL")
125
  png_file = gr.File(label="output png file")
126
  video_input = gr.Video(label="Upload a video")
127
  video_output = gr.Video(label="Processed video")
 
128
 
129
  # Example videos
130
  example_videos = [
@@ -140,7 +153,7 @@ tab1 = gr.Interface(
140
 
141
  tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
142
  #tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
143
- tab4 = gr.Interface(process_video, inputs=video_input, outputs=video_output, examples=example_videos, api_name="video", cache_examples = False)
144
 
145
  # Gradio tabbed interface
146
  demo = gr.TabbedInterface(
 
93
 
94
  # Process each frame
95
  frames = []
96
+ mask_frames = []
97
  for frame in tqdm(video_clip.iter_frames()):
98
  frame_pil = Image.fromarray(frame)
99
  frame_no_bg, mask_resized = remove_background(frame_pil)
 
110
  output_np = np.array(output)
111
  frames.append(output_np)
112
 
113
+ # Create a mask frame with white foreground and black background
114
+ mask_frame = np.array(mask_resized)
115
+ mask_frame = np.stack([mask_frame, mask_frame, mask_frame], axis=-1) # Convert to 3 channels
116
+ mask_frame[mask_frame > 0] = 255 # Set foreground to white
117
+ mask_frames.append(mask_frame)
118
+
119
  # Save the processed frames as a new video
120
  output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
121
  processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
122
  processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
123
 
124
+ # Save the mask frames as a new video
125
+ mask_video_path = os.path.join(output_folder, "mask_video.mp4")
126
+ mask_clip = ImageSequenceClip(mask_frames, fps=video_clip.fps)
127
+ mask_clip.write_videofile(mask_video_path, codec='libx264')
128
+
129
+ return output_video_path, mask_video_path
130
 
131
  # Gradio components
132
  slider1 = ImageSlider(label="RMBG-2.0", type="pil")
 
137
  png_file = gr.File(label="output png file")
138
  video_input = gr.Video(label="Upload a video")
139
  video_output = gr.Video(label="Processed video")
140
+ mask_video_output = gr.Video(label="Mask video")
141
 
142
  # Example videos
143
  example_videos = [
 
153
 
154
  tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
155
  #tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
156
+ tab4 = gr.Interface(process_video, inputs=video_input, outputs=[video_output, mask_video_output], examples=example_videos, api_name="video", cache_examples = False)
157
 
158
  # Gradio tabbed interface
159
  demo = gr.TabbedInterface(