support stream and system prompt (#15)
Browse files- support stream and sysyem prompt (9a7eb1ec0769ca938d7120f42edb9b942a1d431c)
- README.md +19 -2
- modeling_minicpmv.py +44 -4
README.md
CHANGED
@@ -377,10 +377,27 @@ res = model.chat(
|
|
377 |
image=image,
|
378 |
msgs=msgs,
|
379 |
tokenizer=tokenizer,
|
380 |
-
sampling=True,
|
381 |
-
temperature=0.7
|
|
|
382 |
)
|
383 |
print(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
```
|
385 |
|
386 |
Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-V) for more detail about usage.
|
|
|
377 |
image=image,
|
378 |
msgs=msgs,
|
379 |
tokenizer=tokenizer,
|
380 |
+
sampling=True, # if sampling=False, beam_search will be used by default
|
381 |
+
temperature=0.7,
|
382 |
+
# system_prompt='' # pass system_prompt if needed
|
383 |
)
|
384 |
print(res)
|
385 |
+
|
386 |
+
## if you want to use streaming, please make sure sampling=True and stream=True
|
387 |
+
## the model.chat will return a generator
|
388 |
+
res = model.chat(
|
389 |
+
image=image,
|
390 |
+
msgs=msgs,
|
391 |
+
tokenizer=tokenizer,
|
392 |
+
sampling=True,
|
393 |
+
temperature=0.7,
|
394 |
+
stream=True
|
395 |
+
)
|
396 |
+
|
397 |
+
generated_text = ""
|
398 |
+
for new_text in res:
|
399 |
+
generated_text += new_text
|
400 |
+
print(new_text, flush=True, end='')
|
401 |
```
|
402 |
|
403 |
Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-V) for more detail about usage.
|
modeling_minicpmv.py
CHANGED
@@ -3,10 +3,11 @@ from typing import List, Optional
|
|
3 |
import json
|
4 |
import torch
|
5 |
import torchvision
|
|
|
6 |
from copy import deepcopy
|
7 |
from PIL import Image
|
8 |
from torchvision import transforms
|
9 |
-
from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast
|
10 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
11 |
|
12 |
from .configuration_minicpm import MiniCPMVConfig
|
@@ -218,6 +219,25 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
218 |
**kwargs
|
219 |
)
|
220 |
return self._decode_text(output, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
def _decode_text(self, result_ids, tokenizer):
|
223 |
result_text = []
|
@@ -294,6 +314,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
294 |
max_inp_length: Optional[int] = None,
|
295 |
vision_hidden_states=None,
|
296 |
return_vision_hidden_states=False,
|
|
|
297 |
**kwargs
|
298 |
):
|
299 |
|
@@ -326,7 +347,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
326 |
vision_hidden_states,
|
327 |
) = self.get_vllm_embedding(model_inputs)
|
328 |
|
329 |
-
|
|
|
|
|
|
|
330 |
|
331 |
if return_vision_hidden_states:
|
332 |
return result, vision_hidden_states
|
@@ -342,6 +366,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
342 |
max_new_tokens=1024,
|
343 |
sampling=True,
|
344 |
max_inp_length=2048,
|
|
|
|
|
345 |
**kwargs
|
346 |
):
|
347 |
if isinstance(msgs, str):
|
@@ -349,6 +375,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
349 |
|
350 |
copy_msgs = deepcopy(msgs)
|
351 |
assert len(copy_msgs) > 0, 'msgs is empty'
|
|
|
352 |
|
353 |
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
354 |
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
@@ -393,6 +420,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
393 |
if tgt_sizes:
|
394 |
tgt_sizes = torch.vstack(tgt_sizes)
|
395 |
|
|
|
|
|
|
|
|
|
396 |
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
397 |
|
398 |
if sampling:
|
@@ -423,11 +454,20 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
423 |
max_new_tokens=max_new_tokens,
|
424 |
vision_hidden_states=vision_hidden_states,
|
425 |
return_vision_hidden_states=True,
|
|
|
426 |
**generation_config
|
427 |
)
|
428 |
-
answer = res[0]
|
429 |
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
|
432 |
|
433 |
class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
|
|
|
3 |
import json
|
4 |
import torch
|
5 |
import torchvision
|
6 |
+
from threading import Thread
|
7 |
from copy import deepcopy
|
8 |
from PIL import Image
|
9 |
from torchvision import transforms
|
10 |
+
from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast, TextIteratorStreamer
|
11 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
12 |
|
13 |
from .configuration_minicpm import MiniCPMVConfig
|
|
|
219 |
**kwargs
|
220 |
)
|
221 |
return self._decode_text(output, tokenizer)
|
222 |
+
|
223 |
+
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
|
224 |
+
terminators = [
|
225 |
+
tokenizer.eos_token_id,
|
226 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
227 |
+
]
|
228 |
+
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
229 |
+
generation_kwargs = {
|
230 |
+
'inputs_embeds': inputs_embeds,
|
231 |
+
'pad_token_id': 0,
|
232 |
+
'eos_token_id': terminators,
|
233 |
+
'streamer': streamer
|
234 |
+
}
|
235 |
+
generation_kwargs.update(kwargs)
|
236 |
+
|
237 |
+
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
|
238 |
+
thread.start()
|
239 |
+
|
240 |
+
return streamer
|
241 |
|
242 |
def _decode_text(self, result_ids, tokenizer):
|
243 |
result_text = []
|
|
|
314 |
max_inp_length: Optional[int] = None,
|
315 |
vision_hidden_states=None,
|
316 |
return_vision_hidden_states=False,
|
317 |
+
stream=False,
|
318 |
**kwargs
|
319 |
):
|
320 |
|
|
|
347 |
vision_hidden_states,
|
348 |
) = self.get_vllm_embedding(model_inputs)
|
349 |
|
350 |
+
if stream:
|
351 |
+
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
352 |
+
else:
|
353 |
+
result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
354 |
|
355 |
if return_vision_hidden_states:
|
356 |
return result, vision_hidden_states
|
|
|
366 |
max_new_tokens=1024,
|
367 |
sampling=True,
|
368 |
max_inp_length=2048,
|
369 |
+
system_prompt='',
|
370 |
+
stream=False,
|
371 |
**kwargs
|
372 |
):
|
373 |
if isinstance(msgs, str):
|
|
|
375 |
|
376 |
copy_msgs = deepcopy(msgs)
|
377 |
assert len(copy_msgs) > 0, 'msgs is empty'
|
378 |
+
assert sampling or not stream, 'if use stream mode, make sure sampling=True'
|
379 |
|
380 |
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
381 |
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
|
|
420 |
if tgt_sizes:
|
421 |
tgt_sizes = torch.vstack(tgt_sizes)
|
422 |
|
423 |
+
if system_prompt:
|
424 |
+
sys_msg = {'role': 'system', 'content': system_prompt}
|
425 |
+
copy_msgs = [sys_msg] + copy_msgs
|
426 |
+
|
427 |
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
428 |
|
429 |
if sampling:
|
|
|
454 |
max_new_tokens=max_new_tokens,
|
455 |
vision_hidden_states=vision_hidden_states,
|
456 |
return_vision_hidden_states=True,
|
457 |
+
stream=stream,
|
458 |
**generation_config
|
459 |
)
|
|
|
460 |
|
461 |
+
if stream:
|
462 |
+
def stream_gen():
|
463 |
+
for text in res:
|
464 |
+
text = text.replace(tokenizer.eot_token, '').replace(tokenizer.eos_token, '')
|
465 |
+
yield text
|
466 |
+
return stream_gen()
|
467 |
+
|
468 |
+
else:
|
469 |
+
answer = res[0]
|
470 |
+
return answer
|
471 |
|
472 |
|
473 |
class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
|