APISR / test_code /inference.py
Arrcttacsrks's picture
Update test_code/inference.py
6358027 verified
raw
history blame
3.86 kB
import os
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image
# Path to the model in Hugging Face Space
MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location
# Preprocessing function for images (similar to original script)
def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720):
''' Preprocess the image to match model input expectations '''
img = np.array(img)
# Convert to RGB (OpenCV uses BGR by default)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Resize if necessary (downsample based on the downsample threshold)
h, w, _ = img_rgb.shape
short_side = min(h, w)
# Downsample if the short side exceeds the threshold
if short_side > downsample_threshold:
resize_ratio = short_side / downsample_threshold
img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR)
# Crop to match 4x scaling if needed
if crop_for_4x:
h, w, _ = img_rgb.shape
if h % 4 != 0:
img_rgb = img_rgb[:4 * (h // 4), :, :]
if w % 4 != 0:
img_rgb = img_rgb[:, :4 * (w // 4), :]
# Resize the image to match the model's expected input size (e.g., 180x320)
img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320
return img_resized
# Inference function to process image using ONNX model
def inference(img, model_name="4xGRL"):
try:
# Ensure correct dtype for ONNX
weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32
if model_name == "4xGRL":
# Load the ONNX model
ort_session = ort.InferenceSession(MODEL_PATH)
# Preprocess the image (resize, crop, etc.)
img_resized = preprocess_image(img)
# Prepare the input in the format expected by the model (e.g., (N, C, H, W))
input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W)
input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
input_image = input_image.astype(weight_dtype) # Convert to float32
# Run the model
ort_inputs = {ort_session.get_inputs()[0].name: input_image}
ort_outs = ort_session.run(None, ort_inputs)
# Post-process the output
output_image = ort_outs[0] # Assuming the model output is in the first position
output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C)
output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range
# Convert output to PIL Image for Gradio
output_pil = Image.fromarray(output_image)
return output_pil
else:
raise Exception("Model not supported")
except Exception as error:
return f"An error occurred: {error}"
# Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Anime Super-Resolution using ONNX")
gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.")
# File input for image
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Image", interactive=True)
# Process button
with gr.Row():
process_button = gr.Button("Process Image")
# Output for result image
with gr.Row():
result_image = gr.Image(type="pil", label="Processed Image")
# Functionality for processing image
process_button.click(inference, inputs=input_image, outputs=result_image)
return demo
# Launch the app
demo = create_interface()
demo.launch(share=True)