How do I get streaming token generation from mistral_common? Example needed
#146
by
narai
- opened
I tried to do the following, but it appears that no tokens were generated. (I'm only using bitsandbytes at the moment because I can't get EETQ to work outside of huggingface TGI)
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextStreamer, TextIteratorStreamer, GenerationConfig, BitsAndBytesConfig #, EetqConfig
from threading import Thread
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
tokenizer = MistralTokenizer.v1()
streamer = TextIteratorStreamer(tokenizer)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
device_map="auto",
quantization_config=quantization_config,
token=HF_TOKEN
)
completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
input_ids = tokenizer.encode_chat_completion(completion_request).tokens
generation_kwargs = {'input_ids':input_ids,
'streamer':streamer,
'max_new_tokens':max_new_tokens,
'do_sample':True,
'temperature':TEMPERATURE,
'top_k':TOP_K,
'top_p':TOP_P,
}
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
return streamer