import gradio as gr import os import sys base_path = os.path.expanduser('~') sys.path.append(os.path.join(base_path, 'Er0mangaSeg/')) sys.path.append(os.path.join(base_path, 'Er0mangaSeg/demo')) from image_demo_tta import init_seg_model, inference_tta sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/')) sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/bin')) from uncen import init_inpaint_model, inpaint import time import numpy as np import cv2 import shutil import torch if torch.cuda.is_available(): print('GPU found!') device = 'cuda:0' else: print('GPU not found! Using CPU') device = 'cpu' config = os.path.join(base_path, 'Er0mangaSeg/configs/convnext/convnext_h.py') checkpoint = os.path.join(base_path, 'Er0mangaSeg/pretrained/convnext_1024_iter_400.pth') model_seg = init_seg_model(config, checkpoint, device=device) print('Segmentation initialized') inp_model_path = os.path.join(base_path, 'Er0mangaInpaint/pretrained/00-30-09') model_inp = init_inpaint_model(inp_model_path) print('Inpainting initialized') def proc(input_img): try: s = time.time() out_mask, raw_mask = inference_tta(model_seg, input_img) out_mask = np.dstack([out_mask, out_mask, out_mask]) raw_mask = np.dstack([raw_mask, raw_mask, raw_mask]) output_img, out_dbg = inpaint(model_inp, input_img, out_mask) e = time.time() print(f"proc_time: {e-s:.2f}") return output_img#, raw_mask except Exception as e: raise gr.Error(e) def proc_batch(batch): res = [] try: s = time.time() out_p = os.path.dirname(batch[0][0]) salt = str(np.random.randint(1e10)) out_p_d = os.path.join(out_p, '__salt_img__'+salt) out_p_m = os.path.join(out_p, '__salt_mask__'+salt) os.mkdir(out_p_d) os.mkdir(out_p_m) for i in range(len(batch)): input_path = batch[i][0] inp_name = os.path.basename(input_path) input_img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB) out_mask, raw_mask = inference_tta(model_seg, input_img) out_mask = np.dstack([out_mask, out_mask, out_mask]) raw_mask = np.dstack([raw_mask, raw_mask, raw_mask]) output_img, out_dbg = inpaint(model_inp, input_img, out_mask) out_path_img = os.path.join(out_p_d, inp_name) out_path_mask = os.path.join(out_p_m, inp_name+'.png') cv2.imwrite(out_path_img, cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)) cv2.imwrite(out_path_mask, raw_mask) res.append(out_path_img) ar_path = os.path.join(out_p, 'output') shutil.make_archive(ar_path, 'zip', out_p_d) ar_path_m = os.path.join(out_p, 'output_mask') shutil.make_archive(ar_path_m, 'zip', out_p_m) e = time.time() print(f"batch proc_time: {e-s:.2f}") return res, ar_path + '.zip', ar_path_m + '.zip' except Exception as e: raise gr.Error(e) demo1 = gr.Interface(proc, gr.Image(), gr.Image(format='png'), delete_cache=(7200, 7200), allow_flagging='never') demo2 = gr.Interface(proc_batch, gr.Gallery(), [gr.Gallery(value='str', format='png'), gr.File(), gr.File()], delete_cache=(7200, 7200), allow_flagging='never') demo = gr.TabbedInterface([demo1, demo2], ["Single image processing", "Batch processing (experimental)"]) if __name__ == "__main__": demo.launch(server_name='0.0.0.0', server_port=7860)