""" ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era @misc{chen2023designing, title={Designing Scalable Vison Models in the Vision-Language Era}, author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, year={2023}, archivePrefix={arXiv}, primaryClass={cs.CV} } Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin Modifications and timm support by Jieneng Chen 2023 Adapted from timm codebase, thanks! """ from functools import partial from typing import List, Tuple from dataclasses import dataclass, replace from typing import Callable, Optional, Union, Tuple, List, Sequence import math, time from torch.jit import Final import torch import torch.nn as nn import torch.nn.functional as F import timm from torch.utils.checkpoint import checkpoint from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ from timm.layers import to_2tuple, DropPath, Format, trunc_normal_ from timm.layers.norm_act import _create_act from timm.models._registry import register_model from timm.models._manipulate import named_apply, checkpoint_seq from timm.models._builder import build_model_with_cfg from timm.models.vision_transformer import get_act_layer, Type, LayerType, Mlp, Block, PatchEmbed, VisionTransformer, checkpoint_filter_fn, get_init_weights_vit, init_weights_vit_timm, _load_weights import logging from collections import OrderedDict @dataclass class VitConvCfg: expand_ratio: float = 4.0 expand_output: bool = True # calculate expansion channels from output (vs input chs) kernel_size: int = 3 group_size: int = 1 # 1 == depthwise pre_norm_act: bool = False # activation after pre-norm stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' act_layer: str = 'gelu' # stem & stage 1234 norm_layer: str = '' norm_layer_cl: str = '' norm_eps: Optional[float] = None down_shortcut: Optional[bool] = True mlp: str = 'mlp' def __post_init__(self): use_mbconv = True if not self.norm_layer: self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' if not self.norm_layer_cl and not use_mbconv: self.norm_layer_cl = 'layernorm' if self.norm_eps is None: self.norm_eps = 1e-5 if use_mbconv else 1e-6 self.downsample_pool_type = self.downsample_pool_type or self.pool_type @dataclass class VitCfg: embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) stem_width: int = 64 conv_cfg: VitConvCfg = VitConvCfg() weight_init: str = 'vit_eff' head_type: str = "" stem_type: str = "stem" def _init_conv(module, name, scheme=''): if isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) if module.bias is not None: nn.init.zeros_(module.bias) class Stem(nn.Module): def __init__( self, in_chs: int, out_chs: int, act_layer: str = 'gelu', norm_layer: str = 'layernorm2d', norm_eps: float = 1e-6, bias: bool = True, ): super().__init__() self.grad_checkpointing=False norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) self.out_chs = out_chs self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) self.norm1 = norm_act_layer(out_chs) self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) named_apply(_init_conv, self) def forward(self, x): if self.grad_checkpointing: x = checkpoint(self.conv1, x) x = self.norm1(x) x = checkpoint(self.conv2, x) else: x = self.conv1(x) x = self.norm1(x) x = self.conv2(x) return x class Downsample2d(nn.Module): def __init__( self, dim: int, dim_out: int, bias: bool = True, ): super().__init__() self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) if dim != dim_out: self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv else: self.expand = nn.Identity() def forward(self, x): x = self.pool(x) x = self.expand(x) return x class StridedConv(nn.Module): """ downsample 2d as well """ def __init__( self, kernel_size=3, stride=2, padding=1, in_chans=3, embed_dim=768, ): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) self.norm = norm_layer(in_chans) def forward(self, x): x = self.norm(x) x = self.proj(x) return x class MbConvLNBlock(nn.Module): def __init__( self, in_chs: int, out_chs: int, stride: int = 1, drop_path: float = 0., kernel_size: int = 3, norm_layer: str = 'layernorm2d', norm_eps: float = 1e-6, act_layer: str = 'gelu', expand_ratio: float = 4.0, ): super(MbConvLNBlock, self).__init__() self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs mid_chs = make_divisible(out_chs * expand_ratio) prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) if stride == 2: self.shortcut = Downsample2d(in_chs, out_chs, bias=True) elif in_chs != out_chs: self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) else: self.shortcut = nn.Identity() self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) self.down = nn.Identity() self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) self.act1 = _create_act(act_layer, inplace=True) self.act2 = _create_act(act_layer, inplace=True) self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def init_weights(self, scheme=''): named_apply(partial(_init_conv, scheme=scheme), self) def forward(self, x): shortcut = self.shortcut(x) x = self.pre_norm(x) x = self.down(x) # nn.Identity() # 1x1 expansion conv & act x = self.conv1_1x1(x) x = self.act1(x) # (strided) depthwise 3x3 conv & act x = self.conv2_kxk(x) x = self.act2(x) # 1x1 linear projection to output width x = self.conv3_1x1(x) x = self.drop_path(x) + shortcut return x class MbConvStages(nn.Module): """ stage 1 and stage 2 of ViTamin: MBConv-LN blocks """ def __init__( self, cfg: VitCfg, img_size: Union[int, Tuple[int, int]] = 224, # place holder in_chans: int = 3, ): super().__init__() self.grad_checkpointing = False self.stem = Stem( in_chs=in_chans, out_chs=cfg.stem_width, ) stages = [] self.num_stages = len(cfg.embed_dim) for s, dim in enumerate(cfg.embed_dim[:2]): blocks = [] stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width for d in range(cfg.depths[s]): blocks += [MbConvLNBlock( in_chs = stage_in_chs if d==0 else dim, out_chs = dim, stride = 2 if d == 0 else 1, )] blocks = nn.Sequential(*blocks) stages += [blocks] self.stages = nn.ModuleList(stages) self.pool = StridedConv( stride=2, in_chans=cfg.embed_dim[1], embed_dim=cfg.embed_dim[2] ) def forward(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): for stage in self.stages: x = checkpoint_seq(stage, x) x = checkpoint(self.pool, x) else: for stage in self.stages: x = stage(x) x = self.pool(x) return x class GeGluMlp(nn.Module): def __init__( self, in_features, hidden_features, act_layer = None, drop = 0.0, ): super().__init__() norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) self.norm = norm_layer(in_features) self.act = nn.GELU() self.w0 = nn.Linear(in_features, hidden_features) self.w1 = nn.Linear(in_features, hidden_features) self.w2 = nn.Linear(hidden_features, in_features) def forward(self, x): x = self.norm(x) x = self.act(self.w0(x)) * self.w1(x) x = self.w2(x) return x class HybridEmbed(nn.Module): """ Extract feature map from stage 1-2, flatten, project to embedding dim. """ def __init__( self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=1024, bias=True, dynamic_img_pad=False, ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.backbone = backbone if feature_size is None: feature_size = img_size[0] // 16 feature_size = to_2tuple(feature_size) if hasattr(self.backbone, 'feature_info'): feature_dim = self.backbone.feature_info.channels()[-1] elif hasattr(self.backbone, 'num_features'): feature_dim = self.backbone.num_features else: feature_dim = embed_dim assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Identity() def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.proj(x) x = x.flatten(2).transpose(1, 2) return x class ViTamin(nn.Module): """ hack timm VisionTransformer """ dynamic_img_size: Final[bool] def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, init_values: Optional[float] = None, class_token: bool = True, no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0., pos_drop_rate: float = 0., patch_drop_rate: float = 0., proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init = '', fix_init: bool = False, embed_layer: Callable = PatchEmbed, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, is_pos_embed: bool = True ) -> None: """ Args: img_size: Input image size. patch_size: Patch size. in_chans: Number of image input channels. num_classes: Mumber of classes for classification head. global_pool: Type of global pooling for final sequence (default: 'token'). embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: Enable bias for qkv projections if True. init_values: Layer-scale init values (layer-scale enabled if not None). class_token: Use class token. no_embed_class: Don't include position embeddings for class (or reg) tokens. reg_tokens: Number of register tokens. fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer: Patch embedding layer. norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. """ super().__init__() assert global_pool in ('', 'avg', 'token', 'map') assert class_token or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens self.num_reg_tokens = reg_tokens self.has_class_token = class_token self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False self.is_pos_embed = is_pos_embed embed_args = {} if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) # stage_1_2 = MbConvStages(cfg=VitCfg( # embed_dim=(160, 320, 1024), # depths=(2, 4, 1), # stem_width=160, # conv_cfg = VitConvCfg( # norm_layer='layernorm2d', # norm_eps=1e-6, # ), # head_type='1d', # ), # ) # self.patch_embed = HybridEmbed( # stage_1_2, # img_size=img_size, # patch_size=1, # in_chans=in_chans, # embed_dim=embed_dim, # bias=not pre_norm, # dynamic_img_pad=dynamic_img_pad, # **embed_args,) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None if self.is_pos_embed: embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens, ) else: self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, init_values=init_values, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head if global_pool == 'map': self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, ) else: self.attn_pool = None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) if fix_init: self.fix_init_weight() def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. if self.is_pos_embed: trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix=''): _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self): if self.is_pos_embed: return {'pos_embed', 'cls_token', 'dist_token'} else: return {'cls_token', 'dist_token'} @torch.jit.ignore def group_matcher(self, coarse=False): return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable self.patch_embed.backbone.stem.grad_checkpointing = enable # disable https://blog.csdn.net/lhx526080338/article/details/127894671?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-1-127894671-blog-125562110.235^v38^pc_relevant_anti_t3_base&spm=1001.2101.3001.4242.2&utm_relevant_index=4 self.patch_embed.backbone.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.head def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x): if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + self.pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.pos_embed return self.pos_drop(x) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.is_pos_embed: x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: if self.attn_pool is not None: x = self.attn_pool(x) elif self.global_pool == 'avg': x = x[:, self.num_prefix_tokens:].mean(dim=1) elif self.global_pool: x = x[:, 0] # class token x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') return build_model_with_cfg( ViTamin, # ViTamin variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs, ) def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) @register_model def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: stage_1_2 = MbConvStages(cfg=VitCfg( embed_dim=(64, 128, 384), depths=(2, 4, 1), stem_width=64, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) return model @register_model def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: stage_1_2 = MbConvStages(cfg=VitCfg( embed_dim=(128, 256, 768), depths=(2, 4, 1), stem_width=128, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) return model @register_model def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: stage_1_2 = MbConvStages(cfg=VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) return model @register_model def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, conv_cfg = VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), head_type='1d', ), ) model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model def count_params(model: nn.Module): return sum([m.numel() for m in model.parameters()]) def count_stage_params(model: nn.Module, prefix='none'): collections = [] for name, m in model.named_parameters(): print(name) if name.startswith(prefix): collections.append(m.numel()) return sum(collections) if __name__ == "__main__": model = timm.create_model('vitamin_large', num_classes=10).cuda() # x = torch.rand([2,3,224,224]).cuda() check_keys(model)