Spaces:
Running
on
T4
Running
on
T4
import os | |
import cv2 | |
import numpy as np | |
import onnxruntime as ort | |
import gradio as gr | |
from PIL import Image | |
# Path to the model in Hugging Face Space | |
MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location | |
# Preprocessing function for images (similar to original script) | |
def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720): | |
''' Preprocess the image to match model input expectations ''' | |
img = np.array(img) | |
# Convert to RGB (OpenCV uses BGR by default) | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# Resize if necessary (downsample based on the downsample threshold) | |
h, w, _ = img_rgb.shape | |
short_side = min(h, w) | |
# Downsample if the short side exceeds the threshold | |
if short_side > downsample_threshold: | |
resize_ratio = short_side / downsample_threshold | |
img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR) | |
# Crop to match 4x scaling if needed | |
if crop_for_4x: | |
h, w, _ = img_rgb.shape | |
if h % 4 != 0: | |
img_rgb = img_rgb[:4 * (h // 4), :, :] | |
if w % 4 != 0: | |
img_rgb = img_rgb[:, :4 * (w // 4), :] | |
# Resize the image to match the model's expected input size (e.g., 180x320) | |
img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320 | |
return img_resized | |
# Inference function to process image using ONNX model | |
def inference(img, model_name="4xGRL"): | |
try: | |
# Ensure correct dtype for ONNX | |
weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32 | |
if model_name == "4xGRL": | |
# Load the ONNX model | |
ort_session = ort.InferenceSession(MODEL_PATH) | |
# Preprocess the image (resize, crop, etc.) | |
img_resized = preprocess_image(img) | |
# Prepare the input in the format expected by the model (e.g., (N, C, H, W)) | |
input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W) | |
input_image = np.expand_dims(input_image, axis=0) # Add batch dimension | |
input_image = input_image.astype(weight_dtype) # Convert to float32 | |
# Run the model | |
ort_inputs = {ort_session.get_inputs()[0].name: input_image} | |
ort_outs = ort_session.run(None, ort_inputs) | |
# Post-process the output | |
output_image = ort_outs[0] # Assuming the model output is in the first position | |
output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C) | |
output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range | |
# Convert output to PIL Image for Gradio | |
output_pil = Image.fromarray(output_image) | |
return output_pil | |
else: | |
raise Exception("Model not supported") | |
except Exception as error: | |
return f"An error occurred: {error}" | |
# Gradio interface | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Anime Super-Resolution using ONNX") | |
gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.") | |
# File input for image | |
with gr.Row(): | |
input_image = gr.Image(type="pil", label="Upload Image", interactive=True) | |
# Process button | |
with gr.Row(): | |
process_button = gr.Button("Process Image") | |
# Output for result image | |
with gr.Row(): | |
result_image = gr.Image(type="pil", label="Processed Image") | |
# Functionality for processing image | |
process_button.click(inference, inputs=input_image, outputs=result_image) | |
return demo | |
# Launch the app | |
demo = create_interface() | |
demo.launch(share=True) | |