Spaces:
Running
on
Zero
Running
on
Zero
fix zero gpu
Browse files
app.py
CHANGED
@@ -44,9 +44,10 @@ tokenizer, model, data_loader = load_model_and_dataloader(
|
|
44 |
load_4bit=load_4bit,
|
45 |
load_bf16=load_bf16,
|
46 |
scene_to_obj_mapping=scene_to_obj_mapping,
|
47 |
-
|
|
|
|
|
48 |
|
49 |
-
@spaces.GPU
|
50 |
def get_chatbot_response(user_chat_input, scene_id):
|
51 |
# Get the response from the model
|
52 |
prompt, response = get_model_response(
|
|
|
44 |
load_4bit=load_4bit,
|
45 |
load_bf16=load_bf16,
|
46 |
scene_to_obj_mapping=scene_to_obj_mapping,
|
47 |
+
device_map='cpu',
|
48 |
+
) # Huggingface Zero-GPU has to use .to(device) to set the device, otherwise it will fail
|
49 |
+
model.to("cuda") # Huggingface Zero-GPU requires explicit device placement
|
50 |
|
|
|
51 |
def get_chatbot_response(user_chat_input, scene_id):
|
52 |
# Get the response from the model
|
53 |
prompt, response = get_model_response(
|
model.py
CHANGED
@@ -12,8 +12,7 @@ from llava.mm_utils import get_model_name_from_path
|
|
12 |
from llava.model.builder import load_pretrained_model
|
13 |
from llava.utils import disable_torch_init
|
14 |
|
15 |
-
|
16 |
-
def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_context_feature_type="text", load_8bit=False, load_4bit=False, load_bf16=False):
|
17 |
|
18 |
model_name = get_model_name_from_path(model_path)
|
19 |
|
@@ -24,6 +23,7 @@ def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_
|
|
24 |
load_8bit=load_8bit,
|
25 |
load_4bit=load_4bit,
|
26 |
load_bf16=load_bf16,
|
|
|
27 |
)
|
28 |
|
29 |
dataset = ObjIdentifierDataset(
|
@@ -41,7 +41,6 @@ def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_
|
|
41 |
return tokenizer, model, data_loader
|
42 |
|
43 |
|
44 |
-
@spaces.GPU
|
45 |
def get_model_response(model, tokenizer, data_loader, scene_id, user_input, max_new_tokens=50, temperature=0.2, top_p=0.9):
|
46 |
input_data = [
|
47 |
{
|
|
|
12 |
from llava.model.builder import load_pretrained_model
|
13 |
from llava.utils import disable_torch_init
|
14 |
|
15 |
+
def load_model_and_dataloader(model_path, model_base, scene_to_obj_mapping, obj_context_feature_type="text", load_8bit=False, load_4bit=False, load_bf16=False, device_map='auto'):
|
|
|
16 |
|
17 |
model_name = get_model_name_from_path(model_path)
|
18 |
|
|
|
23 |
load_8bit=load_8bit,
|
24 |
load_4bit=load_4bit,
|
25 |
load_bf16=load_bf16,
|
26 |
+
device_map=device_map,
|
27 |
)
|
28 |
|
29 |
dataset = ObjIdentifierDataset(
|
|
|
41 |
return tokenizer, model, data_loader
|
42 |
|
43 |
|
|
|
44 |
def get_model_response(model, tokenizer, data_loader, scene_id, user_input, max_new_tokens=50, temperature=0.2, top_p=0.9):
|
45 |
input_data = [
|
46 |
{
|