Several issues loading and using the model with transformers==4.39.2
#7
by
csegalin
- opened
class LlavaMistralCaptioner:
def __init__(self, device='cuda',
hf_model="llava-hf/llava-v1.6-mistral-7b-hf",
bf16=False,
quant_force=True,
):
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
self.device = device
if bf16:
self.torch_type = torch.bfloat16
else:
self.torch_type = torch.float16
with torch.cuda.device(self.device):
_, total_bytes = torch.cuda.mem_get_info()
total_gb = total_bytes / (1 << 30)
if total_gb < 40:
quant = True
else:
quant = False
self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=self.torch_type,
)
print("========Use torch type as:{} with device:{}========\n".format(self.torch_type, self.device))
self.model = LlavaNextForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=hf_model,
torch_dtype=self.torch_type,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
quantization_config=self.quantization_config if quant or quant_force else None,
# device_map='auto'
).eval()
self.model.tie_weights()
# self.processor = AutoProcessor.from_pretrained(hf_model)
self.processor = LlavaNextProcessor.from_pretrained(hf_model)
def caption(self, image,
prompt,
max_tokens=225,
top_k=1,
top_p=0.1,
num_beams=1,
do_sample=True,
temperature=0.1,
use_cache=True):
import re
prompt = f'''[INST] <image>\n {prompt} [/INST]'''
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, self.torch_type)
outputs = self.model.generate(**inputs,
max_new_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
do_sample=True if temperature > 0 else do_sample,
temperature=temperature,
use_cache=use_cache,
# pad_token_id=2,
# num_return_sequences=1
)
response = self.processor.decode(outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
response = response.split('[/INST]')[-1].strip()
response = re.sub(r'\n+', ' ', response)
response = response.strip().replace("</s>", "").replace("<s>", "").replace("*", " ")
return response
1 when load the model I get
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
The model weights are not tied. Please use the tie_weights
method before using the infer_auto_device
function.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2 when generating the caption I get the same caption repeated 3 times
Any help on this?
Having the same issues with this model: llava-hf/llama3-llava-next-8b-hf
Hey!
I am not 100% sure which arguments are used to run the script, but here is some common advice on LLaVa and FA2:
- FA2 should be loaded in half precision which I am not sure if happening in your script. Also, in LLaVa specifically the recommended precision of fp16 which is the one used in original llava impl
- mixing FA2 with quantization might result in weird/unexpected results, try using only one
- Hmm, the
tie_weights
message actually shouldn't be raised, and usually you don't have to tie weights manually, as you have in the example script. Let me know one of the above advices help. If not can you share a fully reproducible code, that doesn;t rely on external args/hardware limits?