Fix small typo in JambaDecoder

#24
Files changed (1) hide show
  1. modeling_jamba.py +1 -1
modeling_jamba.py CHANGED
@@ -1053,7 +1053,7 @@ class JambaMambaMixer(nn.Module):
1053
  ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
1054
  scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
1055
  scan_outputs.append(scan_output[:, :, 0])
1056
- scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
1057
  scan_output = scan_output + (hidden_states * self.D[None, :, None])
1058
  scan_output = (scan_output * self.act(gate))
1059
 
 
1053
  ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
1054
  scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
1055
  scan_outputs.append(scan_output[:, :, 0])
1056
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len]
1057
  scan_output = scan_output + (hidden_states * self.D[None, :, None])
1058
  scan_output = (scan_output * self.act(gate))
1059