MaziyarPanahi commited on
Commit
b6fa3b6
1 Parent(s): f4ce971

update app

Browse files
Files changed (1) hide show
  1. app.py +49 -40
app.py CHANGED
@@ -1,82 +1,91 @@
1
- import gradio as gr
 
2
 
 
 
 
3
  from transformers import AutoProcessor, LlavaForConditionalGeneration
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, TextIteratorStreamer
5
 
6
- from threading import Thread
7
- import re
8
- import time
9
- from PIL import Image
10
- import torch
11
  import spaces
12
- import requests
13
-
14
- CSS ="""
15
- #component-3 {
16
- height: 500px !important;
17
- }"""
18
 
19
  model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
20
 
21
  processor = AutoProcessor.from_pretrained(model_id)
22
 
23
  model = LlavaForConditionalGeneration.from_pretrained(
24
- model_id,
25
- torch_dtype=torch.float16,
26
- low_cpu_mem_usage=True,
27
  )
28
 
29
  model.to("cuda:0")
30
  model.generation_config.eos_token_id = 128009
31
 
 
32
  @spaces.GPU
33
  def bot_streaming(message, history):
34
  print(message)
35
  if message["files"]:
36
- image = message["files"][-1]["path"]
 
 
 
 
37
  else:
38
  # if there's no image uploaded for this turn, look for images in the past turns
39
  # kept inside tuples, take the last one
40
  for hist in history:
41
- if type(hist[0])==tuple:
42
- image = hist[0][0]
43
  try:
44
  if image is None:
45
  # Handle the case where image is None
46
- gr.Error("You need to upload an image for LLaVA to work.")
47
  except NameError:
48
  # Handle the case where 'image' is not defined at all
49
- gr.Error("You need to upload an image for LLaVA to work.")
50
-
51
- prompt=f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
52
- print(f"prompt: {prompt}")
53
  image = Image.open(image)
54
  inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
55
-
56
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
57
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
58
- generated_text = ""
59
 
60
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
61
  thread.start()
62
-
63
- text_prompt =f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
64
- print(f"text_prompt: {text_prompt}")
65
 
66
  buffer = ""
 
67
  for new_text in streamer:
68
-
 
 
69
  buffer += new_text
70
-
71
- generated_text_without_prompt = buffer[len(text_prompt):]
72
- time.sleep(0.08)
 
 
 
73
  yield generated_text_without_prompt
74
 
75
 
76
- demo = gr.ChatInterface(fn=bot_streaming, css=CSS, fill_height=True, title="LLaVA Llama-3-8B", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
77
- {"text": "How to make this pastry?", "files":["./baklava.png"]}],
78
- description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
79
- stop_btn="Stop Generation", multimodal=True)
 
 
 
 
 
 
80
 
81
  demo.queue(api_open=False)
82
- demo.launch(show_api=False, share=False)
 
1
+ import time
2
+ from threading import Thread
3
 
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
  from transformers import AutoProcessor, LlavaForConditionalGeneration
8
+ from transformers import TextIteratorStreamer
9
 
 
 
 
 
 
10
  import spaces
 
 
 
 
 
 
11
 
12
  model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
13
 
14
  processor = AutoProcessor.from_pretrained(model_id)
15
 
16
  model = LlavaForConditionalGeneration.from_pretrained(
17
+ model_id,
18
+ torch_dtype=torch.float16,
19
+ low_cpu_mem_usage=True,
20
  )
21
 
22
  model.to("cuda:0")
23
  model.generation_config.eos_token_id = 128009
24
 
25
+
26
  @spaces.GPU
27
  def bot_streaming(message, history):
28
  print(message)
29
  if message["files"]:
30
+ # message["files"][-1] is a Dict or just a string
31
+ if type(message["files"][-1]) == dict:
32
+ image = message["files"][-1]["path"]
33
+ else:
34
+ image = message["files"][-1]
35
  else:
36
  # if there's no image uploaded for this turn, look for images in the past turns
37
  # kept inside tuples, take the last one
38
  for hist in history:
39
+ if type(hist[0]) == tuple:
40
+ image = hist[0][0]
41
  try:
42
  if image is None:
43
  # Handle the case where image is None
44
+ gr.Error("You need to upload an image for LLaVA to work.")
45
  except NameError:
46
  # Handle the case where 'image' is not defined at all
47
+ gr.Error("You need to upload an image for LLaVA to work.")
48
+
49
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
50
+ # print(f"prompt: {prompt}")
51
  image = Image.open(image)
52
  inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
53
+
54
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
55
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
 
56
 
57
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
58
  thread.start()
59
+
60
+ text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
61
+ # print(f"text_prompt: {text_prompt}")
62
 
63
  buffer = ""
64
+ time.sleep(0.5)
65
  for new_text in streamer:
66
+ # find <|eot_id|> and remove it from the new_text
67
+ if "<|eot_id|>" in new_text:
68
+ new_text = new_text.split("<|eot_id|>")[0]
69
  buffer += new_text
70
+
71
+ # generated_text_without_prompt = buffer[len(text_prompt):]
72
+ generated_text_without_prompt = buffer
73
+ # print(generated_text_without_prompt)
74
+ time.sleep(0.06)
75
+ # print(f"new_text: {generated_text_without_prompt}")
76
  yield generated_text_without_prompt
77
 
78
 
79
+ demo = gr.ChatInterface(
80
+ fn=bot_streaming,
81
+ fill_height=False,
82
+ title="LLaVA Llama-3-8B",
83
+ examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
84
+ {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
85
+ description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
86
+ stop_btn="Stop Generation",
87
+ multimodal=True
88
+ )
89
 
90
  demo.queue(api_open=False)
91
+ demo.launch(show_api=False, share=False)