# ztrain/stats.py | |
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted | |
import os | |
import torch | |
from typing import Optional | |
def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]: | |
if base is None: | |
rebuilt = delta | |
else: | |
rebuilt = base + delta | |
norm = rebuilt.norm().item() | |
if base is None: | |
cosine = 0 | |
else: | |
cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item() | |
min = delta.min().item() | |
max = delta.max().item() | |
del rebuilt | |
return norm, cosine, min, max | |
def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]): | |
norm, cosine, min, max = gen_stats(m0, None) | |
print(f"Base Model {norm} {min} {max}") | |
for i, s in enumerate(stack): | |
model_name = os.path.basename(model_list[i]) | |
norm, cosine, min, max = gen_stats(s, m0) | |
print(f"{model_name} {norm} {cosine} {min} {max}") | |