""" |
Use JoyCaption to caption images. |
""" |
import argparse |
import dataclasses |
import json |
import logging |
import os |
import random |
from pathlib import Path |
import PIL.Image |
import torch |
import torch.amp |
import torchvision.transforms.functional as TVF |
from PIL import Image |
from torch.utils.data import DataLoader, Dataset |
from tqdm import tqdm |
from transformers import ( |
AutoTokenizer, |
LlavaForConditionalGeneration, |
PreTrainedTokenizer, |
PreTrainedTokenizerFast, |
) |
from typing import Union |
def none_or_type(value, desired_type): |
if value == "None": |
return None |
return desired_type(value) |
DEFAULT_PROMPT = "Write a descriptive caption for this image in a formal tone." |
parser = argparse.ArgumentParser() |
parser.add_argument('-i', '--input', type=str, help='Input image') |
parser.add_argument("--glob", type=str, help="Glob pattern to find images") |
parser.add_argument("--filelist", type=str, help="File containing list of images") |
parser.add_argument("--prompt", type=str, help="Prompt to use") |
parser.add_argument("--prompt-file", type=str, help="JSON file containing prompts to use") |
parser.add_argument("--batch-size", type=int, default=1, help="Batch size") |
parser.add_argument("--greedy", action="store_true", help="Use greedy decoding instead of sampling") |
parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") |
parser.add_argument("--top-p", type=lambda x: none_or_type(x, float), default=0.9, help="Top-p sampling") |
parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling") |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)") |
parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel") |
parser.add_argument("--model", type=str, default="John6666/llama-joycaption-alpha-two-hf-llava-nf4", help="Model to use") |
parser.add_argument("--nf4", action="store_true", default=True, help="Use NF4 (default: bfloat16)") |
PIL.Image.MAX_IMAGE_PIXELS = 933120000 |
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
@dataclasses.dataclass |
class Prompt: |
prompt: str |
weight: float |
@torch.inference_mode() |
def main(): |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") |
args = parser.parse_args() |
logging.info(f"Arguments: {args}") |
IS_NF4 = args.nf4 |
prompts = parse_prompts(args.prompt, args.prompt_file) |
image_paths = find_images(args.glob, args.filelist, args.input) |
if len(image_paths) == 0: |
logging.warning("No images found") |
return |
logging.info(f"Found {len(image_paths)} images") |
image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()] |
from transformers import BitsAndBytesConfig |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=torch.bfloat16, |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) |
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) |
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}" |
if IS_NF4: |
llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", quantization_config=nf4_config).eval() |
attention = llava_model.vision_tower.vision_model.head.attention |
attention.out_proj = torch.nn.Linear(attention.embed_dim, attention.embed_dim, device=llava_model.device, dtype=torch.bfloat16) |
else: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", device_map="auto").eval() |
assert isinstance(llava_model, LlavaForConditionalGeneration) |
dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length) |
dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, num_workers=args.num_workers, shuffle=False, drop_last=False, batch_size=args.batch_size) |
end_of_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") |
end_of_turn_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int) |
pbar = tqdm(total=len(image_paths), desc="Captioning images...", dynamic_ncols=True) |
for batch in dataloader: |
vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype |
vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device |
language_device = llava_model.language_model.get_input_embeddings().weight.device |
print(vision_device, vision_dtype, language_device) |
pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True) |
input_ids = batch['input_ids'].to(language_device, non_blocking=True) |
attention_mask = batch['attention_mask'].to(language_device, non_blocking=True) |
pixel_values = pixel_values / 255.0 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
pixel_values = pixel_values.to(vision_dtype) |
generate_ids = llava_model.generate( |
input_ids=input_ids, |
pixel_values=pixel_values, |
attention_mask=attention_mask, |
max_new_tokens=args.max_new_tokens, |
do_sample=not args.greedy, |
suppress_tokens=None, |
use_cache=True, |
temperature=args.temperature, |
top_k=args.top_k, |
top_p=args.top_p, |
) |
assert isinstance(generate_ids, torch.Tensor) |
generate_ids = generate_ids.tolist() |
generate_ids = [trim_off_prompt(ids, end_of_header_id, end_of_turn_id) for ids in generate_ids] |
captions = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) |
captions = [c.strip() for c in captions] |
for path, caption in zip(batch['paths'], captions): |
write_caption(Path(path), caption) |
pbar.update(len(captions)) |
def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]: |
while True: |
try: |
i = input_ids.index(eoh_id) |
except ValueError: |
break |
input_ids = input_ids[i + 1:] |
try: |
i = input_ids.index(eot_id) |
except ValueError: |
return input_ids |
return input_ids[:i] |
def write_caption(image_path: Path, caption: str): |
caption_path = image_path.with_suffix(".txt") |
try: |
f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) |
except FileExistsError: |
logging.warning(f"Caption file '{caption_path}' already exists") |
return |
except Exception as e: |
logging.error(f"Failed to open caption file '{caption_path}': {e}") |
return |
try: |
os.write(f, caption.encode("utf-8")) |
os.close(f) |
except Exception as e: |
logging.error(f"Failed to write caption to '{caption_path}': {e}") |
return |
class ImageDataset(Dataset): |
def __init__(self, prompts: list[Prompt], paths: list[Path], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], image_token_id: int, image_seq_length: int): |
self.prompts = prompts |
self.paths = paths |
self.tokenizer = tokenizer |
self.image_token_id = image_token_id |
self.image_seq_length = image_seq_length |
self.pad_token_id = tokenizer.pad_token_id |
def __len__(self): |
return len(self.paths) |
def __getitem__(self, idx: int) -> dict: |
path = self.paths[idx] |
prompt_str = random.choices(self.prompts, weights=[p.weight for p in self.prompts])[0].prompt |
try: |
image = Image.open(path) |
if image.size != (384, 384): |
image = image.resize((384, 384), Image.LANCZOS) |
image = image.convert("RGB") |
pixel_values = TVF.pil_to_tensor(image) |
except Exception as e: |
logging.error(f"Failed to load image '{path}': {e}") |
pixel_values = None |
convo = [ |
{ |
"role": "system", |
"content": "You are a helpful image captioner.", |
}, |
{ |
"role": "user", |
"content": prompt_str, |
}, |
] |
convo_string = self.tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) |
assert isinstance(convo_string, str) |
convo_tokens = self.tokenizer.encode(convo_string, add_special_tokens=False, truncation=False) |
input_tokens = [] |
for token in convo_tokens: |
if token == self.image_token_id: |
input_tokens.extend([self.image_token_id] * self.image_seq_length) |
else: |
input_tokens.append(token) |
input_ids = torch.tensor(input_tokens, dtype=torch.long) |
attention_mask = torch.ones_like(input_ids) |
return { |
'path': path, |
'pixel_values': pixel_values, |
'input_ids': input_ids, |
'attention_mask': attention_mask, |
} |
def collate_fn(self, batch: list[dict]) -> dict: |
batch = [item for item in batch if item['pixel_values'] is not None] |
max_length = max(item['input_ids'].shape[0] for item in batch) |
n_pad = [max_length - item['input_ids'].shape[0] for item in batch] |
input_ids = torch.stack([torch.nn.functional.pad(item['input_ids'], (n, 0), value=self.pad_token_id) for item, n in zip(batch, n_pad)]) |
attention_mask = torch.stack([torch.nn.functional.pad(item['attention_mask'], (n, 0), value=0) for item, n in zip(batch, n_pad)]) |
pixel_values = torch.stack([item['pixel_values'] for item in batch]) |
paths = [item['path'] for item in batch] |
return { |
'paths': paths, |
'pixel_values': pixel_values, |
'input_ids': input_ids, |
'attention_mask': attention_mask, |
} |
def parse_prompts(prompt_str: Union[str, None], prompt_file: Union[str, None]) -> list[Prompt]: |
if prompt_str is not None and prompt_file is not None: |
raise ValueError("Cannot specify both --prompt and --prompt-file") |
if prompt_str is not None: |
return [Prompt(prompt=prompt_str, weight=1.0)] |
if prompt_file is None: |
return [Prompt(prompt=DEFAULT_PROMPT, weight=1.0)] |
data = json.loads(Path(prompt_file).read_text()) |
if not isinstance(data, list): |
raise ValueError("Expected JSON file to contain a list of prompts") |
prompts = [] |
for item in data: |
if isinstance(item, str): |
prompts.append(Prompt(prompt=item, weight=1.0)) |
elif isinstance(item, dict) and "prompt" in item and "weight" in item and isinstance(item["prompt"], str) and isinstance(item["weight"], (int, float)): |
prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"])) |
else: |
raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}") |
if len(prompts) == 0: |
raise ValueError("No prompts found in JSON file") |
if sum(p.weight for p in prompts) <= 0.0: |
raise ValueError("Prompt weights must sum to a positive number") |
return prompts |
def find_images(glob: Union[str, None], filelist: Union[str, Path, None], input: str) -> list[Path]: |
if glob is None and filelist is None and input is None: |
raise ValueError("Must specify either --glob or --filelist or --input") |
paths = [] |
if glob is not None: |
paths.extend(Path(".").glob(glob)) |
if filelist is not None: |
paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != "")) |
if input is not None: |
paths.append(input) |
return paths |
if __name__ == "__main__": |
main() |