svjack commited on
Commit
dd5c13c
·
verified ·
1 Parent(s): ab7a8c7

Create video_mask_app.py

Browse files
Files changed (1) hide show
  1. video_mask_app.py +151 -0
video_mask_app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
+ from loadimg import load_img
5
+ import spaces
6
+ from transformers import AutoModelForImageSegmentation
7
+ import torch
8
+ from torchvision import transforms
9
+ from PIL import Image, ImageChops
10
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from uuid import uuid1
14
+
15
+ # Check CUDA availability
16
+ if torch.cuda.is_available():
17
+ device = "cuda"
18
+ else:
19
+ device = "cpu"
20
+
21
+ torch.set_float32_matmul_precision(["high", "highest"][0])
22
+
23
+ # Load the model
24
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
25
+ "briaai/RMBG-2.0", trust_remote_code=True
26
+ )
27
+ birefnet.to(device)
28
+ transform_image = transforms.Compose(
29
+ [
30
+ transforms.Resize((1024, 1024)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
+ ]
34
+ )
35
+
36
+ output_folder = 'output_images'
37
+ if not os.path.exists(output_folder):
38
+ os.makedirs(output_folder)
39
+
40
+ def fn(image):
41
+ im = load_img(image, output_type="pil")
42
+ im = im.convert("RGB")
43
+ origin = im.copy()
44
+ image = process(im)
45
+ image_path = os.path.join(output_folder, "no_bg_image.png")
46
+ image.save(image_path)
47
+ return (image, origin), image_path
48
+
49
+ @spaces.GPU
50
+ def process(image):
51
+ image_size = image.size
52
+ input_images = transform_image(image).unsqueeze(0).to(device)
53
+ # Prediction
54
+ with torch.no_grad():
55
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
56
+ pred = preds[0].squeeze()
57
+ pred_pil = transforms.ToPILImage()(pred)
58
+ mask = pred_pil.resize(image_size)
59
+ image.putalpha(mask)
60
+ return image
61
+
62
+ def process_file(f):
63
+ name_path = f.rsplit(".",1)[0]+".png"
64
+ im = load_img(f, output_type="pil")
65
+ im = im.convert("RGB")
66
+ transparent = process(im)
67
+ transparent.save(name_path)
68
+ return name_path
69
+
70
+ def remove_background(image):
71
+ """Remove background from a single image."""
72
+ input_images = transform_image(image).unsqueeze(0).to(device)
73
+
74
+ # Prediction
75
+ with torch.no_grad():
76
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
77
+ pred = preds[0].squeeze()
78
+
79
+ # Convert the prediction to a mask
80
+ mask = (pred * 255).byte() # Convert to 0-255 range
81
+ mask_pil = transforms.ToPILImage()(mask).convert("L")
82
+ mask_resized = mask_pil.resize(image.size, Image.LANCZOS)
83
+
84
+ # Apply the mask to the image
85
+ image.putalpha(mask_resized)
86
+
87
+ return image, mask_resized
88
+
89
+ def process_video(input_video_path):
90
+ """Process a video to remove the background from each frame."""
91
+ # Load the video
92
+ video_clip = VideoFileClip(input_video_path)
93
+
94
+ # Process each frame
95
+ frames = []
96
+ for frame in tqdm(video_clip.iter_frames()):
97
+ frame_pil = Image.fromarray(frame)
98
+ frame_no_bg, mask_resized = remove_background(frame_pil)
99
+ path = "{}.png".format(uuid1())
100
+ frame_no_bg.save(path)
101
+ frame_no_bg = Image.open(path).convert("RGBA")
102
+ os.remove(path)
103
+
104
+ # Convert mask_resized to RGBA mode
105
+ mask_resized_rgba = mask_resized.convert("RGBA")
106
+
107
+ # Apply the mask using ImageChops.multiply
108
+ output = ImageChops.multiply(frame_no_bg, mask_resized_rgba)
109
+ output_np = np.array(output)
110
+ frames.append(output_np)
111
+
112
+ # Save the processed frames as a new video
113
+ output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
114
+ processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
115
+ processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
116
+
117
+ return output_video_path
118
+
119
+ # Gradio components
120
+ slider1 = ImageSlider(label="RMBG-2.0", type="pil")
121
+ slider2 = ImageSlider(label="RMBG-2.0", type="pil")
122
+ image = gr.Image(label="Upload an image")
123
+ image2 = gr.Image(label="Upload an image", type="filepath")
124
+ text = gr.Textbox(label="Paste an image URL")
125
+ png_file = gr.File(label="output png file")
126
+ video_input = gr.Video(label="Upload a video")
127
+ video_output = gr.Video(label="Processed video")
128
+
129
+ # Example videos
130
+ example_videos = [
131
+ "pexels-cottonbro-5319934.mp4",
132
+ "300_A_car_is_running_on_the_road.mp4",
133
+ "A_Terracotta_Warrior_is_skateboarding_9033688.mp4"
134
+ ]
135
+
136
+ # Gradio interfaces
137
+ tab1 = gr.Interface(
138
+ fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[load_img("giraffe.jpg", output_type="pil")], api_name="image"
139
+ )
140
+
141
+ tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
142
+ #tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
143
+ tab4 = gr.Interface(process_video, inputs=video_input, outputs=video_output, examples=example_videos, api_name="video", cache_examples = False)
144
+
145
+ # Gradio tabbed interface
146
+ demo = gr.TabbedInterface(
147
+ [tab4, tab1, tab2], ["input video", "input image", "input url"], title="RMBG-2.0 for background removal"
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch(share=True, show_error=True)