Fix small typo in JambaDecoder
#24
by
mber
- opened
- modeling_jamba.py +1 -1
modeling_jamba.py
CHANGED
@@ -1053,7 +1053,7 @@ class JambaMambaMixer(nn.Module):
|
|
1053 |
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
|
1054 |
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
|
1055 |
scan_outputs.append(scan_output[:, :, 0])
|
1056 |
-
scan_output = torch.stack(scan_outputs, dim=-1) # [batch,
|
1057 |
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
1058 |
scan_output = (scan_output * self.act(gate))
|
1059 |
|
|
|
1053 |
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
|
1054 |
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
|
1055 |
scan_outputs.append(scan_output[:, :, 0])
|
1056 |
+
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len]
|
1057 |
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
1058 |
scan_output = (scan_output * self.act(gate))
|
1059 |
|