glass_try_on1 / app.py
Siyun He
add files
aea6c92
raw
history blame
2.91 kB
import cv2
import cvzone
import numpy as np
import os
import gradio as gr
# Load the YuNet model
model_path = 'face_detection_yunet_2023mar.onnx'
face_detector = cv2.FaceDetectorYN.create(model_path, "", (320, 320))
# Initialize the glass number
num = 1
overlay = cv2.imread(f'glasses/glass{num}.png', cv2.IMREAD_UNCHANGED)
# Count glasses files
def count_files_in_directory(directory):
file_count = 0
for root, dirs, files in os.walk(directory):
file_count += len(files)
return file_count
directory_path = 'glasses'
total_glass_num = count_files_in_directory(directory_path)
# Change glasses
def change_glasses(action):
global num, overlay
if action == 'next':
num += 1
if num > total_glass_num:
num = 1
overlay = cv2.imread(f'glasses/glass{num}.png', cv2.IMREAD_UNCHANGED)
return overlay
# Process frame for overlay
def process_frame(frame):
global overlay
height, width = frame.shape[:2]
face_detector.setInputSize((width, height))
_, faces = face_detector.detect(frame)
if faces is not None:
for face in faces:
x, y, w, h = face[:4].astype(int)
face_landmarks = face[4:14].reshape(5, 2).astype(int) # Facial landmarks
# Get the nose position (assuming landmark index 2 is the nose)
nose_x, nose_y = face_landmarks[2].astype(int)
# Resize the overlay
overlay_resize = cv2.resize(overlay, (int(w * 1.15), int(h * 0.8)))
# Ensure the frame is a NumPy array and writable
frame = np.array(frame)
# Calculate the position to center the glasses on the nose
overlay_x = nose_x - overlay_resize.shape[1] // 2
overlay_y = y # Adjust this if needed for better vertical alignment
# Overlay the glasses
try:
frame = cvzone.overlayPNG(frame, overlay_resize, [overlay_x, overlay_y])
except Exception as e:
print(f"Error overlaying glasses: {e}")
return frame
# # Display the frame
# def display_frame():
# cap = cv2.VideoCapture(0)
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# frame = process_frame(frame)
# cv2.imshow('SnapLens', frame)
# k = cv2.waitKey(10)
# if k == ord('q'):
# break
# cap.release()
# cv2.destroyAllWindows()
# display_frame()
# Gradio webcam input
def webcam_input(frame):
frame = process_frame(frame)
return frame
# Gradio Interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input", sources="webcam", streaming=True)
input_img.stream(webcam_input, [input_img], [input_img], time_limit=15, stream_every=0.1, concurrency_limit=30)
if __name__ == "__main__":
demo.launch(share=True)