maldv commited on
Commit
b59223f
1 Parent(s): dd8fb75

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. ztrain/__init__.py +0 -0
  2. ztrain/io.py +30 -0
  3. ztrain/model.py +39 -0
  4. ztrain/signal.py +79 -0
  5. ztrain/stats.py +30 -0
  6. ztrain/tensors.py +258 -0
  7. ztrain/util.py +37 -0
ztrain/__init__.py ADDED
File without changes
ztrain/io.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ztrain/io.py
2
+ # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
3
+
4
+ import os
5
+ from glob import glob
6
+
7
+ def flatten_index(model_paths : list[str], allow_list : list[str]):
8
+ flat = []
9
+ subtype = []
10
+ index = {}
11
+ ix = 0
12
+ for g in sorted(model_paths):
13
+ name = os.path.basename(g)
14
+ if name in allow_list:
15
+ index[name] = ix
16
+ flat.append(name)
17
+ if 'base' in g:
18
+ subtype.append('base')
19
+ elif 'instruct' in g:
20
+ subtype.append('instruct')
21
+ else:
22
+ subtype.append('other')
23
+ ix += 1
24
+ return index, flat, subtype
25
+
26
+ def list_for_path(path: str, include_folders: list[str], search: str = "/**/*") -> tuple[list[str], list[str], list[str], dict[str, int]]:
27
+ model_list = sorted([*[ f for f in glob(path + search)]])
28
+ group_idx, model_names, subtypes = flatten_index(model_list, include_folders)
29
+ groups = [[m] for m in model_names]
30
+ return model_names, subtypes, model_list, group_idx
ztrain/model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ztrain/model.py
2
+ # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
3
+
4
+ from collections import defaultdict
5
+ import re
6
+
7
+ def generate_merge_group(group_data : list, parents : list[int] = []):
8
+ # drill down until we find a list of strings, then yield it with a parent tree index
9
+ for i, g in enumerate(group_data):
10
+ if isinstance(g, list):
11
+ yield from generate_merge_group(g, parents + [i])
12
+ else:
13
+ yield g, parents + [i]
14
+
15
+ def merge_groups(group_data : list):
16
+ results = defaultdict(list)
17
+ for g, k in generate_merge_group(group_data):
18
+ key = tuple(k[:-1])
19
+ results[key].append(g)
20
+ return results
21
+
22
+ def get_layer_type(k : str) -> tuple[int, str, str, str]:
23
+ matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)\.(.+)")
24
+
25
+ m = matcher.match(k)
26
+ if m is not None:
27
+ return int(m.group(1)), m.group(2), m.group(3), m.group(4)
28
+ matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)")
29
+ if m is not None:
30
+ return int(m.group(1)), m.group(2), "", m.group(3)
31
+
32
+ if "model.norm.weight" == k:
33
+ return -1, "norm", "", "weight"
34
+ if "model.embed_tokens.weight" == k:
35
+ return -1, "embed_tokens", "", "weight"
36
+ if "lm_head.weight" == k:
37
+ return -1, "lm_head", "", "weight"
38
+ print(f"Unknown key {k}")
39
+ return -1, "unknown", "unknown", "unknown"
ztrain/signal.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ztrain/signal.py
2
+ # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
3
+
4
+ import torch
5
+
6
+ def gaussian_kernel(size, sigma=1.0):
7
+ """
8
+ Generates a 2D Gaussian kernel using PyTorch.
9
+
10
+ Parameters:
11
+ - size: The size of the kernel (an integer). It's recommended to use an odd number
12
+ to have a central pixel.
13
+ - sigma: The standard deviation of the Gaussian distribution.
14
+
15
+ Returns:
16
+ - A 2D PyTorch tensor representing the Gaussian kernel.
17
+ """
18
+ size = int(size) // 2
19
+ x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1))
20
+ g = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
21
+ return g / g.sum()
22
+
23
+ def laplacian_kernel(size, scale=1.0):
24
+ """
25
+ Creates a Laplacian kernel for edge detection with an adjustable size and scale factor.
26
+
27
+ Parameters:
28
+ - size: The size of the kernel (an integer). It's recommended to use an odd number
29
+ to ensure a central pixel.
30
+ - scale: A float that adjusts the intensity of the edge detection effect.
31
+
32
+ Returns:
33
+ - A 2D PyTorch tensor representing the scaled Laplacian kernel.
34
+ """
35
+ if size % 2 == 0:
36
+ raise ValueError("Size must be odd.")
37
+
38
+ # Initialize the kernel with zeros
39
+ kernel = torch.zeros((size, size), dtype=torch.float32)
40
+
41
+ # Set the center pixel
42
+ kernel[size // 2, size // 2] = -4.0
43
+
44
+ # Set the immediate neighbors
45
+ kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0
46
+ kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0
47
+
48
+ # For larger kernels, adjust the outer pixels (this simplistic approach might need refinement for larger sizes)
49
+ if size > 3:
50
+ for i in range(size):
51
+ for j in range(size):
52
+ if i == 0 or i == size - 1 or j == 0 or j == size - 1:
53
+ kernel[i, j] = 1.0
54
+
55
+ # Apply the scale factor
56
+ kernel *= scale
57
+
58
+ # Adjust the kernel so that its sum is 0
59
+ center = size // 2
60
+ kernel[center, center] = -torch.sum(kernel) + kernel[center, center]
61
+
62
+ return kernel
63
+
64
+ def fftshift(input):
65
+ """
66
+ Reorients the FFT output so the zero-frequency component is at the center.
67
+
68
+ Parameters:
69
+ - input: A 2D tensor representing the FFT output.
70
+
71
+ Returns:
72
+ - A 2D tensor with the zero-frequency component shifted to the center.
73
+ """
74
+ # For even dimensions, we split at dim_size // 2. For odd dimensions, we need to do (dim_size + 1) // 2
75
+ for dim in range(2): # assuming input is 2D
76
+ n = input.shape[dim]
77
+ half = (n + 1) // 2
78
+ input = torch.roll(input, shifts=half, dims=dim)
79
+ return input
ztrain/stats.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ztrain/stats.py
2
+ # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
3
+
4
+ import os
5
+ import torch
6
+ from typing import Optional
7
+
8
+ def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]:
9
+ if base is None:
10
+ rebuilt = delta
11
+ else:
12
+ rebuilt = base + delta
13
+ norm = rebuilt.norm().item()
14
+ if base is None:
15
+ cosine = 0
16
+ else:
17
+ cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item()
18
+ min = delta.min().item()
19
+ max = delta.max().item()
20
+ del rebuilt
21
+ return norm, cosine, min, max
22
+
23
+ def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]):
24
+ norm, cosine, min, max = gen_stats(m0, None)
25
+ print(f"Base Model {norm} {min} {max}")
26
+
27
+ for i, s in enumerate(stack):
28
+ model_name = os.path.basename(model_list[i])
29
+ norm, cosine, min, max = gen_stats(s, m0)
30
+ print(f"{model_name} {norm} {cosine} {min} {max}")
ztrain/tensors.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ztrain/tensors.py
2
+ # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
3
+
4
+ import torch
5
+ from typing import Generator, Tuple
6
+
7
+ def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]:
8
+ m1 = m1.to(torch.float32)
9
+ m1_norm = torch.norm(m1)
10
+ ratio = (norm / m1_norm).item()
11
+ m1 = m1 * ratio
12
+ return m1, norm.item(), ratio
13
+
14
+ def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float:
15
+ m1_norm = torch.norm(m1)
16
+ m2_norm = torch.norm(m2)
17
+ ratio = (m1_norm / m2_norm).item()
18
+ print(f"Norms {m1_norm} {m2_norm} {ratio}")
19
+ return ratio
20
+
21
+ def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
22
+ """
23
+ Merges two tensors using 2D Fourier transform interpolation.
24
+
25
+ Parameters:
26
+ - v0 (torch.Tensor): The first input tensor.
27
+ - v1 (torch.Tensor): The second input tensor.
28
+ - t (float): Interpolation parameter (0 <= t <= 1).
29
+
30
+ Returns:
31
+ - torch.Tensor: The tensor resulting from the interpolated inverse FFT.
32
+ """
33
+ v0 = v0.to("cuda:0")
34
+ v1 = v1.to("cuda:0")
35
+
36
+ # Ensure the input tensors are on the same device and dtype
37
+ if len(v0.shape) == 1:
38
+ fft_v0 = torch.fft.fft(v0)
39
+ fft_v1 = torch.fft.fft(v1)
40
+ result_fft = torch.zeros_like(fft_v0)
41
+
42
+ real_v0 = fft_v0.real
43
+ real_v1 = fft_v1.real
44
+ abs_real_v0 = real_v0.abs()
45
+ abs_real_v1 = real_v1.abs()
46
+
47
+ sign_mask = real_v0.sign() == real_v1.sign()
48
+ larger_values_mask = abs_real_v0 > abs_real_v1
49
+
50
+ result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
51
+ result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
52
+
53
+ imag_v0 = fft_v0.imag
54
+ imag_v1 = fft_v1.imag
55
+ abs_imag_v0 = imag_v0.abs()
56
+ abs_imag_v1 = imag_v1.abs()
57
+ larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
58
+
59
+ result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
60
+ result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
61
+
62
+ merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
63
+ del v0, v1, fft_v0, fft_v1, result_fft
64
+ return merged_tensor
65
+
66
+ # Perform the 2D FFT on both tensors
67
+ fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
68
+ fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
69
+
70
+ # Initialize the result FFT tensor
71
+ result_fft = torch.zeros_like(fft_v0)
72
+
73
+ # Compare real parts of the coefficients
74
+ real_v0 = fft_v0.real
75
+ real_v1 = fft_v1.real
76
+ abs_real_v0 = real_v0.abs()
77
+ abs_real_v1 = real_v1.abs()
78
+
79
+ # Create masks for where signs match and where they do not
80
+ sign_mask = real_v0.sign() == real_v1.sign()
81
+ larger_values_mask = abs_real_v0 > abs_real_v1
82
+
83
+ # Where signs match, interpolate; where signs do not match, take the larger by magnitude
84
+ result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
85
+ result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
86
+
87
+ del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
88
+
89
+ # Assuming the imaginary part should be treated similarly, adjust this if not
90
+ imag_v0 = fft_v0.imag
91
+ imag_v1 = fft_v1.imag
92
+ abs_imag_v0 = imag_v0.abs()
93
+ abs_imag_v1 = imag_v1.abs()
94
+ larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
95
+
96
+ result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
97
+ result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
98
+
99
+ del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
100
+
101
+ # Perform the inverse FFT to go back to the spatial domain
102
+ merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
103
+
104
+ del fft_v0, fft_v1, result_fft
105
+
106
+ return merged_tensor
107
+
108
+ def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor:
109
+ n = tensors.shape[0]
110
+ matrix = torch.zeros(n, n).to(store_device)
111
+ for i in range(n):
112
+ a = tensors[i].to(work_device)
113
+ for j in range(i + 1, n):
114
+ b = tensors[j].to(work_device)
115
+ matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item()
116
+ b.to(store_device)
117
+ a.to(store_device)
118
+ return matrix
119
+
120
+ def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]:
121
+ """
122
+ Generates tuples of indices and their corresponding least correlation coefficient
123
+ from a given correlation matrix, ensuring that once an index is used, it is no longer
124
+ considered in future tuples.
125
+
126
+ Args:
127
+ correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix.
128
+
129
+ Yields:
130
+ Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient
131
+ of the least correlated pairs in the matrix.
132
+ """
133
+ n = correlation_tensor.size(0)
134
+ # Create a mask to exclude diagonal and already processed elements
135
+ mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
136
+
137
+ while torch.any(mask):
138
+ # Apply mask to get relevant correlations
139
+ valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf')))
140
+
141
+ # Find the minimum non-zero absolute correlation
142
+ min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')]))
143
+
144
+ # Locate the indices with the minimum correlation
145
+ min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True)
146
+ if len(min_indices[0]) == 0:
147
+ break
148
+
149
+ # Yield the first index pair (greedy approach) along with the correlation coefficient
150
+ x, y = min_indices[0][0].item(), min_indices[1][0].item()
151
+ coefficient = correlation_tensor[x, y].item() # Extract the actual correlation value
152
+ yield (x, y, coefficient)
153
+
154
+ # Mask out the entire row and column for both indices
155
+ mask[x, :] = False
156
+ mask[:, x] = False
157
+ mask[y, :] = False
158
+ mask[:, y] = False
159
+
160
+
161
+ def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]:
162
+ """
163
+ Merges two tensors using 2D Fourier transform interpolation.
164
+
165
+ Parameters:
166
+ - v0 (torch.Tensor): The first input tensor.
167
+ - v1 (torch.Tensor): The second input tensor.
168
+ - t (float): Interpolation parameter (0 <= t <= 1).
169
+
170
+ Returns:
171
+ - torch.Tensor: The tensor resulting from the interpolated inverse FFT.
172
+ """
173
+ v0 = v0.to("cuda:0")
174
+ v1 = v1.to("cuda:0")
175
+
176
+ # Calculate norms of each tensor
177
+ norm_v0_t = v0.norm()
178
+ norm_v1_t = v1.norm()
179
+
180
+ # Scale tensors by their norms
181
+ v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0
182
+ v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1
183
+
184
+ norm_v0 = norm_v0_t.item()
185
+ norm_v1 = norm_v1_t.item()
186
+ del norm_v0_t, norm_v1_t
187
+
188
+ # Ensure the input tensors are on the same device and dtype
189
+ if len(v0.shape) == 1:
190
+ fft_v0 = torch.fft.fft(v0)
191
+ fft_v1 = torch.fft.fft(v1)
192
+ result_fft = torch.zeros_like(fft_v0)
193
+
194
+ real_v0 = fft_v0.real
195
+ real_v1 = fft_v1.real
196
+ abs_real_v0 = real_v0.abs()
197
+ abs_real_v1 = real_v1.abs()
198
+
199
+ sign_mask = real_v0.sign() == real_v1.sign()
200
+ larger_values_mask = abs_real_v0 > abs_real_v1
201
+
202
+ result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
203
+ result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
204
+
205
+ imag_v0 = fft_v0.imag
206
+ imag_v1 = fft_v1.imag
207
+ abs_imag_v0 = imag_v0.abs()
208
+ abs_imag_v1 = imag_v1.abs()
209
+ larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
210
+
211
+ result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
212
+ result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
213
+
214
+ merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
215
+ del v0, v1, fft_v0, fft_v1, result_fft
216
+ return merged_tensor, norm_v0, norm_v1
217
+
218
+ # Perform the 2D FFT on both tensors
219
+ fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
220
+ fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
221
+
222
+ # Initialize the result FFT tensor
223
+ result_fft = torch.zeros_like(fft_v0)
224
+
225
+ # Compare real parts of the coefficients
226
+ real_v0 = fft_v0.real
227
+ real_v1 = fft_v1.real
228
+ abs_real_v0 = real_v0.abs()
229
+ abs_real_v1 = real_v1.abs()
230
+
231
+ # Create masks for where signs match and where they do not
232
+ sign_mask = real_v0.sign() == real_v1.sign()
233
+ larger_values_mask = abs_real_v0 > abs_real_v1
234
+
235
+ # Where signs match, interpolate; where signs do not match, take the larger by magnitude
236
+ result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
237
+ result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
238
+
239
+ del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
240
+
241
+ # Assuming the imaginary part should be treated similarly, adjust this if not
242
+ imag_v0 = fft_v0.imag
243
+ imag_v1 = fft_v1.imag
244
+ abs_imag_v0 = imag_v0.abs()
245
+ abs_imag_v1 = imag_v1.abs()
246
+ larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
247
+
248
+ result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
249
+ result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
250
+
251
+ del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
252
+
253
+ # Perform the inverse FFT to go back to the spatial domain
254
+ merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
255
+
256
+ del fft_v0, fft_v1, result_fft
257
+
258
+ return merged_tensor, norm_v0, norm_v1
ztrain/util.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ztrain/util.py
2
+ # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
3
+
4
+ import contextlib
5
+ import torch
6
+
7
+
8
+ @contextlib.contextmanager
9
+ def cuda_memory_profiler(display : str = True):
10
+ """
11
+ A context manager for profiling CUDA memory usage in PyTorch.
12
+ """
13
+ if display is False:
14
+ yield
15
+ return
16
+
17
+ if not torch.cuda.is_available():
18
+ print("CUDA is not available, skipping memory profiling")
19
+ yield
20
+ return
21
+
22
+ torch.cuda.reset_peak_memory_stats()
23
+ torch.cuda.synchronize()
24
+ start_memory = torch.cuda.memory_allocated()
25
+
26
+ try:
27
+ yield
28
+ finally:
29
+ torch.cuda.synchronize()
30
+ end_memory = torch.cuda.memory_allocated()
31
+ print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / (1024 ** 2):.2f} MB")
32
+ print(f"Memory allocated at start: {start_memory / (1024 ** 2):.2f} MB")
33
+ print(f"Memory allocated at end: {end_memory / (1024 ** 2):.2f} MB")
34
+ print(f"Net memory change: {(end_memory - start_memory) / (1024 ** 2):.2f} MB")
35
+
36
+ def get_device():
37
+ return torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")