75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
|
|
import torch
|
||
|
|
import torch.amp
|
||
|
|
import torchvision.transforms.functional as TVF
|
||
|
|
from PIL import Image
|
||
|
|
from transformers import AutoTokenizer, LlavaForConditionalGeneration
|
||
|
|
|
||
|
|
|
||
|
|
IMAGE_PATH = "C:/Users/27698/Desktop/node/12/00001.png"
|
||
|
|
PROMPT = "Write a long descriptive caption for this image in a formal tone."
|
||
|
|
MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
|
||
|
|
|
||
|
|
|
||
|
|
# Load JoyCaption
|
||
|
|
# bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
|
||
|
|
# device_map=0 loads the model into the first GPU
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
||
|
|
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map="cuda:0")
|
||
|
|
llava_model.eval()
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
# Load and preprocess image
|
||
|
|
# Normally you would use the Processor here, but the image module's processor
|
||
|
|
# has some buggy behavior and a simple resize in Pillow yields higher quality results
|
||
|
|
image = Image.open(IMAGE_PATH)
|
||
|
|
|
||
|
|
if image.size != (384, 384):
|
||
|
|
image = image.resize((384, 384), Image.LANCZOS)
|
||
|
|
|
||
|
|
image = image.convert("RGB")
|
||
|
|
pixel_values = TVF.pil_to_tensor(image)
|
||
|
|
|
||
|
|
# Normalize the image
|
||
|
|
pixel_values = pixel_values / 255.0
|
||
|
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
||
|
|
pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0)
|
||
|
|
|
||
|
|
# Build the conversation
|
||
|
|
convo = [
|
||
|
|
{
|
||
|
|
"role": "system",
|
||
|
|
"content": "You are a helpful image captioner.",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"role": "user",
|
||
|
|
"content": PROMPT,
|
||
|
|
},
|
||
|
|
]
|
||
|
|
|
||
|
|
# Format the conversation
|
||
|
|
convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
||
|
|
|
||
|
|
# Tokenize the conversation
|
||
|
|
convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
|
||
|
|
|
||
|
|
# Repeat the image tokens
|
||
|
|
input_tokens = []
|
||
|
|
for token in convo_tokens:
|
||
|
|
if token == llava_model.config.image_token_index:
|
||
|
|
input_tokens.extend([llava_model.config.image_token_index] * llava_model.config.image_seq_length)
|
||
|
|
else:
|
||
|
|
input_tokens.append(token)
|
||
|
|
|
||
|
|
input_ids = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0)
|
||
|
|
attention_mask = torch.ones_like(input_ids)
|
||
|
|
|
||
|
|
# Generate the caption
|
||
|
|
generate_ids = llava_model.generate(input_ids=input_ids.to('cuda'), pixel_values=pixel_values.to('cuda'), attention_mask=attention_mask.to('cuda'), max_new_tokens=300, do_sample=True, suppress_tokens=None, use_cache=True)[0]
|
||
|
|
|
||
|
|
# Trim off the prompt
|
||
|
|
generate_ids = generate_ids[input_ids.shape[1]:]
|
||
|
|
|
||
|
|
# Decode the caption
|
||
|
|
caption = tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||
|
|
caption = caption.strip()
|
||
|
|
print(caption)
|