Add print statements
Browse files- modeling_cogvlm.py +7 -7
modeling_cogvlm.py
CHANGED
@@ -241,7 +241,7 @@ class VisionExpertAttention(nn.Module):
|
|
241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
243 |
|
244 |
-
if print_values:
|
245 |
|
246 |
# torch.save(query_states, "query_states.pt")
|
247 |
# torch.save(key_states, "key_states.pt")
|
@@ -325,13 +325,13 @@ class CogVLMDecoderLayer(nn.Module):
|
|
325 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
326 |
residual = hidden_states
|
327 |
|
328 |
-
if print_values:
|
329 |
-
|
330 |
|
331 |
hidden_states = self.input_layernorm(hidden_states)
|
332 |
|
333 |
-
if print_values:
|
334 |
-
|
335 |
|
336 |
# Self Attention
|
337 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
@@ -345,8 +345,8 @@ class CogVLMDecoderLayer(nn.Module):
|
|
345 |
print_values=print_values,
|
346 |
)
|
347 |
|
348 |
-
if print_values:
|
349 |
-
|
350 |
|
351 |
hidden_states = residual + hidden_states
|
352 |
|
|
|
241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
243 |
|
244 |
+
# if print_values:
|
245 |
|
246 |
# torch.save(query_states, "query_states.pt")
|
247 |
# torch.save(key_states, "key_states.pt")
|
|
|
325 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
326 |
residual = hidden_states
|
327 |
|
328 |
+
# if print_values:
|
329 |
+
# print("Hidden states before RMS norm:", hidden_states[0, :3, :3])
|
330 |
|
331 |
hidden_states = self.input_layernorm(hidden_states)
|
332 |
|
333 |
+
# if print_values:
|
334 |
+
# print("Hidden states after RMS norm, before self attention:", hidden_states[0,:3,:3])
|
335 |
|
336 |
# Self Attention
|
337 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
345 |
print_values=print_values,
|
346 |
)
|
347 |
|
348 |
+
# if print_values:
|
349 |
+
# print("Hidden states after self attention:", hidden_states[0,:3,:3])
|
350 |
|
351 |
hidden_states = residual + hidden_states
|
352 |
|