{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n", "#### SD2.x系付属の.yamlも併せて変換します\n", "#### オプションでfp16として保存できます" ], "metadata": { "id": "fAIY_GORNEYa" } }, { "cell_type": "markdown", "source": [ "以下のコードを上から順番に両方とも実行" ], "metadata": { "id": "OnuCk_wNLM_D" } }, { "cell_type": "code", "source": [ "from google.colab import drive \n", "drive.mount(\"/content/drive\")" ], "metadata": { "id": "liEiK8Iioscq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!pip install torch safetensors\n", "!pip install pytorch-lightning\n", "!pip install wget" ], "metadata": { "id": "pXr7oNJzwwgU" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行" ], "metadata": { "id": "7Ils-K70k15Y" } }, { "cell_type": "code", "source": [ "#@title モデルをダウンロード\n", "#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n", "#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n", "#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n", "import shutil\n", "import urllib.parse\n", "import urllib.request\n", "import wget\n", "\n", "models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n", "models = [m.strip() for m in models.split(\",\")]\n", "for model in models:\n", " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n", " wget.download(model)\n", " elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\")):\n", " shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n", " else:\n", " print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")" ], "metadata": { "cellView": "form", "id": "4vd3A09AxJE0" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title モデルを変換\n", "#@markdown 変換するモデルをカンマ区切りで任意個指定
\n", "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される\n", "import os\n", "import glob\n", "import torch\n", "import safetensors.torch\n", "from functools import partial\n", "\n", "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n", "as_fp16 = True #@param {type:\"boolean\"}\n", "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n", "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n", "\n", "def convert_yaml(file_name):\n", " with open(file_name) as f:\n", " yaml = f.read()\n", " if save_directly_to_Google_Drive:\n", " os.chdir(\"/content/drive/MyDrive\")\n", " is_safe = save_type == \".safetensors\"\n", " yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n", " if as_fp16:\n", " yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n", " file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n", " with open(file_name, mode=\"w\") as f:\n", " f.write(yaml)\n", " os.chdir(\"/content\")\n", "\n", "if models == \"\":\n", " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n", "else:\n", " models = [m.strip() for m in models.split(\",\")]\n", "\n", "for model in models:\n", " model_name, model_ext = os.path.splitext(model)\n", " if model_ext == \".yaml\":\n", " convert_yaml(model)\n", " elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n", " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n", " else:\n", " load_model = partial(safetensors.torch.load_file, device=\"cpu\") if model_ext == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))\n", " save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n", " # convert model\n", " with torch.no_grad():\n", " weights = load_model(model)\n", " if \"state_dict\" in weights:\n", " weights = weights[\"state_dict\"]\n", " if as_fp16:\n", " model_name = model_name + \"-fp16\"\n", " for key in weights.keys():\n", " weights[key] = weights[key].half()\n", " if save_directly_to_Google_Drive:\n", " os.chdir(\"/content/drive/MyDrive\")\n", " save_model(weights, model_name + save_type)\n", " os.chdir(\"/content\")\n", " del weights\n", "\n", "!reset" ], "metadata": { "id": "9OmSG98HxJg2", "cellView": "form" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n", "\n", "指定方法はモデルと同じです" ], "metadata": { "id": "SWTFKmGFLec6" } }, { "cell_type": "markdown", "source": [ "メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n", "\n", "標準では10GBまでのモデルを変換できます" ], "metadata": { "id": "0SUK6Alv2ItS" } }, { "cell_type": "markdown", "source": [ "モデルのリンク集: https://huggingface.co/models?other=stable-diffusion 等から好きなモデルを選ぼう" ], "metadata": { "id": "yaLq5Nqe6an6" } } ] }