import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr import copy # added import numpy as np import mediapipe as mp from mediapipe import solutions from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks import python from mediapipe.tasks.python import vision import cv2 from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) right_style_lm[0].color=(251, 206, 177) left_style_lm[0].color=(255, 255, 225) def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False): hand_landmarks_list = detection_result.hand_landmarks handedness_list = detection_result.handedness if overlap: annotated_image = np.copy(rgb_image) else: annotated_image = np.zeros_like(rgb_image) # Loop through the detected hands to visualize. for idx in range(len(hand_landmarks_list)): hand_landmarks = hand_landmarks_list[idx] handedness = handedness_list[idx] # Draw the hand landmarks. hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() hand_landmarks_proto.landmark.extend([ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks ]) if hand_encoding: if handedness[0].category_name == "Left": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, left_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) if handedness[0].category_name == "Right": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, right_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) else: solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, solutions.drawing_styles.get_default_hand_landmarks_style(), solutions.drawing_styles.get_default_hand_connections_style()) return annotated_image def generate_annotation(img, overlap=False, hand_encoding=False): """img(input): numpy array annotated_image(output): numpy array """ # STEP 2: Create an HandLandmarker object. base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2) detector = vision.HandLandmarker.create_from_options(options) # STEP 3: Load the input image. image = mp.Image( image_format=mp.ImageFormat.SRGB, data=img) # STEP 4: Detect hand landmarks from the input image. detection_result = detector.detect(image) # STEP 5: Process the classification result. In this case, visualize it. annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding) return annotated_image std_args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) enc_args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k", controlnet_revision=None, controlnet_from_pt=False, ) std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained( std_args.controlnet_model_name_or_path, revision=std_args.controlnet_revision, from_pt=std_args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained( enc_args.controlnet_model_name_or_path, revision=enc_args.controlnet_revision, from_pt=enc_args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) std_pipeline, std_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( std_args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=std_controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=std_args.revision, from_pt=std_args.from_pt, ) enc_pipeline, enc_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( enc_args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=enc_controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=enc_args.revision, from_pt=enc_args.from_pt, ) std_pipeline_params["controlnet"] = std_controlnet_params std_pipeline_params = jax_utils.replicate(std_pipeline_params) enc_pipeline_params["controlnet"] = enc_controlnet_params enc_pipeline_params = jax_utils.replicate(enc_pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) def infer(prompt, negative_prompt, image, model_type="Standard"): prompts = num_samples * [prompt] if model_type=="Standard": prompt_ids = std_pipeline.prepare_text_inputs(prompts) elif model_type=="Hand Encoding": prompt_ids = enc_pipeline.prepare_text_inputs(prompts) else: pass prompt_ids = shard(prompt_ids) if model_type=="Standard": annotated_image = generate_annotation(image, overlap=False, hand_encoding=False) overlap_image = generate_annotation(image, overlap=True, hand_encoding=False) elif model_type=="Hand Encoding": annotated_image = generate_annotation(image, overlap=False, hand_encoding=True) overlap_image = generate_annotation(image, overlap=True, hand_encoding=True) else: pass validation_image = Image.fromarray(annotated_image).convert("RGB") if model_type=="Standard": processed_image = std_pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = std_pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = std_pipeline( prompt_ids=prompt_ids, image=processed_image, params=std_pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images elif model_type=="Hand Encoding": processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = enc_pipeline( prompt_ids=prompt_ids, image=processed_image, params=enc_pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images else: pass images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) results = [i for i in images] return [overlap_image, annotated_image] + results with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.") gr.Markdown("""

Summary

As Stable diffusion and other diffusion models are notoriously poor at generating realistic hands for our project we decided to train a ControlNet model using MediaPipes landmarks in order to generate more realistic hands avoiding common issues such as unrealistic positions and irregular digits.
We opted to use the [HAnd Gesture Recognition Image Dataset](https://github.com/hukenovs/hagrid) (HaGRID) and [MediaPipe's Hand Landmarker](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker) to train a control net that could potentially be used independently or as an in-painting tool.
To preprocess the data there were three options we considered: - The first was to use Mediapipes built-in draw landmarks function. This was an obvious first choice however we noticed with low training steps that the model couldn't easily distinguish handedness and would often generate the wrong hand for the conditioning image.

Forwarding
Original Image

Routing
Conditioning Image

- To counter this issue we changed the palm landmark colours with the intention to keep the color similar in order to learn that they provide similar information, but different to make the model know which hands were left or right.

Forwarding
Original Image

Routing
Conditioning Image

- The last option was to use [MediaPipe Holistic](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) to provide pose face and hand landmarks to the ControlNet. This method was promising in theory, however, the HaGRID dataset was not suitable for this method as the Holistic model performs poorly with partial body and obscurely cropped images. We anecdotally determined that when trained at lower steps the encoded hand model performed better than the standard MediaPipe model due to implied handedness. We theorize that with a larger dataset of more full-body hand and pose classifications, Holistic landmarks will provide the best images in the future however for the moment the hand encoded model performs best. """) gr.Markdown("""

LINKS 🔗

Standard Model Link

Model using Hand Encoding

Dataset Used To Train the Standard Model

Dataset Used To Train the Hand Encoding Model

Standard Data Preprocessing Script

Hand Encoding Data Preprocessing Script

""") model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard MediaPipe landmarks, and one with different (but similar) coloring on palm landmarks to distinguish left and right") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") input_image = gr.Image(label="Input Image") # output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') submit_btn = gr.Button(value = "Submit") # inputs = [prompt_input, negative_prompt, input_image] # submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) with gr.Column(): output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto') gr.Examples( examples=[ [ "a woman is making an ok sign in front of a painting", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example.png" ], [ "a man with his hands up in the air making a rock sign", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example1.png" ], [ "a man is making a thumbs up gesture", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example2.png" ], [ "a woman is holding up her hand in front of a window", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example3.png" ], [ "a man with his finger on his lips", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example4.png" ], ], inputs=[prompt_input, negative_prompt, input_image, model_type], outputs=[output_image], fn=infer, cache_examples=False, #cache_examples=True, ) inputs = [prompt_input, negative_prompt, input_image, model_type] submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) demo.launch()