Spaces:
Runtime error
Runtime error
LanguageBind
commited on
Commit
•
cbfb9b8
1
Parent(s):
4cee86a
Update llava/model/builder.py
Browse files- llava/model/builder.py +16 -14
llava/model/builder.py
CHANGED
@@ -139,6 +139,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
139 |
if 'llava' in model_name.lower():
|
140 |
mm_use_x_start_end = getattr(model.config, "mm_use_x_start_end", False)
|
141 |
mm_use_x_patch_token = getattr(model.config, "mm_use_x_patch_token", True)
|
|
|
142 |
X = model.config.X
|
143 |
if mm_use_x_patch_token:
|
144 |
for x in X:
|
@@ -146,23 +147,24 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
146 |
if mm_use_x_start_end:
|
147 |
for x in X:
|
148 |
tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
|
|
|
149 |
model.resize_token_embeddings(len(tokenizer))
|
150 |
print(X)
|
151 |
-
if 'Image' in X:
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
|
159 |
-
if 'Video' in X:
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
|
167 |
if hasattr(model.config, "max_sequence_length"):
|
168 |
context_len = model.config.max_sequence_length
|
|
|
139 |
if 'llava' in model_name.lower():
|
140 |
mm_use_x_start_end = getattr(model.config, "mm_use_x_start_end", False)
|
141 |
mm_use_x_patch_token = getattr(model.config, "mm_use_x_patch_token", True)
|
142 |
+
'''
|
143 |
X = model.config.X
|
144 |
if mm_use_x_patch_token:
|
145 |
for x in X:
|
|
|
147 |
if mm_use_x_start_end:
|
148 |
for x in X:
|
149 |
tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
|
150 |
+
'''
|
151 |
model.resize_token_embeddings(len(tokenizer))
|
152 |
print(X)
|
153 |
+
#if 'Image' in X:
|
154 |
+
image_tower = model.get_image_tower()
|
155 |
+
if not image_tower.is_loaded:
|
156 |
+
image_tower.load_model()
|
157 |
+
image_tower.to(device=device, dtype=torch.float16)
|
158 |
+
image_processor = image_tower.image_processor
|
159 |
+
processor['image'] = image_processor
|
160 |
|
161 |
+
#if 'Video' in X:
|
162 |
+
video_tower = model.get_video_tower()
|
163 |
+
if not video_tower.is_loaded:
|
164 |
+
video_tower.load_model()
|
165 |
+
video_tower.to(device=device, dtype=torch.float16)
|
166 |
+
video_processor = video_tower.video_processor
|
167 |
+
processor['video'] = video_processor
|
168 |
|
169 |
if hasattr(model.config, "max_sequence_length"):
|
170 |
context_len = model.config.max_sequence_length
|