Janus-Pro-7B / app.py
ginipick's picture
Update app.py
bd608a9 verified
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM, pipeline as translation_pipeline
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from PIL import Image
import numpy as np
import os
import time
from Upsample import RealESRGAN
import spaces # Import spaces for ZeroGPU compatibility
import re
# ๋ฒˆ์—ญ ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (ํ•œ๊ธ€ โ†’ ์˜์–ด)
translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
def translate_if_korean(prompt: str) -> str:
"""ํ”„๋กฌํ”„ํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉด ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ๊ฐ€-ํžฃ]', prompt):
try:
translation = translator(prompt)[0]['translation_text']
return translation
except Exception as e:
print(f"Translation error: {e}")
return prompt
return prompt
# Load model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(
model_path,
language_config=language_config,
trust_remote_code=True
)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
vl_gpt = vl_gpt.to(torch.float16)
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# SR model
sr_model = RealESRGAN(torch.device(cuda_device), scale=2)
sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
@torch.inference_mode()
@spaces.GPU(duration=120)
def multimodal_understanding(image, question, seed, top_p, temperature):
# (์ƒ๋žต) ๊ธฐ์กด multimodal ์ดํ•ด ํ•จ์ˆ˜ ๋‚ด์šฉ ๊ทธ๋Œ€๋กœ...
torch.cuda.empty_cache()
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
conversation = [
{
"role": "<|User|>",
"content": f"<image_placeholder>\n{question}",
"images": [image],
},
{"role": "<|Assistant|>", "content": ""},
]
pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False if temperature == 0 else True,
use_cache=True,
temperature=temperature,
top_p=top_p,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
def generate(input_ids, width, height, temperature: float = 1,
parallel_size: int = 5, cfg_weight: float = 5,
image_token_num_per_image: int = 576, patch_size: int = 16):
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(image_token_num_per_image):
with torch.no_grad():
outputs = vl_gpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=pkv
)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size]
)
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=5):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
# ๋ฒˆ์—ญ: ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉด ์˜์–ด๋กœ ๋ณ€ํ™˜
prompt = translate_if_korean(prompt)
torch.cuda.empty_cache()
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
width = 384
height = 384
parallel_size = 5
with torch.no_grad():
messages = [{'role': '<|User|>', 'content': prompt},
{'role': '<|Assistant|>', 'content': ''}]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=messages,
sft_format=vl_chat_processor.sft_format,
system_prompt=''
)
text = text + vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
output, patches = generate(
input_ids,
width // 16 * 16,
height // 16 * 16,
cfg_weight=guidance,
parallel_size=parallel_size,
temperature=t2i_temperature
)
images = unpack(
patches,
width // 16 * 16,
height // 16 * 16,
parallel_size=parallel_size
)
stime = time.time()
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
print(f'upsample time: {time.time() - stime}')
return ret_images
@spaces.GPU(duration=60)
def image_upsample(img: Image.Image) -> Image.Image:
if img is None:
raise Exception("Image not uploaded")
width, height = img.size
if width >= 5000 or height >= 5000:
raise Exception("The image is too large.")
global sr_model
result = sr_model.predict(img.convert('RGB'))
return result
# Custom CSS for a sleek, modern and highly readable interface
custom_css = """
body {
background: #f0f2f5;
font-family: 'Segoe UI', sans-serif;
color: #333;
}
h1, h2, h3 {
font-weight: 600;
}
.gradio-container {
padding: 20px;
}
header {
text-align: center;
padding: 20px;
margin-bottom: 20px;
}
header h1 {
font-size: 3em;
color: #2c3e50;
}
.gr-button {
background-color: #3498db !important;
color: #fff !important;
border: none !important;
padding: 10px 20px !important;
border-radius: 5px !important;
font-size: 1em !important;
}
.gr-button:hover {
background-color: #2980b9 !important;
}
.gr-input, .gr-slider, .gr-number, .gr-textbox {
border-radius: 5px;
}
.gr-gallery-item {
border-radius: 10px;
overflow: hidden;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css, title="Multimodal & T2I") as demo:
with gr.Column(variant="panel"):
gr.Markdown("<header><h1>Chat With Janus-Pro-7B</h1></header>")
with gr.Tabs():
with gr.TabItem("Multimodal Understanding"):
gr.Markdown("### Chat with Images")
with gr.Row():
image_input = gr.Image(label="Upload Image", type="numpy")
with gr.Column():
question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4)
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p")
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
understanding_button = gr.Button("Chat", elem_id="understanding-button")
understanding_output = gr.Textbox(label="Response", lines=6)
with gr.Accordion("Examples", open=False):
gr.Examples(
label="Multimodal Understanding Examples",
examples=[
["explain this meme", "doge.png"]
],
inputs=[question_input, image_input],
)
understanding_button.click(
multimodal_understanding,
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
outputs=understanding_output,
)
with gr.TabItem("Text-to-Image Generation"):
gr.Markdown("### Generate Images from Text")
with gr.Row():
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4)
with gr.Row():
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
generation_button = gr.Button("Generate Images", elem_id="generation-button")
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
with gr.Accordion("Examples", open=False):
gr.Examples(
label="Text-to-Image Examples",
examples=[
"Master shifu racoon wearing drip attire as a street gangster.",
"The face of a beautiful girl",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.",
"๊ณ ์–‘์ด๊ฐ€ ์šฐ์ฃผ๋ณต์„ ์ž…๊ณ  ๋‹ฌ์— ์žˆ๋Š” ๋ชจ์Šต"
],
inputs=prompt_input,
)
generation_button.click(
fn=generate_image,
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
outputs=image_output,
)
gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
demo.launch(share=True)