Spaces:
Running
Running
import os | |
import random | |
import tensorflow as tf | |
from keras import models | |
import numpy as np | |
import gradio as gr | |
import cv2 | |
# Load the model | |
try: | |
generator = models.load_model("generator.keras") | |
print("Model loaded successfully!") | |
except Exception as e: | |
print("Error loading model:", e) | |
# Function to preprocess the image (resize, normalize) | |
def preprocess_image(img): | |
img = cv2.resize(img, (256, 256)) | |
# Convert L to range [-1, 1] | |
img = img.astype("float32") | |
img = (img / 127.5) - 1 | |
# Convert to tensor | |
img = tf.convert_to_tensor(img, dtype=tf.float32) | |
img = tf.expand_dims(img, axis=-1) # Add image dimension | |
img = tf.expand_dims(img, axis=0) # Add batch dimension | |
return img | |
# Function to postprocess the image (denormalize) | |
def postprocess_image(img): | |
return cv2.cvtColor(((img + 1) * 127.5).numpy().astype(np.uint8), cv2.COLOR_LAB2RGB) | |
# Function to adjust brightness | |
def adjust_brightness(img, brightness=0.0): | |
# Apply brightness adjustment | |
img = cv2.convertScaleAbs(img, beta=int(brightness * 127.0 / 4.0)) | |
return np.uint8(np.clip(img, 0, 255)) | |
# Function to adjust contrast | |
def adjust_contrast(img, contrast=0.0): | |
# Apply contrast adjustment | |
img = cv2.convertScaleAbs(img, alpha=(contrast * 0.75 + 1.0)) | |
return np.uint8(np.clip(img, 0, 255)) | |
# Function to adjust hue | |
def adjust_hue(img, hue_shift=0.0): | |
# Convert the image to HSV | |
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
# Adjust the hue channel (value is between 0 and 179 in OpenCV's HSV) | |
hsv_img[:, :, 0] = ( | |
hsv_img[:, :, 0] + hue_shift * 90 | |
) % 180 # Hue is wrapped in OpenCV HSV format | |
# Convert back to BGR | |
img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) | |
return np.uint8(np.clip(img, 0, 255)) | |
def adjust_saturation(img, saturation_factor=0.0): | |
# Convert the image to HSV | |
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
# Adjust the saturation channel (index 1 in HSV) | |
hsv_img[:, :, 1] = np.clip(hsv_img[:, :, 1] * (saturation_factor + 1.0), 0, 255) | |
# Convert back to BGR | |
img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) | |
return np.uint8(np.clip(img, 0, 255)) | |
# Define the inference function | |
def colorize_image(input_image): | |
# Preprocess the image for the model | |
preprocessed_image = preprocess_image(input_image) | |
# Predict using the model | |
output_ab = generator.predict(preprocessed_image) | |
output = tf.concat([preprocessed_image[0], output_ab[0]], axis=-1) | |
# Postprocess the output | |
output_image = postprocess_image(output) | |
return output_image | |
# Function to colorize and store the result for further manipulation | |
def colorize_and_store(img, bright_slider, cont_slider, sat_slider, hue_slider): | |
# Colorize the image | |
colorized_image = colorize_image(img) | |
output_image = adjust_brightness(colorized_image, bright_slider) | |
output_image = adjust_contrast(output_image, cont_slider) | |
output_image = adjust_saturation(output_image, sat_slider) | |
output_image = adjust_hue(output_image, hue_slider) | |
# Return the colorized image for further manipulation (no model call) | |
return colorized_image, output_image | |
def make_grayscale_256(img): | |
img = cv2.resize(img, (256, 256)) | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
return img | |
css = """ | |
h1 { | |
text-align: center; | |
display:block; | |
font-size: 3rem; | |
margin: 0; | |
padding: 0.5rem; | |
line-height: 1; | |
overflow: hidden; | |
} | |
p { | |
text-align: center; | |
display:block; | |
font-size:1.5rem; | |
margin: 0; | |
padding: 0.5rem; | |
line-height: 1; | |
overflow: hidden; | |
} | |
#input-image img { | |
filter: grayscale(1); | |
} | |
""" | |
# Get all image file paths in the folder | |
image_files = [ | |
os.path.join("examples", file) | |
for file in os.listdir("examples") | |
if file.lower().endswith((".png", ".jpg", ".jpeg", ".webp")) | |
] | |
# Gradio Interface | |
with gr.Blocks(css=css) as demo: | |
demo.title = "Portrait Colorizer" | |
# title | |
gr.HTML("<h1>Portrait Colorizer</h1>") | |
# description | |
gr.HTML("<p>Upload a grayscale image to colorize it and fine-tune the output using the sliders below.</p>") | |
with gr.Row(): | |
input_image = gr.Image( | |
type="numpy", | |
label="Grayscale Image", | |
image_mode="L", | |
height=256, | |
width=256, | |
elem_id="input-image", | |
) | |
examples_gallery = gr.Examples( | |
examples=image_files, inputs=[input_image], label="Example Images" | |
) | |
output_image = gr.Image( | |
type="numpy", | |
label="Colorized Image", | |
image_mode="RGB", | |
height=256, | |
width=256, | |
) | |
process_button = gr.Button("Colorize") | |
bright_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Brightness") | |
cont_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Contrast") | |
sat_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Saturation") | |
hue_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Hue") | |
# Initially colorize and display the image when it is uploaded | |
colorized_image = gr.State() | |
# Button click triggers processing | |
process_button.click( | |
fn=colorize_and_store, | |
inputs=[input_image, bright_slider, cont_slider, sat_slider, hue_slider], | |
outputs=[colorized_image, output_image], | |
) | |
# Apply hue adjustment to the stored colorized image (no re-generation) | |
bright_slider.change( | |
fn=adjust_brightness, | |
inputs=[colorized_image, bright_slider], | |
outputs=output_image, # Update output image | |
) | |
# Apply hue adjustment to the stored colorized image (no re-generation) | |
cont_slider.change( | |
fn=adjust_contrast, | |
inputs=[colorized_image, cont_slider], | |
outputs=output_image, # Update output image | |
) | |
# Apply hue adjustment to the stored colorized image (no re-generation) | |
hue_slider.change( | |
fn=adjust_hue, | |
inputs=[colorized_image, hue_slider], | |
outputs=output_image, # Update output image | |
) | |
# Apply saturation adjustment to the stored colorized image (no re-generation) | |
sat_slider.change( | |
fn=adjust_saturation, | |
inputs=[colorized_image, sat_slider], | |
outputs=output_image, # Update output image | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |