yairschiff commited on
Commit
b948b05
1 Parent(s): 5bc94d4

Enable mambav2 compat

Browse files
Files changed (1) hide show
  1. modeling_caduceus.py +29 -16
modeling_caduceus.py CHANGED
@@ -2,21 +2,29 @@
2
 
3
  """
4
 
 
5
  import math
6
  from functools import partial
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
10
- from mamba_ssm.modules.mamba_simple import Mamba, Block
 
 
 
 
11
  from torch import nn
12
  from torch.nn import functional as F
13
  from transformers import PreTrainedModel
14
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
15
 
16
  try:
17
- from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
  except ImportError:
19
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
 
 
 
20
 
21
  from .configuration_caduceus import CaduceusConfig
22
  from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
@@ -54,13 +62,24 @@ def create_block(
54
  nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
55
  )
56
  block_cls = RCPSMambaBlock if rcps else Block
57
- block = block_cls(
58
- d_model,
59
- mixer_cls,
60
- norm_cls=norm_cls,
61
- fused_add_norm=fused_add_norm,
62
- residual_in_fp32=residual_in_fp32,
63
- )
 
 
 
 
 
 
 
 
 
 
 
64
  block.layer_idx = layer_idx
65
  return block
66
 
@@ -497,12 +516,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
497
 
498
  # Initialize weights and apply final processing
499
  self.post_init()
500
- self.init_scorer()
501
-
502
- def init_scorer(self, initializer_range=0.02):
503
- initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
504
- if self.config.initializer_cfg is not None else initializer_range
505
- self.score.weight.data.normal_(std=initializer_range)
506
 
507
  def get_input_embeddings(self):
508
  return self.caduceus.backbone.embeddings.word_embeddings
 
2
 
3
  """
4
 
5
+ import inspect
6
  import math
7
  from functools import partial
8
  from typing import Optional, Tuple, Union
9
 
10
  import torch
11
+ from mamba_ssm.modules.mamba_simple import Mamba
12
+ try:
13
+ from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
14
+ except ImportError:
15
+ from mamba_ssm.modules.block import Block # mambav2 file structure
16
  from torch import nn
17
  from torch.nn import functional as F
18
  from transformers import PreTrainedModel
19
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
20
 
21
  try:
22
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
23
  except ImportError:
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
 
29
  from .configuration_caduceus import CaduceusConfig
30
  from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
 
62
  nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
63
  )
64
  block_cls = RCPSMambaBlock if rcps else Block
65
+ # mambav2 compatibility
66
+ if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
67
+ block = block_cls(
68
+ d_model,
69
+ mixer_cls,
70
+ mlp_cls=nn.Identity,
71
+ norm_cls=norm_cls,
72
+ fused_add_norm=fused_add_norm,
73
+ residual_in_fp32=residual_in_fp32,
74
+ )
75
+ else:
76
+ block = block_cls(
77
+ d_model,
78
+ mixer_cls,
79
+ norm_cls=norm_cls,
80
+ fused_add_norm=fused_add_norm,
81
+ residual_in_fp32=residual_in_fp32,
82
+ )
83
  block.layer_idx = layer_idx
84
  return block
85
 
 
516
 
517
  # Initialize weights and apply final processing
518
  self.post_init()
 
 
 
 
 
 
519
 
520
  def get_input_embeddings(self):
521
  return self.caduceus.backbone.embeddings.word_embeddings