import warnings
warnings.filterwarnings('ignore')
import subprocess, io, os, sys, time
import gradio as gr
from loguru import logger
import argparse
import copy
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont, ImageOps
import cv2
import numpy as np
import matplotlib
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.system("pip list")
logger.info(f"Start app...")
sys.path.insert(0, './GroundingDINO')
try:
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
except Exception as e:
logger.error(f"import GroundingDINO error: {str(e)}")
matplotlib.use('AGG')
plt = matplotlib.pyplot
groundingdino_enable = True
sam_enable = True
inpainting_enable = True
ram_enable = False
lama_cleaner_enable = True
kosmos_enable = False
if os.environ.get('IS_MY_DEBUG') is not None:
sam_enable = False
ram_enable = False
# inpainting_enable = False
kosmos_enable = False
if lama_cleaner_enable:
try:
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config as lama_Config
except Exception as e:
lama_cleaner_enable = False
# segment anything
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
# diffusers
import PIL
import requests
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline
from huggingface_hub import hf_hub_download
from util_computer import computer_info
# relate anything
from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
from ram_train_eval import RamModel, RamPredictor
from mmengine.config import Config as mmengine_Config
if lama_cleaner_enable:
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
# from transformers import AutoProcessor, AutoModelForVision2Seq
import ast
if kosmos_enable:
os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main")
# os.system("pip install transformers==4.32.0")
from kosmos_utils import *
from util_tencent import getTextTrans
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
sam_checkpoint = './sam_vit_h_4b8939.pth'
output_dir = "outputs"
device = 'cpu'
os.makedirs(output_dir, exist_ok=True)
groundingdino_model = None
sam_device = None
sam_model = None
sam_predictor = None
sam_mask_generator = None
sd_model = None
lama_cleaner_model= None
ram_model = None
kosmos_model = None
kosmos_processor = None
MAX_SEED = np.iinfo(np.int32).max
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location=device)
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
logger.info("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def plot_boxes_to_image(image_pil, tgt):
H, W = tgt["size"]
boxes = tgt["boxes"]
labels = tgt["labels"]
assert len(boxes) == len(labels), "boxes and labels must have same length"
draw = ImageDraw.Draw(image_pil)
mask = Image.new("L", image_pil.size, 0)
mask_draw = ImageDraw.Draw(mask)
# draw boxes and masks
for box, label in zip(boxes, labels):
# from 0..1 to 0..W, 0..H
box = box * torch.Tensor([W, H, W, H])
# from xywh to xyxy
box[:2] -= box[2:] / 2
box[2:] += box[:2]
# random color
color = tuple(np.random.randint(0, 255, size=3).tolist())
# draw
x0, y0, x1, y1 = box
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
# draw.text((x0, y0), str(label), fill=color)
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (x0, y0, w + x0, y0 + h)
# bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
try:
font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font_size = 36
new_font = ImageFont.truetype(font, font_size)
draw.text((x0+2, y0+2), str(label), font=new_font, fill="white")
except Exception as e:
pass
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
return image_pil, mask
def load_image(image_path):
# # load image
if isinstance(image_path, PIL.Image.Image):
image_pil = image_path
else:
image_pil = Image.open(image_path).convert("RGB") # load image
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image_pil, image
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
return boxes_filt, pred_phrases
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.text(x0, y0, label)
def xywh_to_xyxy(box, sizeW, sizeH):
if isinstance(box, list):
box = torch.Tensor(box)
box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH])
box[:2] -= box[2:] / 2
box[2:] += box[:2]
box = box.numpy()
return box
def mask_extend(img, box, extend_pixels=10, useRectangle=True):
box[0] = int(box[0])
box[1] = int(box[1])
box[2] = int(box[2])
box[3] = int(box[3])
region = img.crop(tuple(box))
new_width = box[2] - box[0] + 2*extend_pixels
new_height = box[3] - box[1] + 2*extend_pixels
region_BILINEAR = region.resize((int(new_width), int(new_height)))
if useRectangle:
region_draw = ImageDraw.Draw(region_BILINEAR)
region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255))
img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels)))
return img
def mix_masks(imgs):
re_img = 1 - np.asarray(imgs[0].convert("1"))
for i in range(len(imgs)-1):
re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1")))
re_img = 1 - re_img
return Image.fromarray(np.uint8(255*re_img))
def set_device(args):
global device
if os.environ.get('IS_MY_DEBUG') is None:
device = args.cuda if torch.cuda.is_available() else 'cpu'
else:
device = 'cpu'
logger.info(f'device={device}')
def load_groundingdino_model(device):
# initialize groundingdino model
global groundingdino_model
logger.info(f"initialize groundingdino model...")
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
logger.info(f"initialize groundingdino model...{type(groundingdino_model)}")
def get_sam_vit_h_4b8939():
if not os.path.exists('./sam_vit_h_4b8939.pth'):
logger.info(f"get sam_vit_h_4b8939.pth...")
result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
logger.info(f'wget sam_vit_h_4b8939.pth result = {result}')
def load_sam_model(device):
# initialize SAM
global sam_model, sam_predictor, sam_mask_generator, sam_device
get_sam_vit_h_4b8939()
logger.info(f"initialize SAM model...")
sam_device = device
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
sam_predictor = SamPredictor(sam_model)
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
def load_sd_model(device):
# initialize stable-diffusion-inpainting
global sd_model
logger.info(f"initialize stable-diffusion-inpainting...")
sd_model = None
'''
if os.environ.get('IS_MY_DEBUG') is None:
# sd_model = StableDiffusionInpaintPipeline.from_pretrained(
# "runwayml/stable-diffusion-inpainting",
# revision="fp16",
# # "stabilityai/stable-diffusion-2-inpainting",
# torch_dtype=torch.float16,
# )
# sd_model = sd_model.to(device)
'''
def load_lama_cleaner_model(device):
# initialize lama_cleaner
global lama_cleaner_model
logger.info(f"initialize lama_cleaner...")
lama_cleaner_model = ModelManager(
name='lama',
device=device,
)
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
try:
logger.info(f'_______lama_cleaner_process_______1____')
ori_image = image
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
# rotate image
logger.info(f'_______lama_cleaner_process_______2____')
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
logger.info(f'_______lama_cleaner_process_______3____')
image = ori_image
logger.info(f'_______lama_cleaner_process_______4____')
original_shape = ori_image.shape
logger.info(f'_______lama_cleaner_process_______5____')
interpolation = cv2.INTER_CUBIC
size_limit = cleaner_size_limit
if size_limit == -1:
logger.info(f'_______lama_cleaner_process_______6____')
size_limit = max(image.shape)
else:
logger.info(f'_______lama_cleaner_process_______7____')
size_limit = int(size_limit)
logger.info(f'_______lama_cleaner_process_______8____')
config = lama_Config(
ldm_steps=25,
ldm_sampler='plms',
zits_wireframe=True,
hd_strategy='Original',
hd_strategy_crop_margin=196,
hd_strategy_crop_trigger_size=1280,
hd_strategy_resize_limit=2048,
prompt='',
use_croper=False,
croper_x=0,
croper_y=0,
croper_height=512,
croper_width=512,
sd_mask_blur=5,
sd_strength=0.75,
sd_steps=50,
sd_guidance_scale=7.5,
sd_sampler='ddim',
sd_seed=42,
cv2_flag='INPAINT_NS',
cv2_radius=5,
)
logger.info(f'_______lama_cleaner_process_______9____')
if config.sd_seed == -1:
config.sd_seed = random.randint(1, MAX_SEED)
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
logger.info(f'_______lama_cleaner_process_______10____')
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
# logger.info(f"Resized image shape_1_: {image.shape}")
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
logger.info(f'_______lama_cleaner_process_______11____')
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
logger.info(f'_______lama_cleaner_process_______12____')
res_np_img = lama_cleaner_model(image, mask, config)
logger.info(f'_______lama_cleaner_process_______13____')
torch.cuda.empty_cache()
logger.info(f'_______lama_cleaner_process_______14____')
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
logger.info(f'_______lama_cleaner_process_______15____')
except Exception as e:
logger.info(f'lama_cleaner_process[Error]:' + str(e))
image = None
return image
class Ram_Predictor(RamPredictor):
def __init__(self, config, device='cpu'):
self.config = config
self.device = torch.device(device)
self._build_model()
def _build_model(self):
self.model = RamModel(**self.config.model).to(self.device)
if self.config.load_from is not None:
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
self.model.train()
def load_ram_model(device):
# load ram model
global ram_model
if os.environ.get('IS_MY_DEBUG') is not None:
return
model_path = "./checkpoints/ram_epoch12.pth"
ram_config = dict(
model=dict(
pretrained_model_name_or_path='bert-base-uncased',
load_pretrained_weights=False,
num_transformer_layer=2,
input_feature_size=256,
output_feature_size=768,
cls_feature_size=512,
num_relation_classes=56,
pred_type='attention',
loss_type='multi_label_ce',
),
load_from=model_path,
)
ram_config = mmengine_Config(ram_config)
ram_model = Ram_Predictor(ram_config, device)
# visualization
def draw_selected_mask(mask, draw):
color = (255, 0, 0, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def draw_object_mask(mask, draw):
color = (0, 0, 255, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'):
# Define the colors to use for each word
color_red = (255, 0, 0)
color_black = (0, 0, 0)
color_blue = (0, 0, 255)
# Define the initial font size and spacing between words
font_size = 40
# Create a new image with the specified width and white background
image = Image.new('RGB', (width, 60), (255, 255, 255))
try:
# Load the specified font
font = ImageFont.truetype(font_path, font_size)
# Keep increasing the font size until all words fit within the desired width
while True:
# Create a draw object for the image
draw = ImageDraw.Draw(image)
word_spacing = font_size / 2
# Draw each word in the appropriate color
x_offset = word_spacing
draw.text((x_offset, 0), word1, color_red, font=font)
x_offset += font.getsize(word1)[0] + word_spacing
draw.text((x_offset, 0), word2, color_black, font=font)
x_offset += font.getsize(word2)[0] + word_spacing
draw.text((x_offset, 0), word3, color_blue, font=font)
word_sizes = [font.getsize(word) for word in [word1, word2, word3]]
total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3
# Stop increasing font size if the image is within the desired width
if total_width <= width:
break
# Increase font size and reset the draw object
font_size -= 1
image = Image.new('RGB', (width, 50), (255, 255, 255))
font = ImageFont.truetype(font_path, font_size)
draw = None
except Exception as e:
pass
return image
def concatenate_images_vertical(image1, image2):
# Get the dimensions of the two images
width1, height1 = image1.size
width2, height2 = image2.size
# Create a new image with the combined height and the maximum width
new_image = Image.new('RGBA', (max(width1, width2), height1 + height2))
# Paste the first image at the top of the new image
new_image.paste(image1, (0, 0))
# Paste the second image below the first image
new_image.paste(image2, (0, height1))
return new_image
def relate_anything(input_image, k):
logger.info(f'relate_anything_1_{input_image.size}_')
w, h = input_image.size
max_edge = 1500
if w > max_edge or h > max_edge:
ratio = max(w, h) / max_edge
new_size = (int(w / ratio), int(h / ratio))
input_image.thumbnail(new_size)
logger.info(f'relate_anything_2_')
# load image
pil_image = input_image.convert('RGBA')
image = np.array(input_image)
sam_masks = sam_mask_generator.generate(image)
filtered_masks = sort_and_deduplicate(sam_masks)
logger.info(f'relate_anything_3_')
feat_list = []
for fm in filtered_masks:
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
feat_list.append(feat)
feat = torch.cat(feat_list, dim=1).to(device)
matrix_output, rel_triplets = ram_model.predict(feat)
logger.info(f'relate_anything_4_')
pil_image_list = []
for i, rel in enumerate(rel_triplets[:k]):
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
relation = relation_classes[r]
mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
current_pil_image = pil_image.copy()
current_pil_image.alpha_composite(mask_image)
title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
pil_image_list.append(concate_pil_image)
logger.info(f'relate_anything_5_{len(pil_image_list)}')
return pil_image_list
mask_source_draw = "draw a mask on input image"
mask_source_segment = "type what to detect below"
def get_time_cost(run_task_time, time_cost_str):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
run_task_time = now_time
return run_task_time, time_cost_str
def processs_inpainting(inpaint_prompt, input_image, mask_image, image_input_composite, debug=False):
from gradio_client import Client, handle_file
import tempfile
MAX_IMAGE_SIZE = 1024
def change_RGB_value(image, r0, g0, b0, r1, g1, b1):
pixels = image.load()
for i in range(image.size[0]):
for j in range(image.size[1]):
r, g, b = pixels[i, j]
if r == r0 and g == g0 and b == b0:
pixels[i, j] = (r1, g1, b1)
return image
try:
# logger.info(f'processs_inpainting_input_image={inpaint_prompt} // {input_image}')
# logger.info(f'processs_inpainting_mask_image={mask_image}')
# job_image = {}
# job_mask_image = None
# if 'background' in input_image.keys():
# width, height = input_image['background'].size
# if max(width, height) > MAX_IMAGE_SIZE:
# if width > height:
# resize_width = MAX_IMAGE_SIZE
# resize_height = int(height * MAX_IMAGE_SIZE / width)
# else:
# resize_height = MAX_IMAGE_SIZE
# resize_width = int(width * MAX_IMAGE_SIZE / height)
# else:
# resize_width, resize_height = width, height
# logger.info(f"resize____{width}, {height}==>{resize_width}, {resize_height}")
# _, temp_file_path = tempfile.mkstemp(suffix='.png')
# img = input_image['background'].convert("RGB").resize((resize_width, resize_height))
# img.save(temp_file_path)
# # logger.info(f'processs_inpainting_temp_file_background_={temp_file_path}')
# job_image["background"] = handle_file(temp_file_path)
# if mask_image is not None:
# _, temp_file_path = tempfile.mkstemp(suffix='.png')
# logger.info(f"mask_temp_file_path={temp_file_path}")
# img = mask_image.convert("RGB").resize((resize_width, resize_height))
# # RGB(0,0,0) --> RGB(230,230,230)
# img = change_RGB_value(img, 0, 0, 0, 230, 230, 230)
# # RGB(255,255,255) --> RGB(170,170,170)
# # img = change_RGB_value(img, 255, 255, 255, 170, 170, 170)
# img.save(temp_file_path)
# # logger.info(f'processs_inpainting_temp_file___mask_={temp_file_path}')
# job_image["layers"] = [handle_file(temp_file_path)]
# if image_input_composite:
# _, temp_file_path = tempfile.mkstemp(suffix='.png')
# image_input_composite.save(temp_file_path)
# job_image["composite"] = handle_file(temp_file_path)
# logger.info(f'processs_inpainting_input_image={job_image}')
# logger.info(f'processs_inpainting_job_image={job_image}')
# logger.info(f'processs_inpainting_job_mask_image={job_mask_image}')
# if 0==1:
# logger.info(f'processs_inpainting_HF = Kwai-Kolors/Kolors-Inpainting')
# client = Client("Kwai-Kolors/Kolors-Inpainting", hf_token=huggingface_token, verbose=True)
# job = client.submit(
# prompt=inpaint_prompt,
# image=job_image,
# mask_image=job_mask_image,
# negative_prompt="broken fingers, deformed fingers, deformed hands, stumps, blurriness, low quality",
# seed=0,
# randomize_seed=True,
# guidance_scale=6.0,
# num_inference_steps=25,
# api_name="/infer"
# )
if 0==0:
logger.info(f'processs_inpainting_HF = ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU')
client = Client("ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU")
job = client.submit(
input_image_editor=input_image,
# input_image_editor={"background":handle_file('https://ameerazam08-flux-1-dev-inpainting-model-beta-gpu.hf.space/file=/tmp/gradio/d9270300d40346834ccf596c97d9795fef6af692/background.png'),"layers":[handle_file('https://ameerazam08-flux-1-dev-inpainting-model-beta-gpu.hf.space/file=/tmp/gradio/4245eafd2d9aa8969d94dc22ffff62be95d4df02/layer_0.png')],"composite":handle_file('https://ameerazam08-flux-1-dev-inpainting-model-beta-gpu.hf.space/file=/tmp/gradio/7fe6ddfacbb3b394d19a32b01113a0460e2279ac/composite.png')},
prompt=inpaint_prompt,
negative_prompt="",
controlnet_conditioning_scale=0.9,
guidance_scale=3.5,
seed=124,
num_inference_steps=24,
true_guidance_scale=3.5,
api_name="/process"
)
debug = True
if debug:
count = 0
logger.info(f'{count}___{job.status()}')
while not job.done():
if debug:
count += 1
logger.info(f'{count}___{job.status()}')
time.sleep(0.1)
result = job.outputs()
logger.info(f'processs_inpainting_result={result}')
if len(result) <= 0:
return None
result = result[0]
im = Image.open(result)
if im.mode == "RGBA":
im.load()
background = Image.new("RGB", im.size, (255, 255, 255))
background.paste(im, mask=im.split()[3])
return im
except Exception as e:
logger.info(f'processs_inpainting_[Error]:' + str(e))
return None
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, remove_use_segment, num_relation, kosmos_input, cleaner_size_limit=1080):
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
# logger.info(f"input_image==={input_image}")
ori_input_image = input_image
image_input_composite = None
if 'background' in input_image.keys():
input_image['image'] = input_image['background'].convert("RGB")
if len(input_image['layers']) > 0:
img_arr = np.array(input_image['layers'][0].convert("L"))
img_arr = np.where(img_arr > 0, 1, img_arr)
input_image['mask'] = Image.fromarray(255*img_arr.astype('uint8'))
if 'composite' in input_image.keys():
image_input_composite = input_image['composite']
if (task_type == 'Kosmos-2'):
global kosmos_model, kosmos_processor
if isinstance(input_image, dict):
image_pil, image = load_image(input_image['image'].convert("RGB"))
input_img = input_image['image']
else:
image_pil, image = load_image(input_image.convert("RGB"))
input_img = input_image
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_input, kosmos_model, kosmos_processor)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
return None, None, time_cost_str, kosmos_image, gr.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
if (task_type == 'relate anything'):
output_images = relate_anything(input_image['image'], num_relation)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
return output_images, gr.update(label='relate images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
text_prompt = text_prompt.strip()
if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
if text_prompt == '':
return [], gr.update(label='Detection prompt is not found!😂😂😂😂'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
if input_image is None:
return [], gr.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
file_temp = int(time.time())
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}/{remove_use_segment}_[{text_prompt}]/[{inpaint_prompt}]___1_')
output_images = []
# load image
if mask_source_radio == mask_source_draw:
input_mask_pil = input_image['mask']
input_mask = np.array(input_mask_pil.convert("L"))
if isinstance(input_image, dict):
image_pil, image = load_image(input_image['image'].convert("RGB"))
input_img = input_image['image']
output_images.append(input_image['image'])
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
else:
image_pil, image = load_image(input_image.convert("RGB"))
input_img = input_image
output_images.append(input_image)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
size = image_pil.size
H, W = size[1], size[0]
# run grounding dino model
if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw:
pass
else:
groundingdino_device = 'cpu'
if device != 'cpu':
try:
from groundingdino import _C
groundingdino_device = 'cuda:0'
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
boxes_filt, pred_phrases = get_grounding_output(
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
)
if boxes_filt.size(0) == 0:
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
return [], gr.update(label='No objects detected, please try others.😂😂😂😂'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
boxes_filt_ori = copy.deepcopy(boxes_filt)
pred_dict = {
"boxes": boxes_filt,
"size": [size[1], size[0]], # H,W
"labels": pred_phrases,
}
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
output_images.append(image_with_box)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
use_sam_predictor = True
if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
image = np.array(input_img)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_1_')
if task_type == 'remove' and remove_use_segment == False:
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_2_')
use_sam_predictor = False
if sam_predictor and use_sam_predictor:
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_3_')
sam_predictor.set_image(image)
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
if sam_predictor and use_sam_predictor:
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_4_')
boxes_filt = boxes_filt.to(sam_device)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
masks, _, _, _ = sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes,
multimask_output = False,
)
# masks: [9, 1, 512, 512]
assert sam_checkpoint, 'sam_checkpoint is not found!'
else:
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_5_')
masks = torch.zeros(len(boxes_filt), 1, H, W)
mask_count = 0
for box in boxes_filt:
masks[mask_count, 0, int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
mask_count += 1
masks = torch.where(masks > 0, True, False)
run_mode = "rectangle"
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_6_')
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.cpu().numpy(), plt.gca(), label)
plt.axis('off')
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
plt.savefig(image_path, bbox_inches="tight")
plt.clf()
plt.close('all')
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
os.remove(image_path)
output_images.append(Image.fromarray(segment_image_result))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
if task_type == 'detection' or task_type == 'segment':
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
task_type = 'remove'
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
if mask_source_radio == mask_source_draw:
mask_pil = input_mask_pil
mask = input_mask
else:
masks_ori = copy.deepcopy(masks)
if inpaint_mode == 'merge':
masks = torch.sum(masks, dim=0).unsqueeze(0)
masks = torch.where(masks > 0, True, False)
mask = masks[0][0].cpu().numpy()
mask_pil = Image.fromarray(mask)
output_images.append(mask_pil.convert("RGB"))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
if task_type in ['inpainting', 'outpainting']:
# inpainting pipeline
image_source_for_inpaint = image_pil #.resize((512, 512))
image_mask_for_inpaint = mask_pil #.resize((512, 512))
if task_type in ['outpainting']:
# reverse mask
img_arr = np.array(image_mask_for_inpaint)
img_arr = np.where(img_arr > 0, 1, img_arr)
img_arr = 1 - img_arr
image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
output_images.append(image_mask_for_inpaint.convert("RGB"))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
# image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
image_inpainting = processs_inpainting(inpaint_prompt, input_image, image_mask_for_inpaint, image_input_composite)
if image_inpainting is None:
logger.info(f'processs_inpainting_failed_')
time_cost_str = f"processs_inpainting_task__failed!"
return None, None, time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
else:
# remove from mask
if mask_source_radio == mask_source_segment:
mask_imgs = []
masks_shape = masks_ori.shape
boxes_filt_ori_array = boxes_filt_ori.numpy()
if inpaint_mode == 'merge':
extend_shape_0 = masks_shape[0]
extend_shape_1 = masks_shape[1]
else:
extend_shape_0 = 1
extend_shape_1 = 1
for i in range(extend_shape_0):
for j in range(extend_shape_1):
mask = masks_ori[i][j].cpu().numpy()
mask_pil = Image.fromarray(mask)
if remove_mode == 'segment':
useRectangle = False
else:
useRectangle = True
try:
remove_mask_extend = int(remove_mask_extend)
except:
remove_mask_extend = 10
mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), W, H),
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
mask_imgs.append(mask_pil_exp)
mask_pil = mix_masks(mask_imgs)
output_images.append(mask_pil.convert("RGB"))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
if image_inpainting is None:
logger.info(f'run_anything_task_failed_')
time_cost_str = f"run_anything_task[{task_type}]__failed!"
return None, None, time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
# output_images.append(image_inpainting)
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
output_images.append(image_inpainting)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
else:
logger.info(f"task_type:{task_type} error!")
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
def change_radio_display(task_type, mask_source_radio):
text_prompt_visible = True
inpaint_prompt_visible = False
mask_source_radio_visible = False
num_relation_visible = False
image_gallery_visible = True
kosmos_input_visible = False
kosmos_output_visible = False
kosmos_text_output_visible = False
if task_type == "Kosmos-2":
if kosmos_enable:
text_prompt_visible = False
image_gallery_visible = False
kosmos_input_visible = True
kosmos_output_visible = True
kosmos_text_output_visible = True
if task_type in ['inpainting', 'outpainting']:
inpaint_prompt_visible = True
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
mask_source_radio_visible = True
if mask_source_radio == mask_source_draw:
text_prompt_visible = False
if task_type == "relate anything":
text_prompt_visible = False
num_relation_visible = True
return (gr.update(visible=text_prompt_visible),
gr.update(visible=inpaint_prompt_visible),
gr.update(visible=mask_source_radio_visible),
gr.update(visible=num_relation_visible),
gr.update(visible=image_gallery_visible),
gr.update(visible=kosmos_input_visible),
gr.update(visible=kosmos_output_visible),
gr.update(visible=kosmos_text_output_visible))
def get_model_device(module):
try:
if module is None:
return 'None'
if isinstance(module, torch.nn.DataParallel):
module = module.module
for submodule in module.children():
if hasattr(submodule, "_parameters"):
parameters = submodule._parameters
if "weight" in parameters:
return parameters["weight"].device
return 'UnKnown'
except Exception as e:
return 'Error'
def main_gradio(args):
block = gr.Blocks(
title="SAM and others",
# theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
)
with block:
with gr.Row():
with gr.Column():
task_types = ["detection"]
if sam_enable:
task_types.append("segment")
if inpainting_enable:
task_types.append("inpainting")
# task_types.append("outpainting")
if lama_cleaner_enable:
task_types.append("remove")
if ram_enable:
task_types.append("relate anything")
if kosmos_enable:
task_types.append("Kosmos-2")
brush_color = "#FFFFFF"
color_mode = "fixed"
input_image = gr.ImageEditor(sources=["upload", "webcam"],
image_mode='RGB',
elem_id="image_upload",
type='pil',
label="Upload",
layers=False,
brush=gr.Brush(colors=[brush_color], color_mode=color_mode))
task_type = gr.Radio(task_types, value="detection",
label='Task type', visible=True)
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
value=mask_source_segment, label="Mask from",
visible=False)
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
run_button = gr.Button(value="Run", visible=True)
with gr.Accordion("Advanced options", open=False) as advanced_options:
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
iou_threshold = gr.Slider(
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
)
inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
with gr.Row():
with gr.Column(scale=1):
remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
with gr.Column(scale=1):
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
with gr.Column(scale=1, visible=False):
remove_use_segment = gr.Checkbox(value=True, elem_id='remove_use_segment', label="use segment for removing?", info="")
with gr.Column():
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
) #.style(preview=True, columns=[5], object_fit="scale-down", height="auto")
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
kosmos_text_output = gr.HighlightedText(
label="Generated Description",
combine_adjacent=False,
show_legend=True,
visible=False,
) # .style(color_map=color_map)
# record which text span (label) is selected
selected = gr.Number(-1, show_label=False, visible=False)
# record the current `entities`
entity_output = gr.Textbox(visible=False)
# get the current selected span label
def get_text_span_label(evt: gr.SelectData):
if evt.value[-1] is None:
return -1
return int(evt.value[-1])
# and set this information to `selected`
kosmos_text_output.select(get_text_span_label, None, selected)
# update output image when we change the span (enity) selection
def update_output_image(img_input, image_output, entities, idx):
entities = ast.literal_eval(entities)
updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
return updated_image
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
run_button.click(fn=run_anything_task, inputs=[
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, remove_use_segment, num_relation, kosmos_input],
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
image_gallery, kosmos_input, kosmos_output, kosmos_text_output
])
DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything).
'
if lama_cleaner_enable:
DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner).
'
if kosmos_enable:
DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2).
'
if ram_enable:
DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything).
'
if inpainting_enable:
DESCRIPTION += f'Inpainting from [FLUX.1-dev-Inpainting-Model-Beta-GPU](https://huggingface.co/spaces/ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU).
'
# DESCRIPTION += f'Inpainting from [Kolors-Inpainting](https://huggingface.co/spaces/Kwai-Kolors/Kolors-Inpainting).
'
DESCRIPTION += f'Thanks for their excellent work.'
DESCRIPTION += f'
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
' gr.Markdown(DESCRIPTION) logger.info(f'device = {device}') logger.info(f'torch.cuda.is_available = {torch.cuda.is_available()}') computer_info() block.queue(max_size=10, api_open=False) logger.info(f"Start a gradio server[{os.getpid()}]: http://0.0.0.0:{args.port}") block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share, show_api=False) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") parser.add_argument("--port", "-p", type=int, default=7860, help="port") parser.add_argument("--cuda", "-c", type=str, default='cuda:0', help="cuda") args, _ = parser.parse_known_args() logger.info(f'args = {args}') if os.environ.get('IS_MY_DEBUG') is None: os.system("pip list") set_device(args) if device == 'cpu': kosmos_enable = False if kosmos_enable: kosmos_model, kosmos_processor = load_kosmos_model(device) if groundingdino_enable: load_groundingdino_model('cpu') if sam_enable: load_sam_model(device) if inpainting_enable: load_sd_model(device) if lama_cleaner_enable: load_lama_cleaner_model(device) if ram_enable: load_ram_model(device) if os.environ.get('IS_MY_DEBUG') is None: os.system("pip list") main_gradio(args)