Upload folder using huggingface_hub
Browse files- ztrain/__init__.py +0 -0
- ztrain/io.py +30 -0
- ztrain/model.py +39 -0
- ztrain/signal.py +79 -0
- ztrain/stats.py +30 -0
- ztrain/tensors.py +258 -0
- ztrain/util.py +37 -0
ztrain/__init__.py
ADDED
File without changes
|
ztrain/io.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ztrain/io.py
|
2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
3 |
+
|
4 |
+
import os
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
def flatten_index(model_paths : list[str], allow_list : list[str]):
|
8 |
+
flat = []
|
9 |
+
subtype = []
|
10 |
+
index = {}
|
11 |
+
ix = 0
|
12 |
+
for g in sorted(model_paths):
|
13 |
+
name = os.path.basename(g)
|
14 |
+
if name in allow_list:
|
15 |
+
index[name] = ix
|
16 |
+
flat.append(name)
|
17 |
+
if 'base' in g:
|
18 |
+
subtype.append('base')
|
19 |
+
elif 'instruct' in g:
|
20 |
+
subtype.append('instruct')
|
21 |
+
else:
|
22 |
+
subtype.append('other')
|
23 |
+
ix += 1
|
24 |
+
return index, flat, subtype
|
25 |
+
|
26 |
+
def list_for_path(path: str, include_folders: list[str], search: str = "/**/*") -> tuple[list[str], list[str], list[str], dict[str, int]]:
|
27 |
+
model_list = sorted([*[ f for f in glob(path + search)]])
|
28 |
+
group_idx, model_names, subtypes = flatten_index(model_list, include_folders)
|
29 |
+
groups = [[m] for m in model_names]
|
30 |
+
return model_names, subtypes, model_list, group_idx
|
ztrain/model.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ztrain/model.py
|
2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
3 |
+
|
4 |
+
from collections import defaultdict
|
5 |
+
import re
|
6 |
+
|
7 |
+
def generate_merge_group(group_data : list, parents : list[int] = []):
|
8 |
+
# drill down until we find a list of strings, then yield it with a parent tree index
|
9 |
+
for i, g in enumerate(group_data):
|
10 |
+
if isinstance(g, list):
|
11 |
+
yield from generate_merge_group(g, parents + [i])
|
12 |
+
else:
|
13 |
+
yield g, parents + [i]
|
14 |
+
|
15 |
+
def merge_groups(group_data : list):
|
16 |
+
results = defaultdict(list)
|
17 |
+
for g, k in generate_merge_group(group_data):
|
18 |
+
key = tuple(k[:-1])
|
19 |
+
results[key].append(g)
|
20 |
+
return results
|
21 |
+
|
22 |
+
def get_layer_type(k : str) -> tuple[int, str, str, str]:
|
23 |
+
matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)\.(.+)")
|
24 |
+
|
25 |
+
m = matcher.match(k)
|
26 |
+
if m is not None:
|
27 |
+
return int(m.group(1)), m.group(2), m.group(3), m.group(4)
|
28 |
+
matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)")
|
29 |
+
if m is not None:
|
30 |
+
return int(m.group(1)), m.group(2), "", m.group(3)
|
31 |
+
|
32 |
+
if "model.norm.weight" == k:
|
33 |
+
return -1, "norm", "", "weight"
|
34 |
+
if "model.embed_tokens.weight" == k:
|
35 |
+
return -1, "embed_tokens", "", "weight"
|
36 |
+
if "lm_head.weight" == k:
|
37 |
+
return -1, "lm_head", "", "weight"
|
38 |
+
print(f"Unknown key {k}")
|
39 |
+
return -1, "unknown", "unknown", "unknown"
|
ztrain/signal.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ztrain/signal.py
|
2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def gaussian_kernel(size, sigma=1.0):
|
7 |
+
"""
|
8 |
+
Generates a 2D Gaussian kernel using PyTorch.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
- size: The size of the kernel (an integer). It's recommended to use an odd number
|
12 |
+
to have a central pixel.
|
13 |
+
- sigma: The standard deviation of the Gaussian distribution.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
- A 2D PyTorch tensor representing the Gaussian kernel.
|
17 |
+
"""
|
18 |
+
size = int(size) // 2
|
19 |
+
x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1))
|
20 |
+
g = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
|
21 |
+
return g / g.sum()
|
22 |
+
|
23 |
+
def laplacian_kernel(size, scale=1.0):
|
24 |
+
"""
|
25 |
+
Creates a Laplacian kernel for edge detection with an adjustable size and scale factor.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
- size: The size of the kernel (an integer). It's recommended to use an odd number
|
29 |
+
to ensure a central pixel.
|
30 |
+
- scale: A float that adjusts the intensity of the edge detection effect.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
- A 2D PyTorch tensor representing the scaled Laplacian kernel.
|
34 |
+
"""
|
35 |
+
if size % 2 == 0:
|
36 |
+
raise ValueError("Size must be odd.")
|
37 |
+
|
38 |
+
# Initialize the kernel with zeros
|
39 |
+
kernel = torch.zeros((size, size), dtype=torch.float32)
|
40 |
+
|
41 |
+
# Set the center pixel
|
42 |
+
kernel[size // 2, size // 2] = -4.0
|
43 |
+
|
44 |
+
# Set the immediate neighbors
|
45 |
+
kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0
|
46 |
+
kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0
|
47 |
+
|
48 |
+
# For larger kernels, adjust the outer pixels (this simplistic approach might need refinement for larger sizes)
|
49 |
+
if size > 3:
|
50 |
+
for i in range(size):
|
51 |
+
for j in range(size):
|
52 |
+
if i == 0 or i == size - 1 or j == 0 or j == size - 1:
|
53 |
+
kernel[i, j] = 1.0
|
54 |
+
|
55 |
+
# Apply the scale factor
|
56 |
+
kernel *= scale
|
57 |
+
|
58 |
+
# Adjust the kernel so that its sum is 0
|
59 |
+
center = size // 2
|
60 |
+
kernel[center, center] = -torch.sum(kernel) + kernel[center, center]
|
61 |
+
|
62 |
+
return kernel
|
63 |
+
|
64 |
+
def fftshift(input):
|
65 |
+
"""
|
66 |
+
Reorients the FFT output so the zero-frequency component is at the center.
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
- input: A 2D tensor representing the FFT output.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
- A 2D tensor with the zero-frequency component shifted to the center.
|
73 |
+
"""
|
74 |
+
# For even dimensions, we split at dim_size // 2. For odd dimensions, we need to do (dim_size + 1) // 2
|
75 |
+
for dim in range(2): # assuming input is 2D
|
76 |
+
n = input.shape[dim]
|
77 |
+
half = (n + 1) // 2
|
78 |
+
input = torch.roll(input, shifts=half, dims=dim)
|
79 |
+
return input
|
ztrain/stats.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ztrain/stats.py
|
2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]:
|
9 |
+
if base is None:
|
10 |
+
rebuilt = delta
|
11 |
+
else:
|
12 |
+
rebuilt = base + delta
|
13 |
+
norm = rebuilt.norm().item()
|
14 |
+
if base is None:
|
15 |
+
cosine = 0
|
16 |
+
else:
|
17 |
+
cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item()
|
18 |
+
min = delta.min().item()
|
19 |
+
max = delta.max().item()
|
20 |
+
del rebuilt
|
21 |
+
return norm, cosine, min, max
|
22 |
+
|
23 |
+
def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]):
|
24 |
+
norm, cosine, min, max = gen_stats(m0, None)
|
25 |
+
print(f"Base Model {norm} {min} {max}")
|
26 |
+
|
27 |
+
for i, s in enumerate(stack):
|
28 |
+
model_name = os.path.basename(model_list[i])
|
29 |
+
norm, cosine, min, max = gen_stats(s, m0)
|
30 |
+
print(f"{model_name} {norm} {cosine} {min} {max}")
|
ztrain/tensors.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ztrain/tensors.py
|
2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from typing import Generator, Tuple
|
6 |
+
|
7 |
+
def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]:
|
8 |
+
m1 = m1.to(torch.float32)
|
9 |
+
m1_norm = torch.norm(m1)
|
10 |
+
ratio = (norm / m1_norm).item()
|
11 |
+
m1 = m1 * ratio
|
12 |
+
return m1, norm.item(), ratio
|
13 |
+
|
14 |
+
def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float:
|
15 |
+
m1_norm = torch.norm(m1)
|
16 |
+
m2_norm = torch.norm(m2)
|
17 |
+
ratio = (m1_norm / m2_norm).item()
|
18 |
+
print(f"Norms {m1_norm} {m2_norm} {ratio}")
|
19 |
+
return ratio
|
20 |
+
|
21 |
+
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Merges two tensors using 2D Fourier transform interpolation.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
- v0 (torch.Tensor): The first input tensor.
|
27 |
+
- v1 (torch.Tensor): The second input tensor.
|
28 |
+
- t (float): Interpolation parameter (0 <= t <= 1).
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
|
32 |
+
"""
|
33 |
+
v0 = v0.to("cuda:0")
|
34 |
+
v1 = v1.to("cuda:0")
|
35 |
+
|
36 |
+
# Ensure the input tensors are on the same device and dtype
|
37 |
+
if len(v0.shape) == 1:
|
38 |
+
fft_v0 = torch.fft.fft(v0)
|
39 |
+
fft_v1 = torch.fft.fft(v1)
|
40 |
+
result_fft = torch.zeros_like(fft_v0)
|
41 |
+
|
42 |
+
real_v0 = fft_v0.real
|
43 |
+
real_v1 = fft_v1.real
|
44 |
+
abs_real_v0 = real_v0.abs()
|
45 |
+
abs_real_v1 = real_v1.abs()
|
46 |
+
|
47 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
48 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
49 |
+
|
50 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
51 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
52 |
+
|
53 |
+
imag_v0 = fft_v0.imag
|
54 |
+
imag_v1 = fft_v1.imag
|
55 |
+
abs_imag_v0 = imag_v0.abs()
|
56 |
+
abs_imag_v1 = imag_v1.abs()
|
57 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
58 |
+
|
59 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
60 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
61 |
+
|
62 |
+
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
|
63 |
+
del v0, v1, fft_v0, fft_v1, result_fft
|
64 |
+
return merged_tensor
|
65 |
+
|
66 |
+
# Perform the 2D FFT on both tensors
|
67 |
+
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
|
68 |
+
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
|
69 |
+
|
70 |
+
# Initialize the result FFT tensor
|
71 |
+
result_fft = torch.zeros_like(fft_v0)
|
72 |
+
|
73 |
+
# Compare real parts of the coefficients
|
74 |
+
real_v0 = fft_v0.real
|
75 |
+
real_v1 = fft_v1.real
|
76 |
+
abs_real_v0 = real_v0.abs()
|
77 |
+
abs_real_v1 = real_v1.abs()
|
78 |
+
|
79 |
+
# Create masks for where signs match and where they do not
|
80 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
81 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
82 |
+
|
83 |
+
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
|
84 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
85 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
86 |
+
|
87 |
+
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
|
88 |
+
|
89 |
+
# Assuming the imaginary part should be treated similarly, adjust this if not
|
90 |
+
imag_v0 = fft_v0.imag
|
91 |
+
imag_v1 = fft_v1.imag
|
92 |
+
abs_imag_v0 = imag_v0.abs()
|
93 |
+
abs_imag_v1 = imag_v1.abs()
|
94 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
95 |
+
|
96 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
97 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
98 |
+
|
99 |
+
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
|
100 |
+
|
101 |
+
# Perform the inverse FFT to go back to the spatial domain
|
102 |
+
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
|
103 |
+
|
104 |
+
del fft_v0, fft_v1, result_fft
|
105 |
+
|
106 |
+
return merged_tensor
|
107 |
+
|
108 |
+
def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor:
|
109 |
+
n = tensors.shape[0]
|
110 |
+
matrix = torch.zeros(n, n).to(store_device)
|
111 |
+
for i in range(n):
|
112 |
+
a = tensors[i].to(work_device)
|
113 |
+
for j in range(i + 1, n):
|
114 |
+
b = tensors[j].to(work_device)
|
115 |
+
matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item()
|
116 |
+
b.to(store_device)
|
117 |
+
a.to(store_device)
|
118 |
+
return matrix
|
119 |
+
|
120 |
+
def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]:
|
121 |
+
"""
|
122 |
+
Generates tuples of indices and their corresponding least correlation coefficient
|
123 |
+
from a given correlation matrix, ensuring that once an index is used, it is no longer
|
124 |
+
considered in future tuples.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix.
|
128 |
+
|
129 |
+
Yields:
|
130 |
+
Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient
|
131 |
+
of the least correlated pairs in the matrix.
|
132 |
+
"""
|
133 |
+
n = correlation_tensor.size(0)
|
134 |
+
# Create a mask to exclude diagonal and already processed elements
|
135 |
+
mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
|
136 |
+
|
137 |
+
while torch.any(mask):
|
138 |
+
# Apply mask to get relevant correlations
|
139 |
+
valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf')))
|
140 |
+
|
141 |
+
# Find the minimum non-zero absolute correlation
|
142 |
+
min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')]))
|
143 |
+
|
144 |
+
# Locate the indices with the minimum correlation
|
145 |
+
min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True)
|
146 |
+
if len(min_indices[0]) == 0:
|
147 |
+
break
|
148 |
+
|
149 |
+
# Yield the first index pair (greedy approach) along with the correlation coefficient
|
150 |
+
x, y = min_indices[0][0].item(), min_indices[1][0].item()
|
151 |
+
coefficient = correlation_tensor[x, y].item() # Extract the actual correlation value
|
152 |
+
yield (x, y, coefficient)
|
153 |
+
|
154 |
+
# Mask out the entire row and column for both indices
|
155 |
+
mask[x, :] = False
|
156 |
+
mask[:, x] = False
|
157 |
+
mask[y, :] = False
|
158 |
+
mask[:, y] = False
|
159 |
+
|
160 |
+
|
161 |
+
def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]:
|
162 |
+
"""
|
163 |
+
Merges two tensors using 2D Fourier transform interpolation.
|
164 |
+
|
165 |
+
Parameters:
|
166 |
+
- v0 (torch.Tensor): The first input tensor.
|
167 |
+
- v1 (torch.Tensor): The second input tensor.
|
168 |
+
- t (float): Interpolation parameter (0 <= t <= 1).
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
|
172 |
+
"""
|
173 |
+
v0 = v0.to("cuda:0")
|
174 |
+
v1 = v1.to("cuda:0")
|
175 |
+
|
176 |
+
# Calculate norms of each tensor
|
177 |
+
norm_v0_t = v0.norm()
|
178 |
+
norm_v1_t = v1.norm()
|
179 |
+
|
180 |
+
# Scale tensors by their norms
|
181 |
+
v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0
|
182 |
+
v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1
|
183 |
+
|
184 |
+
norm_v0 = norm_v0_t.item()
|
185 |
+
norm_v1 = norm_v1_t.item()
|
186 |
+
del norm_v0_t, norm_v1_t
|
187 |
+
|
188 |
+
# Ensure the input tensors are on the same device and dtype
|
189 |
+
if len(v0.shape) == 1:
|
190 |
+
fft_v0 = torch.fft.fft(v0)
|
191 |
+
fft_v1 = torch.fft.fft(v1)
|
192 |
+
result_fft = torch.zeros_like(fft_v0)
|
193 |
+
|
194 |
+
real_v0 = fft_v0.real
|
195 |
+
real_v1 = fft_v1.real
|
196 |
+
abs_real_v0 = real_v0.abs()
|
197 |
+
abs_real_v1 = real_v1.abs()
|
198 |
+
|
199 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
200 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
201 |
+
|
202 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
203 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
204 |
+
|
205 |
+
imag_v0 = fft_v0.imag
|
206 |
+
imag_v1 = fft_v1.imag
|
207 |
+
abs_imag_v0 = imag_v0.abs()
|
208 |
+
abs_imag_v1 = imag_v1.abs()
|
209 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
210 |
+
|
211 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
212 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
213 |
+
|
214 |
+
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
|
215 |
+
del v0, v1, fft_v0, fft_v1, result_fft
|
216 |
+
return merged_tensor, norm_v0, norm_v1
|
217 |
+
|
218 |
+
# Perform the 2D FFT on both tensors
|
219 |
+
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
|
220 |
+
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
|
221 |
+
|
222 |
+
# Initialize the result FFT tensor
|
223 |
+
result_fft = torch.zeros_like(fft_v0)
|
224 |
+
|
225 |
+
# Compare real parts of the coefficients
|
226 |
+
real_v0 = fft_v0.real
|
227 |
+
real_v1 = fft_v1.real
|
228 |
+
abs_real_v0 = real_v0.abs()
|
229 |
+
abs_real_v1 = real_v1.abs()
|
230 |
+
|
231 |
+
# Create masks for where signs match and where they do not
|
232 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
233 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
234 |
+
|
235 |
+
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
|
236 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
237 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
238 |
+
|
239 |
+
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
|
240 |
+
|
241 |
+
# Assuming the imaginary part should be treated similarly, adjust this if not
|
242 |
+
imag_v0 = fft_v0.imag
|
243 |
+
imag_v1 = fft_v1.imag
|
244 |
+
abs_imag_v0 = imag_v0.abs()
|
245 |
+
abs_imag_v1 = imag_v1.abs()
|
246 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
247 |
+
|
248 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
249 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
250 |
+
|
251 |
+
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
|
252 |
+
|
253 |
+
# Perform the inverse FFT to go back to the spatial domain
|
254 |
+
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
|
255 |
+
|
256 |
+
del fft_v0, fft_v1, result_fft
|
257 |
+
|
258 |
+
return merged_tensor, norm_v0, norm_v1
|
ztrain/util.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ztrain/util.py
|
2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
3 |
+
|
4 |
+
import contextlib
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
@contextlib.contextmanager
|
9 |
+
def cuda_memory_profiler(display : str = True):
|
10 |
+
"""
|
11 |
+
A context manager for profiling CUDA memory usage in PyTorch.
|
12 |
+
"""
|
13 |
+
if display is False:
|
14 |
+
yield
|
15 |
+
return
|
16 |
+
|
17 |
+
if not torch.cuda.is_available():
|
18 |
+
print("CUDA is not available, skipping memory profiling")
|
19 |
+
yield
|
20 |
+
return
|
21 |
+
|
22 |
+
torch.cuda.reset_peak_memory_stats()
|
23 |
+
torch.cuda.synchronize()
|
24 |
+
start_memory = torch.cuda.memory_allocated()
|
25 |
+
|
26 |
+
try:
|
27 |
+
yield
|
28 |
+
finally:
|
29 |
+
torch.cuda.synchronize()
|
30 |
+
end_memory = torch.cuda.memory_allocated()
|
31 |
+
print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / (1024 ** 2):.2f} MB")
|
32 |
+
print(f"Memory allocated at start: {start_memory / (1024 ** 2):.2f} MB")
|
33 |
+
print(f"Memory allocated at end: {end_memory / (1024 ** 2):.2f} MB")
|
34 |
+
print(f"Net memory change: {(end_memory - start_memory) / (1024 ** 2):.2f} MB")
|
35 |
+
|
36 |
+
def get_device():
|
37 |
+
return torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|