|
""" |
|
Functions in this file are courtesty of @ashen-sensored on GitHub - thankyou so much! <3 |
|
|
|
Used to merge DreamSim LoRA weights into the base ViT models manually, so we don't need |
|
to use an ancient version of PeFT that is no longer supported (and kind of broken) |
|
""" |
|
import logging |
|
from os import PathLike |
|
from pathlib import Path |
|
|
|
import torch |
|
from safetensors.torch import load_file |
|
from torch import Tensor, nn |
|
|
|
from .model import DreamsimModel |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@torch.no_grad() |
|
def calculate_merged_weight( |
|
lora_a: Tensor, |
|
lora_b: Tensor, |
|
base: Tensor, |
|
scale: float, |
|
qkv_switches: list[bool], |
|
) -> Tensor: |
|
n_switches = len(qkv_switches) |
|
n_groups = sum(qkv_switches) |
|
|
|
qkv_mask = torch.tensor(qkv_switches, dtype=torch.bool).reshape(len(qkv_switches), -1) |
|
qkv_mask = qkv_mask.broadcast_to((-1, base.shape[0] // n_switches)).reshape(-1) |
|
|
|
lora_b = lora_b.squeeze() |
|
delta_w = base.new_zeros(lora_b.shape[0], base.shape[1]) |
|
|
|
grp_in_ch = lora_a.shape[0] // n_groups |
|
grp_out_ch = lora_b.shape[0] // n_groups |
|
for i in range(n_groups): |
|
islice = slice(i * grp_in_ch, (i + 1) * grp_in_ch) |
|
oslice = slice(i * grp_out_ch, (i + 1) * grp_out_ch) |
|
delta_w[oslice, :] = lora_b[oslice, :] @ lora_a[islice, :] |
|
|
|
delta_w_full = base.new_zeros(base.shape) |
|
delta_w_full[qkv_mask, :] = delta_w |
|
|
|
merged = base + scale * delta_w_full |
|
return merged.to(base) |
|
|
|
|
|
@torch.no_grad() |
|
def merge_dreamsim_lora( |
|
base_model: nn.Module, |
|
lora_path: PathLike, |
|
torch_device: torch.device | str = torch.device("cpu"), |
|
): |
|
lora_path = Path(lora_path) |
|
|
|
base_model = base_model.eval().requires_grad_(False).to(torch_device) |
|
|
|
|
|
if lora_path.suffix.lower() in [".pt", ".pth", ".bin"]: |
|
lora_sd = torch.load(lora_path, map_location=torch_device, weights_only=True) |
|
elif lora_path.suffix.lower() == ".safetensors": |
|
lora_sd = load_file(lora_path) |
|
else: |
|
raise ValueError(f"Unsupported file extension '{lora_path.suffix}'") |
|
|
|
|
|
group_prefix = "base_model.model.base_model.model.model." |
|
|
|
group_weights = {k.replace(group_prefix, ""): v for k, v in lora_sd.items() if k.startswith(group_prefix)} |
|
|
|
group_layers = set([k.rsplit(".", 2)[0] for k in group_weights.keys()]) |
|
|
|
base_weights = base_model.state_dict() |
|
for key in [x for x in base_weights.keys() if "attn.qkv.weight" in x]: |
|
param_name = key.rsplit(".", 1)[0] |
|
if param_name not in group_layers: |
|
logger.warning(f"QKV param '{param_name}' not found in lora weights") |
|
continue |
|
new_weight = calculate_merged_weight( |
|
group_weights[f"{param_name}.lora_A.weight"], |
|
group_weights[f"{param_name}.lora_B.weight"], |
|
base_weights[key], |
|
0.5 / 16, |
|
[True, False, True], |
|
) |
|
base_weights[key] = new_weight |
|
|
|
base_model.load_state_dict(base_weights) |
|
return base_model.requires_grad_(False) |
|
|
|
|
|
def remap_clip(state_dict: dict[str, Tensor], variant: str) -> dict[str, Tensor]: |
|
"""Remap keys from the original DreamSim checkpoint to match new model structure.""" |
|
|
|
def prepend_extractor(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: |
|
if variant.endswith("single"): |
|
return {f"extractor.{k}": v for k, v in state_dict.items()} |
|
return state_dict |
|
|
|
if "clip" not in variant: |
|
return prepend_extractor(state_dict) |
|
|
|
if "patch_embed.proj.bias" in state_dict: |
|
_ = state_dict.pop("patch_embed.proj.bias", None) |
|
if "pos_drop.weight" in state_dict: |
|
state_dict["norm_pre.weight"] = state_dict.pop("pos_drop.weight") |
|
state_dict["norm_pre.bias"] = state_dict.pop("pos_drop.bias") |
|
if "head.weight" in state_dict and "head.bias" not in state_dict: |
|
state_dict["head.bias"] = torch.zeros(state_dict["head.weight"].shape[0]) |
|
|
|
return prepend_extractor(state_dict) |
|
|
|
|
|
def convert_dreamsim_single( |
|
ckpt_path: PathLike, |
|
variant: str, |
|
ensemble: bool = False, |
|
) -> DreamsimModel: |
|
ckpt_path = Path(ckpt_path) |
|
if ckpt_path.exists(): |
|
if ckpt_path.is_dir(): |
|
ckpt_path = ckpt_path.joinpath("ensemble" if ensemble else variant) |
|
ckpt_path = ckpt_path.joinpath(f"{variant}_merged.safetensors") |
|
|
|
|
|
patch_size = 16 |
|
layer_norm_eps = 1e-6 |
|
pre_norm = False |
|
act_layer = "gelu" |
|
|
|
match variant: |
|
case "open_clip_vitb16" | "open_clip_vitb32" | "clip_vitb16" | "clip_vitb32": |
|
patch_size = 32 if "b32" in variant else 16 |
|
layer_norm_eps = 1e-5 |
|
pre_norm = True |
|
img_mean = (0.48145466, 0.4578275, 0.40821073) |
|
img_std = (0.26862954, 0.26130258, 0.27577711) |
|
act_layer = "quick_gelu" if variant.startswith("clip_") else "gelu" |
|
case "dino_vitb16": |
|
img_mean = (0.485, 0.456, 0.406) |
|
img_std = (0.229, 0.224, 0.225) |
|
case _: |
|
raise NotImplementedError(f"Unsupported model variant '{variant}'") |
|
|
|
model: DreamsimModel = DreamsimModel( |
|
image_size=224, |
|
patch_size=patch_size, |
|
layer_norm_eps=layer_norm_eps, |
|
pre_norm=pre_norm, |
|
act_layer=act_layer, |
|
img_mean=img_mean, |
|
img_std=img_std, |
|
) |
|
state_dict = load_file(ckpt_path, device="cpu") |
|
state_dict = remap_clip(state_dict) |
|
model.extractor.load_state_dict(state_dict) |
|
return model |
|
|