BK-Lee commited on
Commit
91911bd
1 Parent(s): dd4cd4b
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +2 -4
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Phantom
3
- emoji: ⛰️
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Phantom
3
+ emoji: 👻
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
app.py CHANGED
@@ -9,9 +9,7 @@ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENT
9
  import torch
10
  from PIL import Image
11
  from utils.utils import *
12
- import torch.nn.functional as F
13
  from model.load_model import load_model
14
- from torchvision.transforms.functional import pil_to_tensor
15
 
16
  # Gradio Package
17
  import time
@@ -49,7 +47,7 @@ def threading_function(inputs, streamer, device, model, tokenizer, temperature,
49
  generation_kwargs.update({'use_cache': True})
50
  return model.generate(**generation_kwargs)
51
 
52
- # @spaces.GPU
53
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
54
 
55
  # model selection
@@ -63,7 +61,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
63
  model = model_7
64
  tokenizer = tokenizer_7
65
 
66
- # X -> float16 conversion
67
  for param in model.parameters():
68
  if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
69
  param.data = param.data.to(torch.bfloat16)
 
9
  import torch
10
  from PIL import Image
11
  from utils.utils import *
 
12
  from model.load_model import load_model
 
13
 
14
  # Gradio Package
15
  import time
 
47
  generation_kwargs.update({'use_cache': True})
48
  return model.generate(**generation_kwargs)
49
 
50
+ @spaces.GPU
51
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
52
 
53
  # model selection
 
61
  model = model_7
62
  tokenizer = tokenizer_7
63
 
64
+ # X -> bfloat16 conversion
65
  for param in model.parameters():
66
  if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
67
  param.data = param.data.to(torch.bfloat16)