|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
from safetensors.torch import save_file |
|
|
|
files = os.listdir() |
|
|
|
|
|
models = [] |
|
safeTensors = [] |
|
for path, subdirs, files in os.walk(os.path.abspath(os.getcwd())): |
|
for name in files: |
|
if name.lower().endswith('.ckpt'): |
|
models.append(os.path.join(path, name)) |
|
if name.lower().endswith('.safetensors'): |
|
safeTensors.append(os.path.join(path, name)) |
|
|
|
if len(models) == 0: |
|
print('\033[91m> No .ckpt files found in this directory ({}).\033[0m'.format(os.path.abspath(os.getcwd()))) |
|
input('> Press enter to exit... ') |
|
exit() |
|
print(f"\n\033[92m> Found {len(models)} .ckpt files to convert.\033[0m") |
|
for model in models: |
|
print(str(models.index(model)+1) +": "+ model.split("\\")[-1]) |
|
|
|
input("> Press enter to continue... ") |
|
print("\n") |
|
|
|
for index in range(len(models)): |
|
f = models[index] |
|
modelName = f.split("\\")[-1] |
|
tensorName = f"{modelName.replace('.ckpt', '')}.safetensors" |
|
fn = f"{f.replace('.ckpt', '')}.safetensors" |
|
|
|
if fn in safeTensors: |
|
|
|
print(f"\033[33m\n> Skipping {modelName}, as {tensorName} already exists.\033[0m") |
|
continue |
|
|
|
print(f'\n> Loading {modelName} ({index+1}/{len(models)})...') |
|
|
|
try: |
|
with torch.no_grad(): |
|
map_location = torch.device('cpu') |
|
weights = torch.load(f, map_location=map_location) |
|
|
|
|
|
|
|
fn = f"{f.replace('.ckpt', '')}.safetensors" |
|
print(f'Saving {tensorName}...') |
|
save_file(weights, fn) |
|
except Exception as ex: |
|
print(f'ERROR converting {modelName}: {ex}') |
|
|
|
print("\n\033[92mDone!\033[0m") |
|
input("> Press enter to exit... ") |
|
|