import os import gradio as gr from gradio_imageslider import ImageSlider from loadimg import load_img #import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms import numpy as np from PIL import Image # 检查 CUDA 是否可用 if torch.cuda.is_available(): device = "cuda" else: device = "cpu" torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "briaai/RMBG-2.0", trust_remote_code=True ) birefnet.to(device) transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) output_folder = 'output_images' if not os.path.exists(output_folder): os.makedirs(output_folder) # 定义颜色列表,每个颜色对应一个 mask colors = [ '#000000', # 背景色 '#2692F3', # 蓝色 '#F89E12', # 橙色 '#16C232', # 绿色 '#F92F6C', # 粉色 '#AC6AEB', # 紫色 ] # 将颜色转换为 RGB 值 palette = np.array([ tuple(int(s[i + 1:i + 3], 16) for i in (0, 2, 4)) for s in colors[1:] # 跳过背景色 ]) # (N, 3) def fn(image, mask_color): im = load_img(image, output_type="pil") im = im.convert("RGB") origin = im.copy() image, mask = process(im, mask_color) image_path = os.path.join(output_folder, "no_bg_image.png") mask_path = os.path.join(output_folder, "mask_image.png") image.save(image_path) mask.save(mask_path) return (image, origin), image_path, mask #@spaces.GPU def process(image, mask_color): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # 创建一个新的透明背景图像 transparent_image = Image.new("RGBA", image_size, (0, 0, 0, 0)) transparent_image.paste(image, (0, 0), mask) # 创建一个带有颜色的 mask 图像 mask_color_rgb = tuple(int(mask_color[i + 1:i + 3], 16) for i in (0, 2, 4)) colored_mask = Image.new("RGBA", image_size, mask_color_rgb + (255,)) colored_mask.putalpha(mask) return transparent_image, colored_mask # 示例数据 example_image = "giraffe.jpg" # 确保该文件存在于当前目录 example_url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg" # 定义 Gradio 组件 with gr.Blocks() as demo: gr.Markdown("# 🖼️ RMBG-2.0 for Background Removal") with gr.Row(): # 左侧列:输入 with gr.Column(): gr.Markdown("## Input") image_input = gr.Image(label="Upload an image") text_input = gr.Textbox(label="Paste an image URL") color_input = gr.Dropdown(label="Mask Color", choices=colors[1:], value=colors[1]) run_button = gr.Button("Run") # 右侧列:输出 with gr.Column(): gr.Markdown("## Output") slider_output = ImageSlider(label="RMBG-2.0", type="pil") file_output = gr.File(label="Output PNG File") mask_output = gr.Image(label="Mask Image") # 示例数据 gr.Examples( examples=[[example_image, colors[1]], [example_url, colors[1]]], inputs=[image_input, color_input], outputs=[slider_output, file_output, mask_output], # 添加 outputs 参数 fn=fn, cache_examples=True ) # 绑定事件 run_button.click( fn=fn, inputs=[image_input, color_input], outputs=[slider_output, file_output, mask_output] ) if __name__ == "__main__": demo.launch(share=True, show_error=True)