import os import random import tensorflow as tf from keras import models import numpy as np import gradio as gr import cv2 # Load the model try: generator = models.load_model("generator.keras") print("Model loaded successfully!") except Exception as e: print("Error loading model:", e) # Function to preprocess the image (resize, normalize) def preprocess_image(img): img = cv2.resize(img, (256, 256)) # Convert L to range [-1, 1] img = img.astype("float32") img = (img / 127.5) - 1 # Convert to tensor img = tf.convert_to_tensor(img, dtype=tf.float32) img = tf.expand_dims(img, axis=-1) # Add image dimension img = tf.expand_dims(img, axis=0) # Add batch dimension return img # Function to postprocess the image (denormalize) def postprocess_image(img): return cv2.cvtColor(((img + 1) * 127.5).numpy().astype(np.uint8), cv2.COLOR_LAB2RGB) # Function to adjust brightness def adjust_brightness(img, brightness=0.0): # Apply brightness adjustment img = cv2.convertScaleAbs(img, beta=int(brightness * 127.0 / 4.0)) return np.uint8(np.clip(img, 0, 255)) # Function to adjust contrast def adjust_contrast(img, contrast=0.0): # Apply contrast adjustment img = cv2.convertScaleAbs(img, alpha=(contrast * 0.75 + 1.0)) return np.uint8(np.clip(img, 0, 255)) # Function to adjust hue def adjust_hue(img, hue_shift=0.0): # Convert the image to HSV hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # Adjust the hue channel (value is between 0 and 179 in OpenCV's HSV) hsv_img[:, :, 0] = ( hsv_img[:, :, 0] + hue_shift * 90 ) % 180 # Hue is wrapped in OpenCV HSV format # Convert back to BGR img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) return np.uint8(np.clip(img, 0, 255)) def adjust_saturation(img, saturation_factor=0.0): # Convert the image to HSV hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # Adjust the saturation channel (index 1 in HSV) hsv_img[:, :, 1] = np.clip(hsv_img[:, :, 1] * (saturation_factor + 1.0), 0, 255) # Convert back to BGR img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) return np.uint8(np.clip(img, 0, 255)) # Define the inference function def colorize_image(input_image): # Preprocess the image for the model preprocessed_image = preprocess_image(input_image) # Predict using the model output_ab = generator.predict(preprocessed_image) output = tf.concat([preprocessed_image[0], output_ab[0]], axis=-1) # Postprocess the output output_image = postprocess_image(output) return output_image # Function to colorize and store the result for further manipulation def colorize_and_store(img, bright_slider, cont_slider, sat_slider, hue_slider): # Colorize the image colorized_image = colorize_image(img) output_image = adjust_brightness(colorized_image, bright_slider) output_image = adjust_contrast(output_image, cont_slider) output_image = adjust_saturation(output_image, sat_slider) output_image = adjust_hue(output_image, hue_slider) # Return the colorized image for further manipulation (no model call) return colorized_image, output_image def make_grayscale_256(img): img = cv2.resize(img, (256, 256)) # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) return img css = """ h1 { text-align: center; display:block; font-size: 3rem; margin: 0; padding: 0.5rem; line-height: 1; overflow: hidden; } p { text-align: center; display:block; font-size:1.5rem; margin: 0; padding: 0.5rem; line-height: 1; overflow: hidden; } #input-image img { filter: grayscale(1); } """ # Get all image file paths in the folder image_files = [ os.path.join("examples", file) for file in os.listdir("examples") if file.lower().endswith((".png", ".jpg", ".jpeg", ".webp")) ] # Gradio Interface with gr.Blocks(css=css) as demo: demo.title = "Portrait Colorizer" # title gr.HTML("
Upload a grayscale image to colorize it and fine-tune the output using the sliders below.
") with gr.Row(): input_image = gr.Image( type="numpy", label="Grayscale Image", image_mode="L", height=256, width=256, elem_id="input-image", ) examples_gallery = gr.Examples( examples=image_files, inputs=[input_image], label="Example Images" ) output_image = gr.Image( type="numpy", label="Colorized Image", image_mode="RGB", height=256, width=256, ) process_button = gr.Button("Colorize") bright_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Brightness") cont_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Contrast") sat_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Saturation") hue_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Hue") # Initially colorize and display the image when it is uploaded colorized_image = gr.State() # Button click triggers processing process_button.click( fn=colorize_and_store, inputs=[input_image, bright_slider, cont_slider, sat_slider, hue_slider], outputs=[colorized_image, output_image], ) # Apply hue adjustment to the stored colorized image (no re-generation) bright_slider.change( fn=adjust_brightness, inputs=[colorized_image, bright_slider], outputs=output_image, # Update output image ) # Apply hue adjustment to the stored colorized image (no re-generation) cont_slider.change( fn=adjust_contrast, inputs=[colorized_image, cont_slider], outputs=output_image, # Update output image ) # Apply hue adjustment to the stored colorized image (no re-generation) hue_slider.change( fn=adjust_hue, inputs=[colorized_image, hue_slider], outputs=output_image, # Update output image ) # Apply saturation adjustment to the stored colorized image (no re-generation) sat_slider.change( fn=adjust_saturation, inputs=[colorized_image, sat_slider], outputs=output_image, # Update output image ) # Launch the app demo.launch(share=True, ssr_mode=False)