BUAADreamer commited on
Commit
030f3b8
1 Parent(s): 5318b89

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +19 -0
README.md CHANGED
@@ -28,6 +28,25 @@ from PIL import Image
28
  import torch
29
  from transformers import AutoProcessor, AutoModelForVision2Seq
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  model_id = "BUAADreamer/Yi-VL-6B-hf"
32
 
33
  messages = [
 
28
  import torch
29
  from transformers import AutoProcessor, AutoModelForVision2Seq
30
 
31
+ class LlavaMultiModalProjectorYiVL(nn.Module):
32
+ def __init__(self, config: "LlavaConfig"):
33
+ super().__init__()
34
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
35
+ self.linear_2 = nn.LayerNorm(config.text_config.hidden_size, bias=True)
36
+ self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
37
+ self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True)
38
+ self.act = nn.GELU()
39
+
40
+ def forward(self, image_features):
41
+ hidden_states = self.linear_1(image_features)
42
+ hidden_states = self.linear_2(hidden_states)
43
+ hidden_states = self.act(hidden_states)
44
+ hidden_states = self.linear_3(hidden_states)
45
+ hidden_states = self.linear_4(hidden_states)
46
+ return hidden_states
47
+
48
+ transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorYiVL
49
+
50
  model_id = "BUAADreamer/Yi-VL-6B-hf"
51
 
52
  messages = [