Update modeling_qwen.py
Browse files- modeling_qwen.py +5 -0
modeling_qwen.py
CHANGED
@@ -623,6 +623,11 @@ class QWenModel(QWenPreTrainedModel):
|
|
623 |
|
624 |
if inputs_embeds is None:
|
625 |
inputs_embeds = self.wte(input_ids)
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
if batch_size <= 0:
|
628 |
raise ValueError("batch_size has to be defined and > 0")
|
|
|
623 |
|
624 |
if inputs_embeds is None:
|
625 |
inputs_embeds = self.wte(input_ids)
|
626 |
+
if self.training and images == None: # Compatible with plain text data training
|
627 |
+
fake_images=torch.zeros(1,3,224,224).to(
|
628 |
+
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
629 |
+
image_embeds = self.visual(fake_images)
|
630 |
+
inputs_embeds = inputs_embeds + image_embeds.mean()*0
|
631 |
|
632 |
if batch_size <= 0:
|
633 |
raise ValueError("batch_size has to be defined and > 0")
|