# ztrain/tensors.py # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted import torch from typing import Generator, Tuple def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]: m1 = m1.to(torch.float32) m1_norm = torch.norm(m1) ratio = (norm / m1_norm).item() m1 = m1 * ratio return m1, norm.item(), ratio def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float: m1_norm = torch.norm(m1) m2_norm = torch.norm(m2) ratio = (m1_norm / m2_norm).item() print(f"Norms {m1_norm} {m2_norm} {ratio}") return ratio def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor: """ Merges two tensors using 2D Fourier transform interpolation. Parameters: - v0 (torch.Tensor): The first input tensor. - v1 (torch.Tensor): The second input tensor. - t (float): Interpolation parameter (0 <= t <= 1). Returns: - torch.Tensor: The tensor resulting from the interpolated inverse FFT. """ v0 = v0.to("cuda:0") v1 = v1.to("cuda:0") # Ensure the input tensors are on the same device and dtype if len(v0.shape) == 1: fft_v0 = torch.fft.fft(v0) fft_v1 = torch.fft.fft(v1) result_fft = torch.zeros_like(fft_v0) real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part del v0, v1, fft_v0, fft_v1, result_fft return merged_tensor # Perform the 2D FFT on both tensors fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) # Initialize the result FFT tensor result_fft = torch.zeros_like(fft_v0) # Compare real parts of the coefficients real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() # Create masks for where signs match and where they do not sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 # Where signs match, interpolate; where signs do not match, take the larger by magnitude result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask # Assuming the imaginary part should be treated similarly, adjust this if not imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask # Perform the inverse FFT to go back to the spatial domain merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part del fft_v0, fft_v1, result_fft return merged_tensor def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor: n = tensors.shape[0] matrix = torch.zeros(n, n).to(store_device) for i in range(n): a = tensors[i].to(work_device) for j in range(i + 1, n): b = tensors[j].to(work_device) matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item() b.to(store_device) a.to(store_device) return matrix def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]: """ Generates tuples of indices and their corresponding least correlation coefficient from a given correlation matrix, ensuring that once an index is used, it is no longer considered in future tuples. Args: correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix. Yields: Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient of the least correlated pairs in the matrix. """ n = correlation_tensor.size(0) # Create a mask to exclude diagonal and already processed elements mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1) while torch.any(mask): # Apply mask to get relevant correlations valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf'))) # Find the minimum non-zero absolute correlation min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')])) # Locate the indices with the minimum correlation min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True) if len(min_indices[0]) == 0: break # Yield the first index pair (greedy approach) along with the correlation coefficient x, y = min_indices[0][0].item(), min_indices[1][0].item() coefficient = correlation_tensor[x, y].item() # Extract the actual correlation value yield (x, y, coefficient) # Mask out the entire row and column for both indices mask[x, :] = False mask[:, x] = False mask[y, :] = False mask[:, y] = False def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]: """ Merges two tensors using 2D Fourier transform interpolation. Parameters: - v0 (torch.Tensor): The first input tensor. - v1 (torch.Tensor): The second input tensor. - t (float): Interpolation parameter (0 <= t <= 1). Returns: - torch.Tensor: The tensor resulting from the interpolated inverse FFT. """ v0 = v0.to("cuda:0") v1 = v1.to("cuda:0") # Calculate norms of each tensor norm_v0_t = v0.norm() norm_v1_t = v1.norm() # Scale tensors by their norms v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0 v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1 norm_v0 = norm_v0_t.item() norm_v1 = norm_v1_t.item() del norm_v0_t, norm_v1_t # Ensure the input tensors are on the same device and dtype if len(v0.shape) == 1: fft_v0 = torch.fft.fft(v0) fft_v1 = torch.fft.fft(v1) result_fft = torch.zeros_like(fft_v0) real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part del v0, v1, fft_v0, fft_v1, result_fft return merged_tensor, norm_v0, norm_v1 # Perform the 2D FFT on both tensors fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) # Initialize the result FFT tensor result_fft = torch.zeros_like(fft_v0) # Compare real parts of the coefficients real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() # Create masks for where signs match and where they do not sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 # Where signs match, interpolate; where signs do not match, take the larger by magnitude result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask # Assuming the imaginary part should be treated similarly, adjust this if not imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask # Perform the inverse FFT to go back to the spatial domain merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part del fft_v0, fft_v1, result_fft return merged_tensor, norm_v0, norm_v1