LennyS17's picture
Update app.py
9611c3f verified
raw
history blame
6.38 kB
import os
import random
import tensorflow as tf
from keras import models
import numpy as np
import gradio as gr
import cv2
# Load the model
try:
generator = models.load_model("generator.keras")
print("Model loaded successfully!")
except Exception as e:
print("Error loading model:", e)
# Function to preprocess the image (resize, normalize)
def preprocess_image(img):
img = cv2.resize(img, (256, 256))
# Convert L to range [-1, 1]
img = img.astype("float32")
img = (img / 127.5) - 1
# Convert to tensor
img = tf.convert_to_tensor(img, dtype=tf.float32)
img = tf.expand_dims(img, axis=-1) # Add image dimension
img = tf.expand_dims(img, axis=0) # Add batch dimension
return img
# Function to postprocess the image (denormalize)
def postprocess_image(img):
return cv2.cvtColor(((img + 1) * 127.5).numpy().astype(np.uint8), cv2.COLOR_LAB2RGB)
# Function to adjust brightness
def adjust_brightness(img, brightness=0.0):
# Apply brightness adjustment
img = cv2.convertScaleAbs(img, beta=int(brightness * 127.0 / 4.0))
return np.uint8(np.clip(img, 0, 255))
# Function to adjust contrast
def adjust_contrast(img, contrast=0.0):
# Apply contrast adjustment
img = cv2.convertScaleAbs(img, alpha=(contrast * 0.75 + 1.0))
return np.uint8(np.clip(img, 0, 255))
# Function to adjust hue
def adjust_hue(img, hue_shift=0.0):
# Convert the image to HSV
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Adjust the hue channel (value is between 0 and 179 in OpenCV's HSV)
hsv_img[:, :, 0] = (
hsv_img[:, :, 0] + hue_shift * 90
) % 180 # Hue is wrapped in OpenCV HSV format
# Convert back to BGR
img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
return np.uint8(np.clip(img, 0, 255))
def adjust_saturation(img, saturation_factor=0.0):
# Convert the image to HSV
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Adjust the saturation channel (index 1 in HSV)
hsv_img[:, :, 1] = np.clip(hsv_img[:, :, 1] * (saturation_factor + 1.0), 0, 255)
# Convert back to BGR
img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
return np.uint8(np.clip(img, 0, 255))
# Define the inference function
def colorize_image(input_image):
# Preprocess the image for the model
preprocessed_image = preprocess_image(input_image)
# Predict using the model
output_ab = generator.predict(preprocessed_image)
output = tf.concat([preprocessed_image[0], output_ab[0]], axis=-1)
# Postprocess the output
output_image = postprocess_image(output)
return output_image
# Function to colorize and store the result for further manipulation
def colorize_and_store(img, bright_slider, cont_slider, sat_slider, hue_slider):
# Colorize the image
colorized_image = colorize_image(img)
output_image = adjust_brightness(colorized_image, bright_slider)
output_image = adjust_contrast(output_image, cont_slider)
output_image = adjust_saturation(output_image, sat_slider)
output_image = adjust_hue(output_image, hue_slider)
# Return the colorized image for further manipulation (no model call)
return colorized_image, output_image
def make_grayscale_256(img):
img = cv2.resize(img, (256, 256))
# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
css = """
h1 {
text-align: center;
display:block;
font-size: 3rem;
margin: 0;
padding: 0.5rem;
line-height: 1;
overflow: hidden;
}
p {
text-align: center;
display:block;
font-size:1.5rem;
margin: 0;
padding: 0.5rem;
line-height: 1;
overflow: hidden;
}
#input-image img {
filter: grayscale(1);
}
"""
# Get all image file paths in the folder
image_files = [
os.path.join("examples", file)
for file in os.listdir("examples")
if file.lower().endswith((".png", ".jpg", ".jpeg", ".webp"))
]
# Gradio Interface
with gr.Blocks(css=css) as demo:
demo.title = "Portrait Colorizer"
# title
gr.HTML("<h1>Portrait Colorizer</h1>")
# description
gr.HTML("<p>Upload a grayscale image to colorize it and fine-tune the output using the sliders below.</p>")
with gr.Row():
input_image = gr.Image(
type="numpy",
label="Grayscale Image",
image_mode="L",
height=256,
width=256,
elem_id="input-image",
)
examples_gallery = gr.Examples(
examples=image_files, inputs=[input_image], label="Example Images"
)
output_image = gr.Image(
type="numpy",
label="Colorized Image",
image_mode="RGB",
height=256,
width=256,
)
process_button = gr.Button("Colorize")
bright_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Brightness")
cont_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Contrast")
sat_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Saturation")
hue_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Hue")
# Initially colorize and display the image when it is uploaded
colorized_image = gr.State()
# Button click triggers processing
process_button.click(
fn=colorize_and_store,
inputs=[input_image, bright_slider, cont_slider, sat_slider, hue_slider],
outputs=[colorized_image, output_image],
)
# Apply hue adjustment to the stored colorized image (no re-generation)
bright_slider.change(
fn=adjust_brightness,
inputs=[colorized_image, bright_slider],
outputs=output_image, # Update output image
)
# Apply hue adjustment to the stored colorized image (no re-generation)
cont_slider.change(
fn=adjust_contrast,
inputs=[colorized_image, cont_slider],
outputs=output_image, # Update output image
)
# Apply hue adjustment to the stored colorized image (no re-generation)
hue_slider.change(
fn=adjust_hue,
inputs=[colorized_image, hue_slider],
outputs=output_image, # Update output image
)
# Apply saturation adjustment to the stored colorized image (no re-generation)
sat_slider.change(
fn=adjust_saturation,
inputs=[colorized_image, sat_slider],
outputs=output_image, # Update output image
)
# Launch the app
if __name__ == "__main__":
demo.launch()