File size: 3,808 Bytes
6052413
 
 
 
7ce0964
6052413
 
 
1bc459c
6052413
1bc459c
9c6ac55
6052413
9c6ac55
 
1bc459c
 
 
 
1e32b52
 
 
 
 
 
6052413
 
 
 
1e32b52
6052413
 
 
 
 
2454249
6052413
1bc459c
 
6052413
67c8755
 
1e32b52
6052413
 
 
 
 
 
77576fe
 
 
 
6052413
 
1bc459c
6052413
77576fe
 
6052413
 
 
77576fe
 
 
 
 
 
 
 
 
 
1e32b52
77576fe
 
 
6052413
 
 
1e32b52
 
 
 
6052413
1e32b52
6052413
 
 
 
 
 
 
 
 
77576fe
 
 
 
 
 
 
 
 
 
6052413
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import spaces

from huggingface_hub import hf_hub_download
from diffusers import FluxControlPipeline, FluxTransformer2DModel

####################################
#   Load the model(s) on GPU       #
####################################
path = "sayakpaul/FLUX.1-dev-edit-v0"
edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16)
pipeline = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
pipeline.set_adapters(["hyper-sd"], adapter_weights=[0.125])
MAX_SEED = np.iinfo(np.int32).max
def get_seed(randomize_seed: bool, seed: int) -> int:
    """
    Get the random seed.
    """
    return np.random.randint(0, MAX_SEED) if randomize_seed else seed
#####################################
#  The function for our Gradio app  #
#####################################
@spaces.GPU(duration=120)
def generate(prompt, input_image, seed, progress=gr.Progress(track_tqdm=True)):
    """
    Runs the Flux Control pipeline for editing the given `input_image`
    with the specified `prompt`. The pipeline is on CPU by default.
    """
    output_image = pipeline(
        control_image=input_image,
        prompt=prompt,
        guidance_scale=30.,
        num_inference_steps=8,
        max_sequence_length=512,
        height=input_image.height,
        width=input_image.width,
        generator=torch.manual_seed(seed)
    ).images[0]

    return output_image


def launch_app():
    css = '''
    .gradio-container{max-width: 1100px !important}
    '''
    with gr.Blocks(css=css) as demo:
        gr.Markdown(
            """
            # Flux Control Editing ๐Ÿ–Œ๏ธ

            Edit any image with the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) 
            [Flux Control edit framework](https://github.com/sayakpaul/flux-image-editing) by [Sayak Paul](https://huggingface.co/sayakpaul).
            """
        )
        with gr.Row():
            with gr.Column():
                with gr.Group():
                    input_image = gr.Image(
                        label="Image you would like to edit",
                        type="pil",
                    )
                    prompt = gr.Textbox(
                        label="Your edit prompt",
                        placeholder="e.g. 'Turn the color of the mushroom to blue'"
                    )
                    randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                    generate_button = gr.Button("Generate")
        
            output_image = gr.Image(label="Edited Image")

        # Connect button to function
        generate_button.click(
        get_seed,
        inputs=[randomize_seed, seed],
        outputs=[seed],
    ).then(
            fn=generate,
            inputs=[prompt, input_image, seed],
            outputs=[output_image],
        )

        gr.Examples(
            examples=[
                ["Turn the color of the mushroom to gray", "mushroom.jpg"],
                ["Make the mushroom polka-dotted", "mushroom.jpg"],
            ],
            inputs=[prompt, input_image],
            outputs=[output_image],
            fn=generate,
            cache_examples="lazy"
        )
        gr.Markdown(
            """
            **Acknowledgements**: 
            - [Sayak Paul](https://huggingface.co/sayakpaul) for open-sourcing FLUX.1-dev-edit-v0 
            - [black-forest-labs](https://huggingface.co/black-forest-labs) for FLUX.1-dev
            """
        )
    return demo


if __name__ == "__main__":
    demo = launch_app()
    demo.launch()