Spaces:
Runtime error
Runtime error
nikunjkdtechnoland
commited on
Commit
•
063372b
1
Parent(s):
526f250
init commit some files
Browse files- .gitignore +110 -0
- app.py +14 -0
- data/__init__.py +0 -0
- iopaint/__init__.py +23 -0
- iopaint/__main__.py +4 -0
- iopaint/api.py +396 -0
- iopaint/batch_processing.py +127 -0
- iopaint/benchmark.py +109 -0
- iopaint/file_manager/__init__.py +1 -0
- iopaint/model/__init__.py +37 -0
- iopaint/model/anytext/__init__.py +0 -0
- iopaint/model/anytext/anytext_model.py +73 -0
- iopaint/model/anytext/anytext_pipeline.py +403 -0
- iopaint/model/anytext/anytext_sd15.yaml +99 -0
- iopaint/model/anytext/cldm/__init__.py +0 -0
- iopaint/model/anytext/ldm/__init__.py +0 -0
- iopaint/model/anytext/ldm/models/__init__.py +0 -0
- iopaint/model/anytext/ldm/models/autoencoder.py +218 -0
- iopaint/model/anytext/ldm/models/diffusion/__init__.py +0 -0
- iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py +1 -0
- iopaint/model/anytext/ldm/modules/__init__.py +0 -0
- iopaint/model/anytext/ldm/modules/attention.py +360 -0
- iopaint/model/anytext/ldm/modules/diffusionmodules/__init__.py +0 -0
- iopaint/model/anytext/ldm/modules/distributions/__init__.py +0 -0
- iopaint/model/anytext/ldm/modules/encoders/__init__.py +0 -0
- iopaint/model/anytext/ocr_recog/__init__.py +0 -0
- iopaint/model/base.py +418 -0
- iopaint/model/helper/__init__.py +0 -0
- iopaint/model/original_sd_configs/__init__.py +19 -0
- iopaint/model/power_paint/__init__.py +0 -0
- iopaint/plugins/__init__.py +74 -0
- iopaint/plugins/anime_seg.py +462 -0
- iopaint/plugins/base_plugin.py +30 -0
- iopaint/plugins/segment_anything/__init__.py +14 -0
- iopaint/plugins/segment_anything/modeling/__init__.py +11 -0
- iopaint/plugins/segment_anything/utils/__init__.py +5 -0
- iopaint/tests/.gitignore +2 -0
- iopaint/tests/__init__.py +0 -0
- model/__init__.py +0 -0
- utils/__init__.py +0 -0
.gitignore
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Project ###
|
2 |
+
checkpoints/
|
3 |
+
pretrained-model/yolov8m-seg.pt
|
4 |
+
|
5 |
+
|
6 |
+
### Python ###
|
7 |
+
# Byte-compiled / optimized / DLL files
|
8 |
+
__pycache__/
|
9 |
+
*.py[cod]
|
10 |
+
*$py.class
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# pyenv
|
82 |
+
.python-version
|
83 |
+
|
84 |
+
# celery beat schedule file
|
85 |
+
celerybeat-schedule
|
86 |
+
|
87 |
+
# SageMath parsed files
|
88 |
+
*.sage.py
|
89 |
+
|
90 |
+
# Environments
|
91 |
+
.env
|
92 |
+
.venv
|
93 |
+
env/
|
94 |
+
venv/
|
95 |
+
ENV/
|
96 |
+
env.bak/
|
97 |
+
venv.bak/
|
98 |
+
|
99 |
+
# Spyder project settings
|
100 |
+
.spyderproject
|
101 |
+
.spyproject
|
102 |
+
|
103 |
+
# Rope project settings
|
104 |
+
.ropeproject
|
105 |
+
|
106 |
+
# mkdocs documentation
|
107 |
+
/site
|
108 |
+
|
109 |
+
# mypy
|
110 |
+
.mypy_cache/
|
app.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from only_gradio_server import process_images
|
3 |
+
|
4 |
+
# Create Gradio interface
|
5 |
+
iface = gr.Interface(fn=process_images,
|
6 |
+
inputs=[gr.Image(type='filepath', label='Input Image 1'),
|
7 |
+
gr.Image(type='filepath', label='Input Image 2', image_mode="RGBA"),
|
8 |
+
gr.Textbox(label='Replace Object Name')],
|
9 |
+
outputs='image',
|
10 |
+
title="Image Processing",
|
11 |
+
description="Object to Object Replacement")
|
12 |
+
|
13 |
+
# Launch Gradio interface
|
14 |
+
iface.launch()
|
data/__init__.py
ADDED
File without changes
|
iopaint/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
+
# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
|
5 |
+
os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
|
6 |
+
os.environ["LRU_CACHE_CAPACITY"] = "1"
|
7 |
+
# prevent CPU memory leak when run model on GPU
|
8 |
+
# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
|
9 |
+
# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
|
10 |
+
os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
|
11 |
+
|
12 |
+
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
warnings.simplefilter("ignore", UserWarning)
|
16 |
+
|
17 |
+
|
18 |
+
def entry_point():
|
19 |
+
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
20 |
+
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
21 |
+
from iopaint.cli import typer_app
|
22 |
+
|
23 |
+
typer_app()
|
iopaint/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from iopaint import entry_point
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
entry_point()
|
iopaint/api.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional, Dict, List
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import socketio
|
12 |
+
import torch
|
13 |
+
|
14 |
+
try:
|
15 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
16 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
17 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
18 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
19 |
+
except:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
import uvicorn
|
24 |
+
from PIL import Image
|
25 |
+
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
26 |
+
from fastapi.encoders import jsonable_encoder
|
27 |
+
from fastapi.exceptions import HTTPException
|
28 |
+
from fastapi.middleware.cors import CORSMiddleware
|
29 |
+
from fastapi.responses import JSONResponse, FileResponse, Response
|
30 |
+
from fastapi.staticfiles import StaticFiles
|
31 |
+
from loguru import logger
|
32 |
+
from socketio import AsyncServer
|
33 |
+
|
34 |
+
from iopaint.file_manager import FileManager
|
35 |
+
from iopaint.helper import (
|
36 |
+
load_img,
|
37 |
+
decode_base64_to_image,
|
38 |
+
pil_to_bytes,
|
39 |
+
numpy_to_bytes,
|
40 |
+
concat_alpha_channel,
|
41 |
+
gen_frontend_mask,
|
42 |
+
adjust_mask,
|
43 |
+
)
|
44 |
+
from iopaint.model.utils import torch_gc
|
45 |
+
from iopaint.model_manager import ModelManager
|
46 |
+
from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
|
47 |
+
from iopaint.plugins.base_plugin import BasePlugin
|
48 |
+
from iopaint.plugins.remove_bg import RemoveBG
|
49 |
+
from iopaint.schema import (
|
50 |
+
GenInfoResponse,
|
51 |
+
ApiConfig,
|
52 |
+
ServerConfigResponse,
|
53 |
+
SwitchModelRequest,
|
54 |
+
InpaintRequest,
|
55 |
+
RunPluginRequest,
|
56 |
+
SDSampler,
|
57 |
+
PluginInfo,
|
58 |
+
AdjustMaskRequest,
|
59 |
+
RemoveBGModel,
|
60 |
+
SwitchPluginModelRequest,
|
61 |
+
ModelInfo,
|
62 |
+
InteractiveSegModel,
|
63 |
+
RealESRGANModel,
|
64 |
+
)
|
65 |
+
|
66 |
+
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
67 |
+
WEB_APP_DIR = CURRENT_DIR / "web_app"
|
68 |
+
|
69 |
+
|
70 |
+
def api_middleware(app: FastAPI):
|
71 |
+
rich_available = False
|
72 |
+
try:
|
73 |
+
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
74 |
+
import anyio # importing just so it can be placed on silent list
|
75 |
+
import starlette # importing just so it can be placed on silent list
|
76 |
+
from rich.console import Console
|
77 |
+
|
78 |
+
console = Console()
|
79 |
+
rich_available = True
|
80 |
+
except Exception:
|
81 |
+
pass
|
82 |
+
|
83 |
+
def handle_exception(request: Request, e: Exception):
|
84 |
+
err = {
|
85 |
+
"error": type(e).__name__,
|
86 |
+
"detail": vars(e).get("detail", ""),
|
87 |
+
"body": vars(e).get("body", ""),
|
88 |
+
"errors": str(e),
|
89 |
+
}
|
90 |
+
if not isinstance(
|
91 |
+
e, HTTPException
|
92 |
+
): # do not print backtrace on known httpexceptions
|
93 |
+
message = f"API error: {request.method}: {request.url} {err}"
|
94 |
+
if rich_available:
|
95 |
+
print(message)
|
96 |
+
console.print_exception(
|
97 |
+
show_locals=True,
|
98 |
+
max_frames=2,
|
99 |
+
extra_lines=1,
|
100 |
+
suppress=[anyio, starlette],
|
101 |
+
word_wrap=False,
|
102 |
+
width=min([console.width, 200]),
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
traceback.print_exc()
|
106 |
+
return JSONResponse(
|
107 |
+
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
|
108 |
+
)
|
109 |
+
|
110 |
+
@app.middleware("http")
|
111 |
+
async def exception_handling(request: Request, call_next):
|
112 |
+
try:
|
113 |
+
return await call_next(request)
|
114 |
+
except Exception as e:
|
115 |
+
return handle_exception(request, e)
|
116 |
+
|
117 |
+
@app.exception_handler(Exception)
|
118 |
+
async def fastapi_exception_handler(request: Request, e: Exception):
|
119 |
+
return handle_exception(request, e)
|
120 |
+
|
121 |
+
@app.exception_handler(HTTPException)
|
122 |
+
async def http_exception_handler(request: Request, e: HTTPException):
|
123 |
+
return handle_exception(request, e)
|
124 |
+
|
125 |
+
cors_options = {
|
126 |
+
"allow_methods": ["*"],
|
127 |
+
"allow_headers": ["*"],
|
128 |
+
"allow_origins": ["*"],
|
129 |
+
"allow_credentials": True,
|
130 |
+
}
|
131 |
+
app.add_middleware(CORSMiddleware, **cors_options)
|
132 |
+
|
133 |
+
|
134 |
+
global_sio: AsyncServer = None
|
135 |
+
|
136 |
+
|
137 |
+
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
|
138 |
+
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
139 |
+
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
140 |
+
|
141 |
+
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
142 |
+
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
143 |
+
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
144 |
+
return {}
|
145 |
+
|
146 |
+
|
147 |
+
class Api:
|
148 |
+
def __init__(self, app: FastAPI, config: ApiConfig):
|
149 |
+
self.app = app
|
150 |
+
self.config = config
|
151 |
+
self.router = APIRouter()
|
152 |
+
self.queue_lock = threading.Lock()
|
153 |
+
api_middleware(self.app)
|
154 |
+
|
155 |
+
self.file_manager = self._build_file_manager()
|
156 |
+
self.plugins = self._build_plugins()
|
157 |
+
self.model_manager = self._build_model_manager()
|
158 |
+
|
159 |
+
# fmt: off
|
160 |
+
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
161 |
+
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
162 |
+
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
163 |
+
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
164 |
+
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
165 |
+
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
166 |
+
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
167 |
+
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
168 |
+
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
169 |
+
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
170 |
+
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
171 |
+
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
172 |
+
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
173 |
+
# fmt: on
|
174 |
+
|
175 |
+
global global_sio
|
176 |
+
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
177 |
+
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
178 |
+
self.app.mount("/ws", self.combined_asgi_app)
|
179 |
+
global_sio = self.sio
|
180 |
+
|
181 |
+
def add_api_route(self, path: str, endpoint, **kwargs):
|
182 |
+
return self.app.add_api_route(path, endpoint, **kwargs)
|
183 |
+
|
184 |
+
def api_save_image(self, file: UploadFile):
|
185 |
+
filename = file.filename
|
186 |
+
origin_image_bytes = file.file.read()
|
187 |
+
with open(self.config.output_dir / filename, "wb") as fw:
|
188 |
+
fw.write(origin_image_bytes)
|
189 |
+
|
190 |
+
def api_current_model(self) -> ModelInfo:
|
191 |
+
return self.model_manager.current_model
|
192 |
+
|
193 |
+
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
|
194 |
+
if req.name == self.model_manager.name:
|
195 |
+
return self.model_manager.current_model
|
196 |
+
self.model_manager.switch(req.name)
|
197 |
+
return self.model_manager.current_model
|
198 |
+
|
199 |
+
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
200 |
+
if req.plugin_name in self.plugins:
|
201 |
+
self.plugins[req.plugin_name].switch_model(req.model_name)
|
202 |
+
if req.plugin_name == RemoveBG.name:
|
203 |
+
self.config.remove_bg_model = req.model_name
|
204 |
+
if req.plugin_name == RealESRGANUpscaler.name:
|
205 |
+
self.config.realesrgan_model = req.model_name
|
206 |
+
if req.plugin_name == InteractiveSeg.name:
|
207 |
+
self.config.interactive_seg_model = req.model_name
|
208 |
+
torch_gc()
|
209 |
+
|
210 |
+
def api_server_config(self) -> ServerConfigResponse:
|
211 |
+
plugins = []
|
212 |
+
for it in self.plugins.values():
|
213 |
+
plugins.append(
|
214 |
+
PluginInfo(
|
215 |
+
name=it.name,
|
216 |
+
support_gen_image=it.support_gen_image,
|
217 |
+
support_gen_mask=it.support_gen_mask,
|
218 |
+
)
|
219 |
+
)
|
220 |
+
|
221 |
+
return ServerConfigResponse(
|
222 |
+
plugins=plugins,
|
223 |
+
modelInfos=self.model_manager.scan_models(),
|
224 |
+
removeBGModel=self.config.remove_bg_model,
|
225 |
+
removeBGModels=RemoveBGModel.values(),
|
226 |
+
realesrganModel=self.config.realesrgan_model,
|
227 |
+
realesrganModels=RealESRGANModel.values(),
|
228 |
+
interactiveSegModel=self.config.interactive_seg_model,
|
229 |
+
interactiveSegModels=InteractiveSegModel.values(),
|
230 |
+
enableFileManager=self.file_manager is not None,
|
231 |
+
enableAutoSaving=self.config.output_dir is not None,
|
232 |
+
enableControlnet=self.model_manager.enable_controlnet,
|
233 |
+
controlnetMethod=self.model_manager.controlnet_method,
|
234 |
+
disableModelSwitch=False,
|
235 |
+
isDesktop=False,
|
236 |
+
samplers=self.api_samplers(),
|
237 |
+
)
|
238 |
+
|
239 |
+
def api_input_image(self) -> FileResponse:
|
240 |
+
if self.config.input and self.config.input.is_file():
|
241 |
+
return FileResponse(self.config.input)
|
242 |
+
raise HTTPException(status_code=404, detail="Input image not found")
|
243 |
+
|
244 |
+
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
|
245 |
+
_, _, info = load_img(file.file.read(), return_info=True)
|
246 |
+
parts = info.get("parameters", "").split("Negative prompt: ")
|
247 |
+
prompt = parts[0].strip()
|
248 |
+
negative_prompt = ""
|
249 |
+
if len(parts) > 1:
|
250 |
+
negative_prompt = parts[1].split("\n")[0].strip()
|
251 |
+
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
252 |
+
|
253 |
+
def api_inpaint(self, req: InpaintRequest):
|
254 |
+
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
255 |
+
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
256 |
+
|
257 |
+
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
258 |
+
if image.shape[:2] != mask.shape[:2]:
|
259 |
+
raise HTTPException(
|
260 |
+
400,
|
261 |
+
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
262 |
+
)
|
263 |
+
|
264 |
+
if req.paint_by_example_example_image:
|
265 |
+
paint_by_example_image, _, _ = decode_base64_to_image(
|
266 |
+
req.paint_by_example_example_image
|
267 |
+
)
|
268 |
+
|
269 |
+
start = time.time()
|
270 |
+
rgb_np_img = self.model_manager(image, mask, req)
|
271 |
+
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
272 |
+
torch_gc()
|
273 |
+
|
274 |
+
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
275 |
+
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
276 |
+
|
277 |
+
ext = "png"
|
278 |
+
res_img_bytes = pil_to_bytes(
|
279 |
+
Image.fromarray(rgb_res),
|
280 |
+
ext=ext,
|
281 |
+
quality=self.config.quality,
|
282 |
+
infos=infos,
|
283 |
+
)
|
284 |
+
|
285 |
+
asyncio.run(self.sio.emit("diffusion_finish"))
|
286 |
+
|
287 |
+
return Response(
|
288 |
+
content=res_img_bytes,
|
289 |
+
media_type=f"image/{ext}",
|
290 |
+
headers={"X-Seed": str(req.sd_seed)},
|
291 |
+
)
|
292 |
+
|
293 |
+
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
294 |
+
ext = "png"
|
295 |
+
if req.name not in self.plugins:
|
296 |
+
raise HTTPException(status_code=422, detail="Plugin not found")
|
297 |
+
if not self.plugins[req.name].support_gen_image:
|
298 |
+
raise HTTPException(
|
299 |
+
status_code=422, detail="Plugin does not support output image"
|
300 |
+
)
|
301 |
+
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
302 |
+
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
303 |
+
torch_gc()
|
304 |
+
|
305 |
+
if bgr_or_rgba_np_img.shape[2] == 4:
|
306 |
+
rgba_np_img = bgr_or_rgba_np_img
|
307 |
+
else:
|
308 |
+
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
309 |
+
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
310 |
+
|
311 |
+
return Response(
|
312 |
+
content=pil_to_bytes(
|
313 |
+
Image.fromarray(rgba_np_img),
|
314 |
+
ext=ext,
|
315 |
+
quality=self.config.quality,
|
316 |
+
infos=infos,
|
317 |
+
),
|
318 |
+
media_type=f"image/{ext}",
|
319 |
+
)
|
320 |
+
|
321 |
+
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
322 |
+
if req.name not in self.plugins:
|
323 |
+
raise HTTPException(status_code=422, detail="Plugin not found")
|
324 |
+
if not self.plugins[req.name].support_gen_mask:
|
325 |
+
raise HTTPException(
|
326 |
+
status_code=422, detail="Plugin does not support output image"
|
327 |
+
)
|
328 |
+
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
329 |
+
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
330 |
+
torch_gc()
|
331 |
+
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
332 |
+
return Response(
|
333 |
+
content=numpy_to_bytes(res_mask, "png"),
|
334 |
+
media_type="image/png",
|
335 |
+
)
|
336 |
+
|
337 |
+
def api_samplers(self) -> List[str]:
|
338 |
+
return [member.value for member in SDSampler.__members__.values()]
|
339 |
+
|
340 |
+
def api_adjust_mask(self, req: AdjustMaskRequest):
|
341 |
+
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
342 |
+
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
343 |
+
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
344 |
+
|
345 |
+
def launch(self):
|
346 |
+
self.app.include_router(self.router)
|
347 |
+
uvicorn.run(
|
348 |
+
self.combined_asgi_app,
|
349 |
+
host=self.config.host,
|
350 |
+
port=self.config.port,
|
351 |
+
timeout_keep_alive=999999999,
|
352 |
+
)
|
353 |
+
|
354 |
+
def _build_file_manager(self) -> Optional[FileManager]:
|
355 |
+
if self.config.input and self.config.input.is_dir():
|
356 |
+
logger.info(
|
357 |
+
f"Input is directory, initialize file manager {self.config.input}"
|
358 |
+
)
|
359 |
+
|
360 |
+
return FileManager(
|
361 |
+
app=self.app,
|
362 |
+
input_dir=self.config.input,
|
363 |
+
output_dir=self.config.output_dir,
|
364 |
+
)
|
365 |
+
return None
|
366 |
+
|
367 |
+
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
368 |
+
return build_plugins(
|
369 |
+
self.config.enable_interactive_seg,
|
370 |
+
self.config.interactive_seg_model,
|
371 |
+
self.config.interactive_seg_device,
|
372 |
+
self.config.enable_remove_bg,
|
373 |
+
self.config.remove_bg_model,
|
374 |
+
self.config.enable_anime_seg,
|
375 |
+
self.config.enable_realesrgan,
|
376 |
+
self.config.realesrgan_device,
|
377 |
+
self.config.realesrgan_model,
|
378 |
+
self.config.enable_gfpgan,
|
379 |
+
self.config.gfpgan_device,
|
380 |
+
self.config.enable_restoreformer,
|
381 |
+
self.config.restoreformer_device,
|
382 |
+
self.config.no_half,
|
383 |
+
)
|
384 |
+
|
385 |
+
def _build_model_manager(self):
|
386 |
+
return ModelManager(
|
387 |
+
name=self.config.model,
|
388 |
+
device=torch.device(self.config.device),
|
389 |
+
no_half=self.config.no_half,
|
390 |
+
low_mem=self.config.low_mem,
|
391 |
+
disable_nsfw=self.config.disable_nsfw_checker,
|
392 |
+
sd_cpu_textencoder=self.config.cpu_textencoder,
|
393 |
+
local_files_only=self.config.local_files_only,
|
394 |
+
cpu_offload=self.config.cpu_offload,
|
395 |
+
callback=diffuser_callback,
|
396 |
+
)
|
iopaint/batch_processing.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import psutil
|
7 |
+
from PIL import Image
|
8 |
+
from loguru import logger
|
9 |
+
from rich.console import Console
|
10 |
+
from rich.progress import (
|
11 |
+
Progress,
|
12 |
+
SpinnerColumn,
|
13 |
+
TimeElapsedColumn,
|
14 |
+
MofNCompleteColumn,
|
15 |
+
TextColumn,
|
16 |
+
BarColumn,
|
17 |
+
TaskProgressColumn,
|
18 |
+
)
|
19 |
+
|
20 |
+
from iopaint.helper import pil_to_bytes
|
21 |
+
from iopaint.model.utils import torch_gc
|
22 |
+
from iopaint.model_manager import ModelManager
|
23 |
+
from iopaint.schema import InpaintRequest
|
24 |
+
|
25 |
+
|
26 |
+
def glob_images(path: Path) -> Dict[str, Path]:
|
27 |
+
# png/jpg/jpeg
|
28 |
+
if path.is_file():
|
29 |
+
return {path.stem: path}
|
30 |
+
elif path.is_dir():
|
31 |
+
res = {}
|
32 |
+
for it in path.glob("*.*"):
|
33 |
+
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
34 |
+
res[it.stem] = it
|
35 |
+
return res
|
36 |
+
|
37 |
+
|
38 |
+
def batch_inpaint(
|
39 |
+
model: str,
|
40 |
+
device,
|
41 |
+
image: Path,
|
42 |
+
mask: Path,
|
43 |
+
output: Path,
|
44 |
+
config: Optional[Path] = None,
|
45 |
+
concat: bool = False,
|
46 |
+
):
|
47 |
+
if image.is_dir() and output.is_file():
|
48 |
+
logger.error(
|
49 |
+
f"invalid --output: when image is a directory, output should be a directory"
|
50 |
+
)
|
51 |
+
exit(-1)
|
52 |
+
output.mkdir(parents=True, exist_ok=True)
|
53 |
+
|
54 |
+
image_paths = glob_images(image)
|
55 |
+
mask_paths = glob_images(mask)
|
56 |
+
if len(image_paths) == 0:
|
57 |
+
logger.error(f"invalid --image: empty image folder")
|
58 |
+
exit(-1)
|
59 |
+
if len(mask_paths) == 0:
|
60 |
+
logger.error(f"invalid --mask: empty mask folder")
|
61 |
+
exit(-1)
|
62 |
+
|
63 |
+
if config is None:
|
64 |
+
inpaint_request = InpaintRequest()
|
65 |
+
logger.info(f"Using default config: {inpaint_request}")
|
66 |
+
else:
|
67 |
+
with open(config, "r", encoding="utf-8") as f:
|
68 |
+
inpaint_request = InpaintRequest(**json.load(f))
|
69 |
+
|
70 |
+
model_manager = ModelManager(name=model, device=device)
|
71 |
+
first_mask = list(mask_paths.values())[0]
|
72 |
+
|
73 |
+
console = Console()
|
74 |
+
|
75 |
+
with Progress(
|
76 |
+
SpinnerColumn(),
|
77 |
+
TextColumn("[progress.description]{task.description}"),
|
78 |
+
BarColumn(),
|
79 |
+
TaskProgressColumn(),
|
80 |
+
MofNCompleteColumn(),
|
81 |
+
TimeElapsedColumn(),
|
82 |
+
console=console,
|
83 |
+
transient=False,
|
84 |
+
) as progress:
|
85 |
+
task = progress.add_task("Batch processing...", total=len(image_paths))
|
86 |
+
for stem, image_p in image_paths.items():
|
87 |
+
if stem not in mask_paths and mask.is_dir():
|
88 |
+
progress.log(f"mask for {image_p} not found")
|
89 |
+
progress.update(task, advance=1)
|
90 |
+
continue
|
91 |
+
mask_p = mask_paths.get(stem, first_mask)
|
92 |
+
|
93 |
+
infos = Image.open(image_p).info
|
94 |
+
|
95 |
+
img = cv2.imread(str(image_p))
|
96 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
97 |
+
mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
98 |
+
if mask_img.shape[:2] != img.shape[:2]:
|
99 |
+
progress.log(
|
100 |
+
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
|
101 |
+
)
|
102 |
+
mask_img = cv2.resize(
|
103 |
+
mask_img,
|
104 |
+
(img.shape[1], img.shape[0]),
|
105 |
+
interpolation=cv2.INTER_NEAREST,
|
106 |
+
)
|
107 |
+
mask_img[mask_img >= 127] = 255
|
108 |
+
mask_img[mask_img < 127] = 0
|
109 |
+
|
110 |
+
# bgr
|
111 |
+
inpaint_result = model_manager(img, mask_img, inpaint_request)
|
112 |
+
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
|
113 |
+
if concat:
|
114 |
+
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
|
115 |
+
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
|
116 |
+
|
117 |
+
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
|
118 |
+
save_p = output / f"{stem}.png"
|
119 |
+
with open(save_p, "wb") as fw:
|
120 |
+
fw.write(img_bytes)
|
121 |
+
|
122 |
+
progress.update(task, advance=1)
|
123 |
+
torch_gc()
|
124 |
+
# pid = psutil.Process().pid
|
125 |
+
# memory_info = psutil.Process(pid).memory_info()
|
126 |
+
# memory_in_mb = memory_info.rss / (1024 * 1024)
|
127 |
+
# print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")
|
iopaint/benchmark.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import nvidia_smi
|
9 |
+
import psutil
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from iopaint.model_manager import ModelManager
|
13 |
+
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
|
14 |
+
|
15 |
+
try:
|
16 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
17 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
18 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
19 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
20 |
+
except:
|
21 |
+
pass
|
22 |
+
|
23 |
+
NUM_THREADS = str(4)
|
24 |
+
|
25 |
+
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
26 |
+
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
27 |
+
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
28 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
29 |
+
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
30 |
+
if os.environ.get("CACHE_DIR"):
|
31 |
+
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
32 |
+
|
33 |
+
|
34 |
+
def run_model(model, size):
|
35 |
+
# RGB
|
36 |
+
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
37 |
+
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
38 |
+
|
39 |
+
config = InpaintRequest(
|
40 |
+
ldm_steps=2,
|
41 |
+
hd_strategy=HDStrategy.ORIGINAL,
|
42 |
+
hd_strategy_crop_margin=128,
|
43 |
+
hd_strategy_crop_trigger_size=128,
|
44 |
+
hd_strategy_resize_limit=128,
|
45 |
+
prompt="a fox is sitting on a bench",
|
46 |
+
sd_steps=5,
|
47 |
+
sd_sampler=SDSampler.ddim,
|
48 |
+
)
|
49 |
+
model(image, mask, config)
|
50 |
+
|
51 |
+
|
52 |
+
def benchmark(model, times: int, empty_cache: bool):
|
53 |
+
sizes = [(512, 512)]
|
54 |
+
|
55 |
+
nvidia_smi.nvmlInit()
|
56 |
+
device_id = 0
|
57 |
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
|
58 |
+
|
59 |
+
def format(metrics):
|
60 |
+
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
|
61 |
+
|
62 |
+
process = psutil.Process(os.getpid())
|
63 |
+
# 每个 size 给出显存和内存占用的指标
|
64 |
+
for size in sizes:
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
time_metrics = []
|
67 |
+
cpu_metrics = []
|
68 |
+
memory_metrics = []
|
69 |
+
gpu_memory_metrics = []
|
70 |
+
for _ in range(times):
|
71 |
+
start = time.time()
|
72 |
+
run_model(model, size)
|
73 |
+
torch.cuda.synchronize()
|
74 |
+
|
75 |
+
# cpu_metrics.append(process.cpu_percent())
|
76 |
+
time_metrics.append((time.time() - start) * 1000)
|
77 |
+
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
78 |
+
gpu_memory_metrics.append(
|
79 |
+
nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
|
80 |
+
)
|
81 |
+
|
82 |
+
print(f"size: {size}".center(80, "-"))
|
83 |
+
# print(f"cpu: {format(cpu_metrics)}")
|
84 |
+
print(f"latency: {format(time_metrics)}ms")
|
85 |
+
print(f"memory: {format(memory_metrics)} MB")
|
86 |
+
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
|
87 |
+
|
88 |
+
nvidia_smi.nvmlShutdown()
|
89 |
+
|
90 |
+
|
91 |
+
def get_args_parser():
|
92 |
+
parser = argparse.ArgumentParser()
|
93 |
+
parser.add_argument("--name")
|
94 |
+
parser.add_argument("--device", default="cuda", type=str)
|
95 |
+
parser.add_argument("--times", default=10, type=int)
|
96 |
+
parser.add_argument("--empty-cache", action="store_true")
|
97 |
+
return parser.parse_args()
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
args = get_args_parser()
|
102 |
+
device = torch.device(args.device)
|
103 |
+
model = ModelManager(
|
104 |
+
name=args.name,
|
105 |
+
device=device,
|
106 |
+
disable_nsfw=True,
|
107 |
+
sd_cpu_textencoder=True,
|
108 |
+
)
|
109 |
+
benchmark(model, args.times, args.empty_cache)
|
iopaint/file_manager/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .file_manager import FileManager
|
iopaint/model/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .anytext.anytext_model import AnyText
|
2 |
+
from .controlnet import ControlNet
|
3 |
+
from .fcf import FcF
|
4 |
+
from .instruct_pix2pix import InstructPix2Pix
|
5 |
+
from .kandinsky import Kandinsky22
|
6 |
+
from .lama import LaMa
|
7 |
+
from .ldm import LDM
|
8 |
+
from .manga import Manga
|
9 |
+
from .mat import MAT
|
10 |
+
from .mi_gan import MIGAN
|
11 |
+
from .opencv2 import OpenCV2
|
12 |
+
from .paint_by_example import PaintByExample
|
13 |
+
from .power_paint.power_paint import PowerPaint
|
14 |
+
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
|
15 |
+
from .sdxl import SDXL
|
16 |
+
from .zits import ZITS
|
17 |
+
|
18 |
+
models = {
|
19 |
+
LaMa.name: LaMa,
|
20 |
+
LDM.name: LDM,
|
21 |
+
ZITS.name: ZITS,
|
22 |
+
MAT.name: MAT,
|
23 |
+
FcF.name: FcF,
|
24 |
+
OpenCV2.name: OpenCV2,
|
25 |
+
Manga.name: Manga,
|
26 |
+
MIGAN.name: MIGAN,
|
27 |
+
SD15.name: SD15,
|
28 |
+
Anything4.name: Anything4,
|
29 |
+
RealisticVision14.name: RealisticVision14,
|
30 |
+
SD2.name: SD2,
|
31 |
+
PaintByExample.name: PaintByExample,
|
32 |
+
InstructPix2Pix.name: InstructPix2Pix,
|
33 |
+
Kandinsky22.name: Kandinsky22,
|
34 |
+
SDXL.name: SDXL,
|
35 |
+
PowerPaint.name: PowerPaint,
|
36 |
+
AnyText.name: AnyText,
|
37 |
+
}
|
iopaint/model/anytext/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/anytext_model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
from iopaint.const import ANYTEXT_NAME
|
5 |
+
from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
|
6 |
+
from iopaint.model.base import DiffusionInpaintModel
|
7 |
+
from iopaint.model.utils import get_torch_dtype, is_local_files_only
|
8 |
+
from iopaint.schema import InpaintRequest
|
9 |
+
|
10 |
+
|
11 |
+
class AnyText(DiffusionInpaintModel):
|
12 |
+
name = ANYTEXT_NAME
|
13 |
+
pad_mod = 64
|
14 |
+
is_erase_model = False
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def download(local_files_only=False):
|
18 |
+
hf_hub_download(
|
19 |
+
repo_id=ANYTEXT_NAME,
|
20 |
+
filename="model_index.json",
|
21 |
+
local_files_only=local_files_only,
|
22 |
+
)
|
23 |
+
ckpt_path = hf_hub_download(
|
24 |
+
repo_id=ANYTEXT_NAME,
|
25 |
+
filename="pytorch_model.fp16.safetensors",
|
26 |
+
local_files_only=local_files_only,
|
27 |
+
)
|
28 |
+
font_path = hf_hub_download(
|
29 |
+
repo_id=ANYTEXT_NAME,
|
30 |
+
filename="SourceHanSansSC-Medium.otf",
|
31 |
+
local_files_only=local_files_only,
|
32 |
+
)
|
33 |
+
return ckpt_path, font_path
|
34 |
+
|
35 |
+
def init_model(self, device, **kwargs):
|
36 |
+
local_files_only = is_local_files_only(**kwargs)
|
37 |
+
ckpt_path, font_path = self.download(local_files_only)
|
38 |
+
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
39 |
+
self.model = AnyTextPipeline(
|
40 |
+
ckpt_path=ckpt_path,
|
41 |
+
font_path=font_path,
|
42 |
+
device=device,
|
43 |
+
use_fp16=torch_dtype == torch.float16,
|
44 |
+
)
|
45 |
+
self.callback = kwargs.pop("callback", None)
|
46 |
+
|
47 |
+
def forward(self, image, mask, config: InpaintRequest):
|
48 |
+
"""Input image and output image have same size
|
49 |
+
image: [H, W, C] RGB
|
50 |
+
mask: [H, W, 1] 255 means area to inpainting
|
51 |
+
return: BGR IMAGE
|
52 |
+
"""
|
53 |
+
height, width = image.shape[:2]
|
54 |
+
mask = mask.astype("float32") / 255.0
|
55 |
+
masked_image = image * (1 - mask)
|
56 |
+
|
57 |
+
# list of rgb ndarray
|
58 |
+
results, rtn_code, rtn_warning = self.model(
|
59 |
+
image=image,
|
60 |
+
masked_image=masked_image,
|
61 |
+
prompt=config.prompt,
|
62 |
+
negative_prompt=config.negative_prompt,
|
63 |
+
num_inference_steps=config.sd_steps,
|
64 |
+
strength=config.sd_strength,
|
65 |
+
guidance_scale=config.sd_guidance_scale,
|
66 |
+
height=height,
|
67 |
+
width=width,
|
68 |
+
seed=config.sd_seed,
|
69 |
+
sort_priority="y",
|
70 |
+
callback=self.callback
|
71 |
+
)
|
72 |
+
inpainted_rgb_image = results[0][..., ::-1]
|
73 |
+
return inpainted_rgb_image
|
iopaint/model/anytext/anytext_pipeline.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
AnyText: Multilingual Visual Text Generation And Editing
|
3 |
+
Paper: https://arxiv.org/abs/2311.03054
|
4 |
+
Code: https://github.com/tyxsspa/AnyText
|
5 |
+
Copyright (c) Alibaba, Inc. and its affiliates.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from iopaint.model.utils import set_seed
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
|
13 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
14 |
+
import torch
|
15 |
+
import re
|
16 |
+
import numpy as np
|
17 |
+
import cv2
|
18 |
+
import einops
|
19 |
+
from PIL import ImageFont
|
20 |
+
from iopaint.model.anytext.cldm.model import create_model, load_state_dict
|
21 |
+
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
|
22 |
+
from iopaint.model.anytext.utils import (
|
23 |
+
check_channels,
|
24 |
+
draw_glyph,
|
25 |
+
draw_glyph2,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
BBOX_MAX_NUM = 8
|
30 |
+
PLACE_HOLDER = "*"
|
31 |
+
max_chars = 20
|
32 |
+
|
33 |
+
ANYTEXT_CFG = os.path.join(
|
34 |
+
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def check_limits(tensor):
|
39 |
+
float16_min = torch.finfo(torch.float16).min
|
40 |
+
float16_max = torch.finfo(torch.float16).max
|
41 |
+
|
42 |
+
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
|
43 |
+
is_below_min = (tensor < float16_min).any()
|
44 |
+
is_above_max = (tensor > float16_max).any()
|
45 |
+
|
46 |
+
return is_below_min or is_above_max
|
47 |
+
|
48 |
+
|
49 |
+
class AnyTextPipeline:
|
50 |
+
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
|
51 |
+
self.cfg_path = ANYTEXT_CFG
|
52 |
+
self.font_path = font_path
|
53 |
+
self.use_fp16 = use_fp16
|
54 |
+
self.device = device
|
55 |
+
|
56 |
+
self.font = ImageFont.truetype(font_path, size=60)
|
57 |
+
self.model = create_model(
|
58 |
+
self.cfg_path,
|
59 |
+
device=self.device,
|
60 |
+
use_fp16=self.use_fp16,
|
61 |
+
)
|
62 |
+
if self.use_fp16:
|
63 |
+
self.model = self.model.half()
|
64 |
+
if Path(ckpt_path).suffix == ".safetensors":
|
65 |
+
state_dict = load_file(ckpt_path, device="cpu")
|
66 |
+
else:
|
67 |
+
state_dict = load_state_dict(ckpt_path, location="cpu")
|
68 |
+
self.model.load_state_dict(state_dict, strict=False)
|
69 |
+
self.model = self.model.eval().to(self.device)
|
70 |
+
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
71 |
+
|
72 |
+
def __call__(
|
73 |
+
self,
|
74 |
+
prompt: str,
|
75 |
+
negative_prompt: str,
|
76 |
+
image: np.ndarray,
|
77 |
+
masked_image: np.ndarray,
|
78 |
+
num_inference_steps: int,
|
79 |
+
strength: float,
|
80 |
+
guidance_scale: float,
|
81 |
+
height: int,
|
82 |
+
width: int,
|
83 |
+
seed: int,
|
84 |
+
sort_priority: str = "y",
|
85 |
+
callback=None,
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
|
89 |
+
Args:
|
90 |
+
prompt:
|
91 |
+
negative_prompt:
|
92 |
+
image:
|
93 |
+
masked_image:
|
94 |
+
num_inference_steps:
|
95 |
+
strength:
|
96 |
+
guidance_scale:
|
97 |
+
height:
|
98 |
+
width:
|
99 |
+
seed:
|
100 |
+
sort_priority: x: left-right, y: top-down
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
result: list of images in numpy.ndarray format
|
104 |
+
rst_code: 0: normal -1: error 1:warning
|
105 |
+
rst_info: string of error or warning
|
106 |
+
|
107 |
+
"""
|
108 |
+
set_seed(seed)
|
109 |
+
str_warning = ""
|
110 |
+
|
111 |
+
mode = "text-editing"
|
112 |
+
revise_pos = False
|
113 |
+
img_count = 1
|
114 |
+
ddim_steps = num_inference_steps
|
115 |
+
w = width
|
116 |
+
h = height
|
117 |
+
strength = strength
|
118 |
+
cfg_scale = guidance_scale
|
119 |
+
eta = 0.0
|
120 |
+
|
121 |
+
prompt, texts = self.modify_prompt(prompt)
|
122 |
+
if prompt is None and texts is None:
|
123 |
+
return (
|
124 |
+
None,
|
125 |
+
-1,
|
126 |
+
"You have input Chinese prompt but the translator is not loaded!",
|
127 |
+
"",
|
128 |
+
)
|
129 |
+
n_lines = len(texts)
|
130 |
+
if mode in ["text-generation", "gen"]:
|
131 |
+
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
|
132 |
+
elif mode in ["text-editing", "edit"]:
|
133 |
+
if masked_image is None or image is None:
|
134 |
+
return (
|
135 |
+
None,
|
136 |
+
-1,
|
137 |
+
"Reference image and position image are needed for text editing!",
|
138 |
+
"",
|
139 |
+
)
|
140 |
+
if isinstance(image, str):
|
141 |
+
image = cv2.imread(image)[..., ::-1]
|
142 |
+
assert image is not None, f"Can't read ori_image image from{image}!"
|
143 |
+
elif isinstance(image, torch.Tensor):
|
144 |
+
image = image.cpu().numpy()
|
145 |
+
else:
|
146 |
+
assert isinstance(
|
147 |
+
image, np.ndarray
|
148 |
+
), f"Unknown format of ori_image: {type(image)}"
|
149 |
+
edit_image = image.clip(1, 255) # for mask reason
|
150 |
+
edit_image = check_channels(edit_image)
|
151 |
+
# edit_image = resize_image(
|
152 |
+
# edit_image, max_length=768
|
153 |
+
# ) # make w h multiple of 64, resize if w or h > max_length
|
154 |
+
h, w = edit_image.shape[:2] # change h, w by input ref_img
|
155 |
+
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
156 |
+
if masked_image is None:
|
157 |
+
pos_imgs = np.zeros((w, h, 1))
|
158 |
+
if isinstance(masked_image, str):
|
159 |
+
masked_image = cv2.imread(masked_image)[..., ::-1]
|
160 |
+
assert (
|
161 |
+
masked_image is not None
|
162 |
+
), f"Can't read draw_pos image from{masked_image}!"
|
163 |
+
pos_imgs = 255 - masked_image
|
164 |
+
elif isinstance(masked_image, torch.Tensor):
|
165 |
+
pos_imgs = masked_image.cpu().numpy()
|
166 |
+
else:
|
167 |
+
assert isinstance(
|
168 |
+
masked_image, np.ndarray
|
169 |
+
), f"Unknown format of draw_pos: {type(masked_image)}"
|
170 |
+
pos_imgs = 255 - masked_image
|
171 |
+
pos_imgs = pos_imgs[..., 0:1]
|
172 |
+
pos_imgs = cv2.convertScaleAbs(pos_imgs)
|
173 |
+
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
|
174 |
+
# seprate pos_imgs
|
175 |
+
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
|
176 |
+
if len(pos_imgs) == 0:
|
177 |
+
pos_imgs = [np.zeros((h, w, 1))]
|
178 |
+
if len(pos_imgs) < n_lines:
|
179 |
+
if n_lines == 1 and texts[0] == " ":
|
180 |
+
pass # text-to-image without text
|
181 |
+
else:
|
182 |
+
raise RuntimeError(
|
183 |
+
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
|
184 |
+
)
|
185 |
+
elif len(pos_imgs) > n_lines:
|
186 |
+
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
|
187 |
+
# get pre_pos, poly_list, hint that needed for anytext
|
188 |
+
pre_pos = []
|
189 |
+
poly_list = []
|
190 |
+
for input_pos in pos_imgs:
|
191 |
+
if input_pos.mean() != 0:
|
192 |
+
input_pos = (
|
193 |
+
input_pos[..., np.newaxis]
|
194 |
+
if len(input_pos.shape) == 2
|
195 |
+
else input_pos
|
196 |
+
)
|
197 |
+
poly, pos_img = self.find_polygon(input_pos)
|
198 |
+
pre_pos += [pos_img / 255.0]
|
199 |
+
poly_list += [poly]
|
200 |
+
else:
|
201 |
+
pre_pos += [np.zeros((h, w, 1))]
|
202 |
+
poly_list += [None]
|
203 |
+
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
|
204 |
+
# prepare info dict
|
205 |
+
info = {}
|
206 |
+
info["glyphs"] = []
|
207 |
+
info["gly_line"] = []
|
208 |
+
info["positions"] = []
|
209 |
+
info["n_lines"] = [len(texts)] * img_count
|
210 |
+
gly_pos_imgs = []
|
211 |
+
for i in range(len(texts)):
|
212 |
+
text = texts[i]
|
213 |
+
if len(text) > max_chars:
|
214 |
+
str_warning = (
|
215 |
+
f'"{text}" length > max_chars: {max_chars}, will be cut off...'
|
216 |
+
)
|
217 |
+
text = text[:max_chars]
|
218 |
+
gly_scale = 2
|
219 |
+
if pre_pos[i].mean() != 0:
|
220 |
+
gly_line = draw_glyph(self.font, text)
|
221 |
+
glyphs = draw_glyph2(
|
222 |
+
self.font,
|
223 |
+
text,
|
224 |
+
poly_list[i],
|
225 |
+
scale=gly_scale,
|
226 |
+
width=w,
|
227 |
+
height=h,
|
228 |
+
add_space=False,
|
229 |
+
)
|
230 |
+
gly_pos_img = cv2.drawContours(
|
231 |
+
glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
|
232 |
+
)
|
233 |
+
if revise_pos:
|
234 |
+
resize_gly = cv2.resize(
|
235 |
+
glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
|
236 |
+
)
|
237 |
+
new_pos = cv2.morphologyEx(
|
238 |
+
(resize_gly * 255).astype(np.uint8),
|
239 |
+
cv2.MORPH_CLOSE,
|
240 |
+
kernel=np.ones(
|
241 |
+
(resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
|
242 |
+
dtype=np.uint8,
|
243 |
+
),
|
244 |
+
iterations=1,
|
245 |
+
)
|
246 |
+
new_pos = (
|
247 |
+
new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
|
248 |
+
)
|
249 |
+
contours, _ = cv2.findContours(
|
250 |
+
new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
251 |
+
)
|
252 |
+
if len(contours) != 1:
|
253 |
+
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
|
254 |
+
else:
|
255 |
+
rect = cv2.minAreaRect(contours[0])
|
256 |
+
poly = np.int0(cv2.boxPoints(rect))
|
257 |
+
pre_pos[i] = (
|
258 |
+
cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
|
259 |
+
)
|
260 |
+
gly_pos_img = cv2.drawContours(
|
261 |
+
glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
|
262 |
+
)
|
263 |
+
gly_pos_imgs += [gly_pos_img] # for show
|
264 |
+
else:
|
265 |
+
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
|
266 |
+
gly_line = np.zeros((80, 512, 1))
|
267 |
+
gly_pos_imgs += [
|
268 |
+
np.zeros((h * gly_scale, w * gly_scale, 1))
|
269 |
+
] # for show
|
270 |
+
pos = pre_pos[i]
|
271 |
+
info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
|
272 |
+
info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
|
273 |
+
info["positions"] += [self.arr2tensor(pos, img_count)]
|
274 |
+
# get masked_x
|
275 |
+
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
276 |
+
masked_img = np.transpose(masked_img, (2, 0, 1))
|
277 |
+
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
|
278 |
+
if self.use_fp16:
|
279 |
+
masked_img = masked_img.half()
|
280 |
+
encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
|
281 |
+
masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
|
282 |
+
if self.use_fp16:
|
283 |
+
masked_x = masked_x.half()
|
284 |
+
info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
|
285 |
+
|
286 |
+
hint = self.arr2tensor(np_hint, img_count)
|
287 |
+
cond = self.model.get_learned_conditioning(
|
288 |
+
dict(
|
289 |
+
c_concat=[hint],
|
290 |
+
c_crossattn=[[prompt] * img_count],
|
291 |
+
text_info=info,
|
292 |
+
)
|
293 |
+
)
|
294 |
+
un_cond = self.model.get_learned_conditioning(
|
295 |
+
dict(
|
296 |
+
c_concat=[hint],
|
297 |
+
c_crossattn=[[negative_prompt] * img_count],
|
298 |
+
text_info=info,
|
299 |
+
)
|
300 |
+
)
|
301 |
+
shape = (4, h // 8, w // 8)
|
302 |
+
self.model.control_scales = [strength] * 13
|
303 |
+
samples, intermediates = self.ddim_sampler.sample(
|
304 |
+
ddim_steps,
|
305 |
+
img_count,
|
306 |
+
shape,
|
307 |
+
cond,
|
308 |
+
verbose=False,
|
309 |
+
eta=eta,
|
310 |
+
unconditional_guidance_scale=cfg_scale,
|
311 |
+
unconditional_conditioning=un_cond,
|
312 |
+
callback=callback
|
313 |
+
)
|
314 |
+
if self.use_fp16:
|
315 |
+
samples = samples.half()
|
316 |
+
x_samples = self.model.decode_first_stage(samples)
|
317 |
+
x_samples = (
|
318 |
+
(einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
|
319 |
+
.cpu()
|
320 |
+
.numpy()
|
321 |
+
.clip(0, 255)
|
322 |
+
.astype(np.uint8)
|
323 |
+
)
|
324 |
+
results = [x_samples[i] for i in range(img_count)]
|
325 |
+
# if (
|
326 |
+
# mode == "edit" and False
|
327 |
+
# ): # replace backgound in text editing but not ideal yet
|
328 |
+
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
329 |
+
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
330 |
+
# if len(gly_pos_imgs) > 0 and show_debug:
|
331 |
+
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
332 |
+
# glyph_img = np.sum(glyph_bs, axis=2) * 255
|
333 |
+
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
334 |
+
# results += [np.repeat(glyph_img, 3, axis=2)]
|
335 |
+
rst_code = 1 if str_warning else 0
|
336 |
+
return results, rst_code, str_warning
|
337 |
+
|
338 |
+
def modify_prompt(self, prompt):
|
339 |
+
prompt = prompt.replace("“", '"')
|
340 |
+
prompt = prompt.replace("”", '"')
|
341 |
+
p = '"(.*?)"'
|
342 |
+
strs = re.findall(p, prompt)
|
343 |
+
if len(strs) == 0:
|
344 |
+
strs = [" "]
|
345 |
+
else:
|
346 |
+
for s in strs:
|
347 |
+
prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
|
348 |
+
# if self.is_chinese(prompt):
|
349 |
+
# if self.trans_pipe is None:
|
350 |
+
# return None, None
|
351 |
+
# old_prompt = prompt
|
352 |
+
# prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
|
353 |
+
# print(f"Translate: {old_prompt} --> {prompt}")
|
354 |
+
return prompt, strs
|
355 |
+
|
356 |
+
# def is_chinese(self, text):
|
357 |
+
# text = checker._clean_text(text)
|
358 |
+
# for char in text:
|
359 |
+
# cp = ord(char)
|
360 |
+
# if checker._is_chinese_char(cp):
|
361 |
+
# return True
|
362 |
+
# return False
|
363 |
+
|
364 |
+
def separate_pos_imgs(self, img, sort_priority, gap=102):
|
365 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
|
366 |
+
components = []
|
367 |
+
for label in range(1, num_labels):
|
368 |
+
component = np.zeros_like(img)
|
369 |
+
component[labels == label] = 255
|
370 |
+
components.append((component, centroids[label]))
|
371 |
+
if sort_priority == "y":
|
372 |
+
fir, sec = 1, 0 # top-down first
|
373 |
+
elif sort_priority == "x":
|
374 |
+
fir, sec = 0, 1 # left-right first
|
375 |
+
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
|
376 |
+
sorted_components = [c[0] for c in components]
|
377 |
+
return sorted_components
|
378 |
+
|
379 |
+
def find_polygon(self, image, min_rect=False):
|
380 |
+
contours, hierarchy = cv2.findContours(
|
381 |
+
image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
382 |
+
)
|
383 |
+
max_contour = max(contours, key=cv2.contourArea) # get contour with max area
|
384 |
+
if min_rect:
|
385 |
+
# get minimum enclosing rectangle
|
386 |
+
rect = cv2.minAreaRect(max_contour)
|
387 |
+
poly = np.int0(cv2.boxPoints(rect))
|
388 |
+
else:
|
389 |
+
# get approximate polygon
|
390 |
+
epsilon = 0.01 * cv2.arcLength(max_contour, True)
|
391 |
+
poly = cv2.approxPolyDP(max_contour, epsilon, True)
|
392 |
+
n, _, xy = poly.shape
|
393 |
+
poly = poly.reshape(n, xy)
|
394 |
+
cv2.drawContours(image, [poly], -1, 255, -1)
|
395 |
+
return poly, image
|
396 |
+
|
397 |
+
def arr2tensor(self, arr, bs):
|
398 |
+
arr = np.transpose(arr, (2, 0, 1))
|
399 |
+
_arr = torch.from_numpy(arr.copy()).float().to(self.device)
|
400 |
+
if self.use_fp16:
|
401 |
+
_arr = _arr.half()
|
402 |
+
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
|
403 |
+
return _arr
|
iopaint/model/anytext/anytext_sd15.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: iopaint.model.anytext.cldm.cldm.ControlLDM
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
num_timesteps_cond: 1
|
7 |
+
log_every_t: 200
|
8 |
+
timesteps: 1000
|
9 |
+
first_stage_key: "img"
|
10 |
+
cond_stage_key: "caption"
|
11 |
+
control_key: "hint"
|
12 |
+
glyph_key: "glyphs"
|
13 |
+
position_key: "positions"
|
14 |
+
image_size: 64
|
15 |
+
channels: 4
|
16 |
+
cond_stage_trainable: true # need be true when embedding_manager is valid
|
17 |
+
conditioning_key: crossattn
|
18 |
+
monitor: val/loss_simple_ema
|
19 |
+
scale_factor: 0.18215
|
20 |
+
use_ema: False
|
21 |
+
only_mid_control: False
|
22 |
+
loss_alpha: 0 # perceptual loss, 0.003
|
23 |
+
loss_beta: 0 # ctc loss
|
24 |
+
latin_weight: 1.0 # latin text line may need smaller weigth
|
25 |
+
with_step_weight: true
|
26 |
+
use_vae_upsample: true
|
27 |
+
embedding_manager_config:
|
28 |
+
target: iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
|
29 |
+
params:
|
30 |
+
valid: true # v6
|
31 |
+
emb_type: ocr # ocr, vit, conv
|
32 |
+
glyph_channels: 1
|
33 |
+
position_channels: 1
|
34 |
+
add_pos: false
|
35 |
+
placeholder_string: '*'
|
36 |
+
|
37 |
+
control_stage_config:
|
38 |
+
target: iopaint.model.anytext.cldm.cldm.ControlNet
|
39 |
+
params:
|
40 |
+
image_size: 32 # unused
|
41 |
+
in_channels: 4
|
42 |
+
model_channels: 320
|
43 |
+
glyph_channels: 1
|
44 |
+
position_channels: 1
|
45 |
+
attention_resolutions: [ 4, 2, 1 ]
|
46 |
+
num_res_blocks: 2
|
47 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
48 |
+
num_heads: 8
|
49 |
+
use_spatial_transformer: True
|
50 |
+
transformer_depth: 1
|
51 |
+
context_dim: 768
|
52 |
+
use_checkpoint: True
|
53 |
+
legacy: False
|
54 |
+
|
55 |
+
unet_config:
|
56 |
+
target: iopaint.model.anytext.cldm.cldm.ControlledUnetModel
|
57 |
+
params:
|
58 |
+
image_size: 32 # unused
|
59 |
+
in_channels: 4
|
60 |
+
out_channels: 4
|
61 |
+
model_channels: 320
|
62 |
+
attention_resolutions: [ 4, 2, 1 ]
|
63 |
+
num_res_blocks: 2
|
64 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
65 |
+
num_heads: 8
|
66 |
+
use_spatial_transformer: True
|
67 |
+
transformer_depth: 1
|
68 |
+
context_dim: 768
|
69 |
+
use_checkpoint: True
|
70 |
+
legacy: False
|
71 |
+
|
72 |
+
first_stage_config:
|
73 |
+
target: iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
|
74 |
+
params:
|
75 |
+
embed_dim: 4
|
76 |
+
monitor: val/rec_loss
|
77 |
+
ddconfig:
|
78 |
+
double_z: true
|
79 |
+
z_channels: 4
|
80 |
+
resolution: 256
|
81 |
+
in_channels: 3
|
82 |
+
out_ch: 3
|
83 |
+
ch: 128
|
84 |
+
ch_mult:
|
85 |
+
- 1
|
86 |
+
- 2
|
87 |
+
- 4
|
88 |
+
- 4
|
89 |
+
num_res_blocks: 2
|
90 |
+
attn_resolutions: []
|
91 |
+
dropout: 0.0
|
92 |
+
lossconfig:
|
93 |
+
target: torch.nn.Identity
|
94 |
+
|
95 |
+
cond_stage_config:
|
96 |
+
target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
|
97 |
+
params:
|
98 |
+
version: openai/clip-vit-large-patch14
|
99 |
+
use_vision: false # v6
|
iopaint/model/anytext/cldm/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/models/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/models/autoencoder.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from contextlib import contextmanager
|
4 |
+
|
5 |
+
from iopaint.model.anytext.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
6 |
+
from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
7 |
+
|
8 |
+
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
9 |
+
from iopaint.model.anytext.ldm.modules.ema import LitEma
|
10 |
+
|
11 |
+
|
12 |
+
class AutoencoderKL(torch.nn.Module):
|
13 |
+
def __init__(self,
|
14 |
+
ddconfig,
|
15 |
+
lossconfig,
|
16 |
+
embed_dim,
|
17 |
+
ckpt_path=None,
|
18 |
+
ignore_keys=[],
|
19 |
+
image_key="image",
|
20 |
+
colorize_nlabels=None,
|
21 |
+
monitor=None,
|
22 |
+
ema_decay=None,
|
23 |
+
learn_logvar=False
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.learn_logvar = learn_logvar
|
27 |
+
self.image_key = image_key
|
28 |
+
self.encoder = Encoder(**ddconfig)
|
29 |
+
self.decoder = Decoder(**ddconfig)
|
30 |
+
self.loss = instantiate_from_config(lossconfig)
|
31 |
+
assert ddconfig["double_z"]
|
32 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
33 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
34 |
+
self.embed_dim = embed_dim
|
35 |
+
if colorize_nlabels is not None:
|
36 |
+
assert type(colorize_nlabels)==int
|
37 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
38 |
+
if monitor is not None:
|
39 |
+
self.monitor = monitor
|
40 |
+
|
41 |
+
self.use_ema = ema_decay is not None
|
42 |
+
if self.use_ema:
|
43 |
+
self.ema_decay = ema_decay
|
44 |
+
assert 0. < ema_decay < 1.
|
45 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
46 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
47 |
+
|
48 |
+
if ckpt_path is not None:
|
49 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
50 |
+
|
51 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
52 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
53 |
+
keys = list(sd.keys())
|
54 |
+
for k in keys:
|
55 |
+
for ik in ignore_keys:
|
56 |
+
if k.startswith(ik):
|
57 |
+
print("Deleting key {} from state_dict.".format(k))
|
58 |
+
del sd[k]
|
59 |
+
self.load_state_dict(sd, strict=False)
|
60 |
+
print(f"Restored from {path}")
|
61 |
+
|
62 |
+
@contextmanager
|
63 |
+
def ema_scope(self, context=None):
|
64 |
+
if self.use_ema:
|
65 |
+
self.model_ema.store(self.parameters())
|
66 |
+
self.model_ema.copy_to(self)
|
67 |
+
if context is not None:
|
68 |
+
print(f"{context}: Switched to EMA weights")
|
69 |
+
try:
|
70 |
+
yield None
|
71 |
+
finally:
|
72 |
+
if self.use_ema:
|
73 |
+
self.model_ema.restore(self.parameters())
|
74 |
+
if context is not None:
|
75 |
+
print(f"{context}: Restored training weights")
|
76 |
+
|
77 |
+
def on_train_batch_end(self, *args, **kwargs):
|
78 |
+
if self.use_ema:
|
79 |
+
self.model_ema(self)
|
80 |
+
|
81 |
+
def encode(self, x):
|
82 |
+
h = self.encoder(x)
|
83 |
+
moments = self.quant_conv(h)
|
84 |
+
posterior = DiagonalGaussianDistribution(moments)
|
85 |
+
return posterior
|
86 |
+
|
87 |
+
def decode(self, z):
|
88 |
+
z = self.post_quant_conv(z)
|
89 |
+
dec = self.decoder(z)
|
90 |
+
return dec
|
91 |
+
|
92 |
+
def forward(self, input, sample_posterior=True):
|
93 |
+
posterior = self.encode(input)
|
94 |
+
if sample_posterior:
|
95 |
+
z = posterior.sample()
|
96 |
+
else:
|
97 |
+
z = posterior.mode()
|
98 |
+
dec = self.decode(z)
|
99 |
+
return dec, posterior
|
100 |
+
|
101 |
+
def get_input(self, batch, k):
|
102 |
+
x = batch[k]
|
103 |
+
if len(x.shape) == 3:
|
104 |
+
x = x[..., None]
|
105 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
106 |
+
return x
|
107 |
+
|
108 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
109 |
+
inputs = self.get_input(batch, self.image_key)
|
110 |
+
reconstructions, posterior = self(inputs)
|
111 |
+
|
112 |
+
if optimizer_idx == 0:
|
113 |
+
# train encoder+decoder+logvar
|
114 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
115 |
+
last_layer=self.get_last_layer(), split="train")
|
116 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
117 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
118 |
+
return aeloss
|
119 |
+
|
120 |
+
if optimizer_idx == 1:
|
121 |
+
# train the discriminator
|
122 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
123 |
+
last_layer=self.get_last_layer(), split="train")
|
124 |
+
|
125 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
126 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
127 |
+
return discloss
|
128 |
+
|
129 |
+
def validation_step(self, batch, batch_idx):
|
130 |
+
log_dict = self._validation_step(batch, batch_idx)
|
131 |
+
with self.ema_scope():
|
132 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
133 |
+
return log_dict
|
134 |
+
|
135 |
+
def _validation_step(self, batch, batch_idx, postfix=""):
|
136 |
+
inputs = self.get_input(batch, self.image_key)
|
137 |
+
reconstructions, posterior = self(inputs)
|
138 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
139 |
+
last_layer=self.get_last_layer(), split="val"+postfix)
|
140 |
+
|
141 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
142 |
+
last_layer=self.get_last_layer(), split="val"+postfix)
|
143 |
+
|
144 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
145 |
+
self.log_dict(log_dict_ae)
|
146 |
+
self.log_dict(log_dict_disc)
|
147 |
+
return self.log_dict
|
148 |
+
|
149 |
+
def configure_optimizers(self):
|
150 |
+
lr = self.learning_rate
|
151 |
+
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
152 |
+
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
153 |
+
if self.learn_logvar:
|
154 |
+
print(f"{self.__class__.__name__}: Learning logvar")
|
155 |
+
ae_params_list.append(self.loss.logvar)
|
156 |
+
opt_ae = torch.optim.Adam(ae_params_list,
|
157 |
+
lr=lr, betas=(0.5, 0.9))
|
158 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
159 |
+
lr=lr, betas=(0.5, 0.9))
|
160 |
+
return [opt_ae, opt_disc], []
|
161 |
+
|
162 |
+
def get_last_layer(self):
|
163 |
+
return self.decoder.conv_out.weight
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
167 |
+
log = dict()
|
168 |
+
x = self.get_input(batch, self.image_key)
|
169 |
+
x = x.to(self.device)
|
170 |
+
if not only_inputs:
|
171 |
+
xrec, posterior = self(x)
|
172 |
+
if x.shape[1] > 3:
|
173 |
+
# colorize with random projection
|
174 |
+
assert xrec.shape[1] > 3
|
175 |
+
x = self.to_rgb(x)
|
176 |
+
xrec = self.to_rgb(xrec)
|
177 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
178 |
+
log["reconstructions"] = xrec
|
179 |
+
if log_ema or self.use_ema:
|
180 |
+
with self.ema_scope():
|
181 |
+
xrec_ema, posterior_ema = self(x)
|
182 |
+
if x.shape[1] > 3:
|
183 |
+
# colorize with random projection
|
184 |
+
assert xrec_ema.shape[1] > 3
|
185 |
+
xrec_ema = self.to_rgb(xrec_ema)
|
186 |
+
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
187 |
+
log["reconstructions_ema"] = xrec_ema
|
188 |
+
log["inputs"] = x
|
189 |
+
return log
|
190 |
+
|
191 |
+
def to_rgb(self, x):
|
192 |
+
assert self.image_key == "segmentation"
|
193 |
+
if not hasattr(self, "colorize"):
|
194 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
195 |
+
x = F.conv2d(x, weight=self.colorize)
|
196 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
197 |
+
return x
|
198 |
+
|
199 |
+
|
200 |
+
class IdentityFirstStage(torch.nn.Module):
|
201 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
202 |
+
self.vq_interface = vq_interface
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
def encode(self, x, *args, **kwargs):
|
206 |
+
return x
|
207 |
+
|
208 |
+
def decode(self, x, *args, **kwargs):
|
209 |
+
return x
|
210 |
+
|
211 |
+
def quantize(self, x, *args, **kwargs):
|
212 |
+
if self.vq_interface:
|
213 |
+
return x, None, [None, None, None]
|
214 |
+
return x
|
215 |
+
|
216 |
+
def forward(self, x, *args, **kwargs):
|
217 |
+
return x
|
218 |
+
|
iopaint/model/anytext/ldm/models/diffusion/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sampler import DPMSolverSampler
|
iopaint/model/anytext/ldm/modules/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/modules/attention.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn, einsum
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from typing import Optional, Any
|
8 |
+
|
9 |
+
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
|
10 |
+
|
11 |
+
|
12 |
+
# CrossAttn precision handling
|
13 |
+
import os
|
14 |
+
|
15 |
+
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
16 |
+
|
17 |
+
|
18 |
+
def exists(val):
|
19 |
+
return val is not None
|
20 |
+
|
21 |
+
|
22 |
+
def uniq(arr):
|
23 |
+
return {el: True for el in arr}.keys()
|
24 |
+
|
25 |
+
|
26 |
+
def default(val, d):
|
27 |
+
if exists(val):
|
28 |
+
return val
|
29 |
+
return d() if isfunction(d) else d
|
30 |
+
|
31 |
+
|
32 |
+
def max_neg_value(t):
|
33 |
+
return -torch.finfo(t.dtype).max
|
34 |
+
|
35 |
+
|
36 |
+
def init_(tensor):
|
37 |
+
dim = tensor.shape[-1]
|
38 |
+
std = 1 / math.sqrt(dim)
|
39 |
+
tensor.uniform_(-std, std)
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
|
43 |
+
# feedforward
|
44 |
+
class GEGLU(nn.Module):
|
45 |
+
def __init__(self, dim_in, dim_out):
|
46 |
+
super().__init__()
|
47 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
51 |
+
return x * F.gelu(gate)
|
52 |
+
|
53 |
+
|
54 |
+
class FeedForward(nn.Module):
|
55 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
56 |
+
super().__init__()
|
57 |
+
inner_dim = int(dim * mult)
|
58 |
+
dim_out = default(dim_out, dim)
|
59 |
+
project_in = (
|
60 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
61 |
+
if not glu
|
62 |
+
else GEGLU(dim, inner_dim)
|
63 |
+
)
|
64 |
+
|
65 |
+
self.net = nn.Sequential(
|
66 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return self.net(x)
|
71 |
+
|
72 |
+
|
73 |
+
def zero_module(module):
|
74 |
+
"""
|
75 |
+
Zero out the parameters of a module and return it.
|
76 |
+
"""
|
77 |
+
for p in module.parameters():
|
78 |
+
p.detach().zero_()
|
79 |
+
return module
|
80 |
+
|
81 |
+
|
82 |
+
def Normalize(in_channels):
|
83 |
+
return torch.nn.GroupNorm(
|
84 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
class SpatialSelfAttention(nn.Module):
|
89 |
+
def __init__(self, in_channels):
|
90 |
+
super().__init__()
|
91 |
+
self.in_channels = in_channels
|
92 |
+
|
93 |
+
self.norm = Normalize(in_channels)
|
94 |
+
self.q = torch.nn.Conv2d(
|
95 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
96 |
+
)
|
97 |
+
self.k = torch.nn.Conv2d(
|
98 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
99 |
+
)
|
100 |
+
self.v = torch.nn.Conv2d(
|
101 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
102 |
+
)
|
103 |
+
self.proj_out = torch.nn.Conv2d(
|
104 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
105 |
+
)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
h_ = x
|
109 |
+
h_ = self.norm(h_)
|
110 |
+
q = self.q(h_)
|
111 |
+
k = self.k(h_)
|
112 |
+
v = self.v(h_)
|
113 |
+
|
114 |
+
# compute attention
|
115 |
+
b, c, h, w = q.shape
|
116 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
117 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
118 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
119 |
+
|
120 |
+
w_ = w_ * (int(c) ** (-0.5))
|
121 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
122 |
+
|
123 |
+
# attend to values
|
124 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
125 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
126 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
127 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
128 |
+
h_ = self.proj_out(h_)
|
129 |
+
|
130 |
+
return x + h_
|
131 |
+
|
132 |
+
|
133 |
+
class CrossAttention(nn.Module):
|
134 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
135 |
+
super().__init__()
|
136 |
+
inner_dim = dim_head * heads
|
137 |
+
context_dim = default(context_dim, query_dim)
|
138 |
+
|
139 |
+
self.scale = dim_head**-0.5
|
140 |
+
self.heads = heads
|
141 |
+
|
142 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
143 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
144 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
145 |
+
|
146 |
+
self.to_out = nn.Sequential(
|
147 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x, context=None, mask=None):
|
151 |
+
h = self.heads
|
152 |
+
|
153 |
+
q = self.to_q(x)
|
154 |
+
context = default(context, x)
|
155 |
+
k = self.to_k(context)
|
156 |
+
v = self.to_v(context)
|
157 |
+
|
158 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
159 |
+
|
160 |
+
# force cast to fp32 to avoid overflowing
|
161 |
+
if _ATTN_PRECISION == "fp32":
|
162 |
+
with torch.autocast(enabled=False, device_type="cuda"):
|
163 |
+
q, k = q.float(), k.float()
|
164 |
+
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
165 |
+
else:
|
166 |
+
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
167 |
+
|
168 |
+
del q, k
|
169 |
+
|
170 |
+
if exists(mask):
|
171 |
+
mask = rearrange(mask, "b ... -> b (...)")
|
172 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
173 |
+
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
174 |
+
sim.masked_fill_(~mask, max_neg_value)
|
175 |
+
|
176 |
+
# attention, what we cannot get enough of
|
177 |
+
sim = sim.softmax(dim=-1)
|
178 |
+
|
179 |
+
out = einsum("b i j, b j d -> b i d", sim, v)
|
180 |
+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
181 |
+
return self.to_out(out)
|
182 |
+
|
183 |
+
|
184 |
+
class SDPACrossAttention(CrossAttention):
|
185 |
+
def forward(self, x, context=None, mask=None):
|
186 |
+
batch_size, sequence_length, inner_dim = x.shape
|
187 |
+
|
188 |
+
if mask is not None:
|
189 |
+
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
190 |
+
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
191 |
+
|
192 |
+
h = self.heads
|
193 |
+
q_in = self.to_q(x)
|
194 |
+
context = default(context, x)
|
195 |
+
|
196 |
+
k_in = self.to_k(context)
|
197 |
+
v_in = self.to_v(context)
|
198 |
+
|
199 |
+
head_dim = inner_dim // h
|
200 |
+
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
201 |
+
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
202 |
+
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
203 |
+
|
204 |
+
del q_in, k_in, v_in
|
205 |
+
|
206 |
+
dtype = q.dtype
|
207 |
+
if _ATTN_PRECISION == "fp32":
|
208 |
+
q, k, v = q.float(), k.float(), v.float()
|
209 |
+
|
210 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
211 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
212 |
+
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
213 |
+
)
|
214 |
+
|
215 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
216 |
+
batch_size, -1, h * head_dim
|
217 |
+
)
|
218 |
+
hidden_states = hidden_states.to(dtype)
|
219 |
+
|
220 |
+
# linear proj
|
221 |
+
hidden_states = self.to_out[0](hidden_states)
|
222 |
+
# dropout
|
223 |
+
hidden_states = self.to_out[1](hidden_states)
|
224 |
+
return hidden_states
|
225 |
+
|
226 |
+
|
227 |
+
class BasicTransformerBlock(nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
dim,
|
231 |
+
n_heads,
|
232 |
+
d_head,
|
233 |
+
dropout=0.0,
|
234 |
+
context_dim=None,
|
235 |
+
gated_ff=True,
|
236 |
+
checkpoint=True,
|
237 |
+
disable_self_attn=False,
|
238 |
+
):
|
239 |
+
super().__init__()
|
240 |
+
|
241 |
+
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
242 |
+
attn_cls = SDPACrossAttention
|
243 |
+
else:
|
244 |
+
attn_cls = CrossAttention
|
245 |
+
|
246 |
+
self.disable_self_attn = disable_self_attn
|
247 |
+
self.attn1 = attn_cls(
|
248 |
+
query_dim=dim,
|
249 |
+
heads=n_heads,
|
250 |
+
dim_head=d_head,
|
251 |
+
dropout=dropout,
|
252 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
253 |
+
) # is a self-attention if not self.disable_self_attn
|
254 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
255 |
+
self.attn2 = attn_cls(
|
256 |
+
query_dim=dim,
|
257 |
+
context_dim=context_dim,
|
258 |
+
heads=n_heads,
|
259 |
+
dim_head=d_head,
|
260 |
+
dropout=dropout,
|
261 |
+
) # is self-attn if context is none
|
262 |
+
self.norm1 = nn.LayerNorm(dim)
|
263 |
+
self.norm2 = nn.LayerNorm(dim)
|
264 |
+
self.norm3 = nn.LayerNorm(dim)
|
265 |
+
self.checkpoint = checkpoint
|
266 |
+
|
267 |
+
def forward(self, x, context=None):
|
268 |
+
return checkpoint(
|
269 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
270 |
+
)
|
271 |
+
|
272 |
+
def _forward(self, x, context=None):
|
273 |
+
x = (
|
274 |
+
self.attn1(
|
275 |
+
self.norm1(x), context=context if self.disable_self_attn else None
|
276 |
+
)
|
277 |
+
+ x
|
278 |
+
)
|
279 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
280 |
+
x = self.ff(self.norm3(x)) + x
|
281 |
+
return x
|
282 |
+
|
283 |
+
|
284 |
+
class SpatialTransformer(nn.Module):
|
285 |
+
"""
|
286 |
+
Transformer block for image-like data.
|
287 |
+
First, project the input (aka embedding)
|
288 |
+
and reshape to b, t, d.
|
289 |
+
Then apply standard transformer action.
|
290 |
+
Finally, reshape to image
|
291 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
292 |
+
"""
|
293 |
+
|
294 |
+
def __init__(
|
295 |
+
self,
|
296 |
+
in_channels,
|
297 |
+
n_heads,
|
298 |
+
d_head,
|
299 |
+
depth=1,
|
300 |
+
dropout=0.0,
|
301 |
+
context_dim=None,
|
302 |
+
disable_self_attn=False,
|
303 |
+
use_linear=False,
|
304 |
+
use_checkpoint=True,
|
305 |
+
):
|
306 |
+
super().__init__()
|
307 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
308 |
+
context_dim = [context_dim]
|
309 |
+
self.in_channels = in_channels
|
310 |
+
inner_dim = n_heads * d_head
|
311 |
+
self.norm = Normalize(in_channels)
|
312 |
+
if not use_linear:
|
313 |
+
self.proj_in = nn.Conv2d(
|
314 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
318 |
+
|
319 |
+
self.transformer_blocks = nn.ModuleList(
|
320 |
+
[
|
321 |
+
BasicTransformerBlock(
|
322 |
+
inner_dim,
|
323 |
+
n_heads,
|
324 |
+
d_head,
|
325 |
+
dropout=dropout,
|
326 |
+
context_dim=context_dim[d],
|
327 |
+
disable_self_attn=disable_self_attn,
|
328 |
+
checkpoint=use_checkpoint,
|
329 |
+
)
|
330 |
+
for d in range(depth)
|
331 |
+
]
|
332 |
+
)
|
333 |
+
if not use_linear:
|
334 |
+
self.proj_out = zero_module(
|
335 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
339 |
+
self.use_linear = use_linear
|
340 |
+
|
341 |
+
def forward(self, x, context=None):
|
342 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
343 |
+
if not isinstance(context, list):
|
344 |
+
context = [context]
|
345 |
+
b, c, h, w = x.shape
|
346 |
+
x_in = x
|
347 |
+
x = self.norm(x)
|
348 |
+
if not self.use_linear:
|
349 |
+
x = self.proj_in(x)
|
350 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
351 |
+
if self.use_linear:
|
352 |
+
x = self.proj_in(x)
|
353 |
+
for i, block in enumerate(self.transformer_blocks):
|
354 |
+
x = block(x, context=context[i])
|
355 |
+
if self.use_linear:
|
356 |
+
x = self.proj_out(x)
|
357 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
358 |
+
if not self.use_linear:
|
359 |
+
x = self.proj_out(x)
|
360 |
+
return x + x_in
|
iopaint/model/anytext/ldm/modules/diffusionmodules/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/modules/distributions/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ldm/modules/encoders/__init__.py
ADDED
File without changes
|
iopaint/model/anytext/ocr_recog/__init__.py
ADDED
File without changes
|
iopaint/model/base.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from iopaint.helper import (
|
10 |
+
boxes_from_mask,
|
11 |
+
resize_max_size,
|
12 |
+
pad_img_to_modulo,
|
13 |
+
switch_mps_device,
|
14 |
+
)
|
15 |
+
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
|
16 |
+
from .helper.g_diffuser_bot import expand_image
|
17 |
+
from .utils import get_scheduler
|
18 |
+
|
19 |
+
|
20 |
+
class InpaintModel:
|
21 |
+
name = "base"
|
22 |
+
min_size: Optional[int] = None
|
23 |
+
pad_mod = 8
|
24 |
+
pad_to_square = False
|
25 |
+
is_erase_model = False
|
26 |
+
|
27 |
+
def __init__(self, device, **kwargs):
|
28 |
+
"""
|
29 |
+
|
30 |
+
Args:
|
31 |
+
device:
|
32 |
+
"""
|
33 |
+
device = switch_mps_device(self.name, device)
|
34 |
+
self.device = device
|
35 |
+
self.init_model(device, **kwargs)
|
36 |
+
|
37 |
+
@abc.abstractmethod
|
38 |
+
def init_model(self, device, **kwargs):
|
39 |
+
...
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
@abc.abstractmethod
|
43 |
+
def is_downloaded() -> bool:
|
44 |
+
return False
|
45 |
+
|
46 |
+
@abc.abstractmethod
|
47 |
+
def forward(self, image, mask, config: InpaintRequest):
|
48 |
+
"""Input images and output images have same size
|
49 |
+
images: [H, W, C] RGB
|
50 |
+
masks: [H, W, 1] 255 为 masks 区域
|
51 |
+
return: BGR IMAGE
|
52 |
+
"""
|
53 |
+
...
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def download():
|
57 |
+
...
|
58 |
+
|
59 |
+
def _pad_forward(self, image, mask, config: InpaintRequest):
|
60 |
+
origin_height, origin_width = image.shape[:2]
|
61 |
+
pad_image = pad_img_to_modulo(
|
62 |
+
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
63 |
+
)
|
64 |
+
pad_mask = pad_img_to_modulo(
|
65 |
+
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
66 |
+
)
|
67 |
+
|
68 |
+
# logger.info(f"final forward pad size: {pad_image.shape}")
|
69 |
+
|
70 |
+
image, mask = self.forward_pre_process(image, mask, config)
|
71 |
+
|
72 |
+
result = self.forward(pad_image, pad_mask, config)
|
73 |
+
result = result[0:origin_height, 0:origin_width, :]
|
74 |
+
|
75 |
+
result, image, mask = self.forward_post_process(result, image, mask, config)
|
76 |
+
|
77 |
+
if config.sd_keep_unmasked_area:
|
78 |
+
mask = mask[:, :, np.newaxis]
|
79 |
+
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
80 |
+
return result
|
81 |
+
|
82 |
+
def forward_pre_process(self, image, mask, config):
|
83 |
+
return image, mask
|
84 |
+
|
85 |
+
def forward_post_process(self, result, image, mask, config):
|
86 |
+
return result, image, mask
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
def __call__(self, image, mask, config: InpaintRequest):
|
90 |
+
"""
|
91 |
+
images: [H, W, C] RGB, not normalized
|
92 |
+
masks: [H, W]
|
93 |
+
return: BGR IMAGE
|
94 |
+
"""
|
95 |
+
inpaint_result = None
|
96 |
+
# logger.info(f"hd_strategy: {config.hd_strategy}")
|
97 |
+
if config.hd_strategy == HDStrategy.CROP:
|
98 |
+
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
99 |
+
logger.info(f"Run crop strategy")
|
100 |
+
boxes = boxes_from_mask(mask)
|
101 |
+
crop_result = []
|
102 |
+
for box in boxes:
|
103 |
+
crop_image, crop_box = self._run_box(image, mask, box, config)
|
104 |
+
crop_result.append((crop_image, crop_box))
|
105 |
+
|
106 |
+
inpaint_result = image[:, :, ::-1]
|
107 |
+
for crop_image, crop_box in crop_result:
|
108 |
+
x1, y1, x2, y2 = crop_box
|
109 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
110 |
+
|
111 |
+
elif config.hd_strategy == HDStrategy.RESIZE:
|
112 |
+
if max(image.shape) > config.hd_strategy_resize_limit:
|
113 |
+
origin_size = image.shape[:2]
|
114 |
+
downsize_image = resize_max_size(
|
115 |
+
image, size_limit=config.hd_strategy_resize_limit
|
116 |
+
)
|
117 |
+
downsize_mask = resize_max_size(
|
118 |
+
mask, size_limit=config.hd_strategy_resize_limit
|
119 |
+
)
|
120 |
+
|
121 |
+
logger.info(
|
122 |
+
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
123 |
+
)
|
124 |
+
inpaint_result = self._pad_forward(
|
125 |
+
downsize_image, downsize_mask, config
|
126 |
+
)
|
127 |
+
|
128 |
+
# only paste masked area result
|
129 |
+
inpaint_result = cv2.resize(
|
130 |
+
inpaint_result,
|
131 |
+
(origin_size[1], origin_size[0]),
|
132 |
+
interpolation=cv2.INTER_CUBIC,
|
133 |
+
)
|
134 |
+
original_pixel_indices = mask < 127
|
135 |
+
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
136 |
+
original_pixel_indices
|
137 |
+
]
|
138 |
+
|
139 |
+
if inpaint_result is None:
|
140 |
+
inpaint_result = self._pad_forward(image, mask, config)
|
141 |
+
|
142 |
+
return inpaint_result
|
143 |
+
|
144 |
+
def _crop_box(self, image, mask, box, config: InpaintRequest):
|
145 |
+
"""
|
146 |
+
|
147 |
+
Args:
|
148 |
+
image: [H, W, C] RGB
|
149 |
+
mask: [H, W, 1]
|
150 |
+
box: [left,top,right,bottom]
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
BGR IMAGE, (l, r, r, b)
|
154 |
+
"""
|
155 |
+
box_h = box[3] - box[1]
|
156 |
+
box_w = box[2] - box[0]
|
157 |
+
cx = (box[0] + box[2]) // 2
|
158 |
+
cy = (box[1] + box[3]) // 2
|
159 |
+
img_h, img_w = image.shape[:2]
|
160 |
+
|
161 |
+
w = box_w + config.hd_strategy_crop_margin * 2
|
162 |
+
h = box_h + config.hd_strategy_crop_margin * 2
|
163 |
+
|
164 |
+
_l = cx - w // 2
|
165 |
+
_r = cx + w // 2
|
166 |
+
_t = cy - h // 2
|
167 |
+
_b = cy + h // 2
|
168 |
+
|
169 |
+
l = max(_l, 0)
|
170 |
+
r = min(_r, img_w)
|
171 |
+
t = max(_t, 0)
|
172 |
+
b = min(_b, img_h)
|
173 |
+
|
174 |
+
# try to get more context when crop around image edge
|
175 |
+
if _l < 0:
|
176 |
+
r += abs(_l)
|
177 |
+
if _r > img_w:
|
178 |
+
l -= _r - img_w
|
179 |
+
if _t < 0:
|
180 |
+
b += abs(_t)
|
181 |
+
if _b > img_h:
|
182 |
+
t -= _b - img_h
|
183 |
+
|
184 |
+
l = max(l, 0)
|
185 |
+
r = min(r, img_w)
|
186 |
+
t = max(t, 0)
|
187 |
+
b = min(b, img_h)
|
188 |
+
|
189 |
+
crop_img = image[t:b, l:r, :]
|
190 |
+
crop_mask = mask[t:b, l:r]
|
191 |
+
|
192 |
+
# logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
193 |
+
|
194 |
+
return crop_img, crop_mask, [l, t, r, b]
|
195 |
+
|
196 |
+
def _calculate_cdf(self, histogram):
|
197 |
+
cdf = histogram.cumsum()
|
198 |
+
normalized_cdf = cdf / float(cdf.max())
|
199 |
+
return normalized_cdf
|
200 |
+
|
201 |
+
def _calculate_lookup(self, source_cdf, reference_cdf):
|
202 |
+
lookup_table = np.zeros(256)
|
203 |
+
lookup_val = 0
|
204 |
+
for source_index, source_val in enumerate(source_cdf):
|
205 |
+
for reference_index, reference_val in enumerate(reference_cdf):
|
206 |
+
if reference_val >= source_val:
|
207 |
+
lookup_val = reference_index
|
208 |
+
break
|
209 |
+
lookup_table[source_index] = lookup_val
|
210 |
+
return lookup_table
|
211 |
+
|
212 |
+
def _match_histograms(self, source, reference, mask):
|
213 |
+
transformed_channels = []
|
214 |
+
if len(mask.shape) == 3:
|
215 |
+
mask = mask[:, :, -1]
|
216 |
+
|
217 |
+
for channel in range(source.shape[-1]):
|
218 |
+
source_channel = source[:, :, channel]
|
219 |
+
reference_channel = reference[:, :, channel]
|
220 |
+
|
221 |
+
# only calculate histograms for non-masked parts
|
222 |
+
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
223 |
+
reference_histogram, _ = np.histogram(
|
224 |
+
reference_channel[mask == 0], 256, [0, 256]
|
225 |
+
)
|
226 |
+
|
227 |
+
source_cdf = self._calculate_cdf(source_histogram)
|
228 |
+
reference_cdf = self._calculate_cdf(reference_histogram)
|
229 |
+
|
230 |
+
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
231 |
+
|
232 |
+
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
233 |
+
|
234 |
+
result = cv2.merge(transformed_channels)
|
235 |
+
result = cv2.convertScaleAbs(result)
|
236 |
+
|
237 |
+
return result
|
238 |
+
|
239 |
+
def _apply_cropper(self, image, mask, config: InpaintRequest):
|
240 |
+
img_h, img_w = image.shape[:2]
|
241 |
+
l, t, w, h = (
|
242 |
+
config.croper_x,
|
243 |
+
config.croper_y,
|
244 |
+
config.croper_width,
|
245 |
+
config.croper_height,
|
246 |
+
)
|
247 |
+
r = l + w
|
248 |
+
b = t + h
|
249 |
+
|
250 |
+
l = max(l, 0)
|
251 |
+
r = min(r, img_w)
|
252 |
+
t = max(t, 0)
|
253 |
+
b = min(b, img_h)
|
254 |
+
|
255 |
+
crop_img = image[t:b, l:r, :]
|
256 |
+
crop_mask = mask[t:b, l:r]
|
257 |
+
return crop_img, crop_mask, (l, t, r, b)
|
258 |
+
|
259 |
+
def _run_box(self, image, mask, box, config: InpaintRequest):
|
260 |
+
"""
|
261 |
+
|
262 |
+
Args:
|
263 |
+
image: [H, W, C] RGB
|
264 |
+
mask: [H, W, 1]
|
265 |
+
box: [left,top,right,bottom]
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
BGR IMAGE
|
269 |
+
"""
|
270 |
+
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
271 |
+
|
272 |
+
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
273 |
+
|
274 |
+
|
275 |
+
class DiffusionInpaintModel(InpaintModel):
|
276 |
+
def __init__(self, device, **kwargs):
|
277 |
+
self.model_info = kwargs["model_info"]
|
278 |
+
self.model_id_or_path = self.model_info.path
|
279 |
+
super().__init__(device, **kwargs)
|
280 |
+
|
281 |
+
@torch.no_grad()
|
282 |
+
def __call__(self, image, mask, config: InpaintRequest):
|
283 |
+
"""
|
284 |
+
images: [H, W, C] RGB, not normalized
|
285 |
+
masks: [H, W]
|
286 |
+
return: BGR IMAGE
|
287 |
+
"""
|
288 |
+
# boxes = boxes_from_mask(mask)
|
289 |
+
if config.use_croper:
|
290 |
+
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
291 |
+
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
292 |
+
inpaint_result = image[:, :, ::-1]
|
293 |
+
inpaint_result[t:b, l:r, :] = crop_image
|
294 |
+
elif config.use_extender:
|
295 |
+
inpaint_result = self._do_outpainting(image, config)
|
296 |
+
else:
|
297 |
+
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
298 |
+
|
299 |
+
return inpaint_result
|
300 |
+
|
301 |
+
def _do_outpainting(self, image, config: InpaintRequest):
|
302 |
+
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
303 |
+
# 从 image 中 crop 出 outpainting 区域
|
304 |
+
image_h, image_w = image.shape[:2]
|
305 |
+
cropper_l = config.extender_x
|
306 |
+
cropper_t = config.extender_y
|
307 |
+
cropper_r = config.extender_x + config.extender_width
|
308 |
+
cropper_b = config.extender_y + config.extender_height
|
309 |
+
image_l = 0
|
310 |
+
image_t = 0
|
311 |
+
image_r = image_w
|
312 |
+
image_b = image_h
|
313 |
+
|
314 |
+
# 类似求 IOU
|
315 |
+
l = max(cropper_l, image_l)
|
316 |
+
t = max(cropper_t, image_t)
|
317 |
+
r = min(cropper_r, image_r)
|
318 |
+
b = min(cropper_b, image_b)
|
319 |
+
|
320 |
+
assert (
|
321 |
+
0 <= l < r and 0 <= t < b
|
322 |
+
), f"cropper and image not overlap, {l},{t},{r},{b}"
|
323 |
+
|
324 |
+
cropped_image = image[t:b, l:r, :]
|
325 |
+
padding_l = max(0, image_l - cropper_l)
|
326 |
+
padding_t = max(0, image_t - cropper_t)
|
327 |
+
padding_r = max(0, cropper_r - image_r)
|
328 |
+
padding_b = max(0, cropper_b - image_b)
|
329 |
+
|
330 |
+
expanded_image, mask_image = expand_image(
|
331 |
+
cropped_image,
|
332 |
+
left=padding_l,
|
333 |
+
top=padding_t,
|
334 |
+
right=padding_r,
|
335 |
+
bottom=padding_b,
|
336 |
+
softness=config.sd_outpainting_softness,
|
337 |
+
space=config.sd_outpainting_space,
|
338 |
+
)
|
339 |
+
|
340 |
+
# 最终扩大了的 image, BGR
|
341 |
+
expanded_cropped_result_image = self._scaled_pad_forward(
|
342 |
+
expanded_image, mask_image, config
|
343 |
+
)
|
344 |
+
|
345 |
+
# RGB -> BGR
|
346 |
+
outpainting_image = cv2.copyMakeBorder(
|
347 |
+
image,
|
348 |
+
left=padding_l,
|
349 |
+
top=padding_t,
|
350 |
+
right=padding_r,
|
351 |
+
bottom=padding_b,
|
352 |
+
borderType=cv2.BORDER_CONSTANT,
|
353 |
+
value=0,
|
354 |
+
)[:, :, ::-1]
|
355 |
+
|
356 |
+
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
|
357 |
+
paste_t = 0 if config.extender_y < 0 else config.extender_y
|
358 |
+
paste_l = 0 if config.extender_x < 0 else config.extender_x
|
359 |
+
|
360 |
+
outpainting_image[
|
361 |
+
paste_t : paste_t + expanded_cropped_result_image.shape[0],
|
362 |
+
paste_l : paste_l + expanded_cropped_result_image.shape[1],
|
363 |
+
:,
|
364 |
+
] = expanded_cropped_result_image
|
365 |
+
return outpainting_image
|
366 |
+
|
367 |
+
def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
|
368 |
+
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
369 |
+
origin_size = image.shape[:2]
|
370 |
+
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
371 |
+
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
372 |
+
if config.sd_scale != 1:
|
373 |
+
logger.info(
|
374 |
+
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
375 |
+
)
|
376 |
+
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
377 |
+
# only paste masked area result
|
378 |
+
inpaint_result = cv2.resize(
|
379 |
+
inpaint_result,
|
380 |
+
(origin_size[1], origin_size[0]),
|
381 |
+
interpolation=cv2.INTER_CUBIC,
|
382 |
+
)
|
383 |
+
|
384 |
+
# blend result, copy from g_diffuser_bot
|
385 |
+
# mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0)
|
386 |
+
# inpaint_result = np.clip(
|
387 |
+
# inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0
|
388 |
+
# )
|
389 |
+
# original_pixel_indices = mask < 127
|
390 |
+
# inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
391 |
+
# original_pixel_indices
|
392 |
+
# ]
|
393 |
+
return inpaint_result
|
394 |
+
|
395 |
+
def set_scheduler(self, config: InpaintRequest):
|
396 |
+
scheduler_config = self.model.scheduler.config
|
397 |
+
sd_sampler = config.sd_sampler
|
398 |
+
if config.sd_lcm_lora and self.model_info.support_lcm_lora:
|
399 |
+
sd_sampler = SDSampler.lcm
|
400 |
+
logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
|
401 |
+
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
402 |
+
self.model.scheduler = scheduler
|
403 |
+
|
404 |
+
def forward_pre_process(self, image, mask, config):
|
405 |
+
if config.sd_mask_blur != 0:
|
406 |
+
k = 2 * config.sd_mask_blur + 1
|
407 |
+
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
408 |
+
|
409 |
+
return image, mask
|
410 |
+
|
411 |
+
def forward_post_process(self, result, image, mask, config):
|
412 |
+
if config.sd_match_histograms:
|
413 |
+
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
414 |
+
|
415 |
+
if config.sd_mask_blur != 0:
|
416 |
+
k = 2 * config.sd_mask_blur + 1
|
417 |
+
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
418 |
+
return result, image, mask
|
iopaint/model/helper/__init__.py
ADDED
File without changes
|
iopaint/model/original_sd_configs/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
CURRENT_DIR = Path(__file__).parent.absolute()
|
5 |
+
|
6 |
+
|
7 |
+
def get_config_files() -> Dict[str, Path]:
|
8 |
+
"""
|
9 |
+
- `v1`: Config file for Stable Diffusion v1
|
10 |
+
- `v2`: Config file for Stable Diffusion v2
|
11 |
+
- `xl`: Config file for Stable Diffusion XL
|
12 |
+
- `xl_refiner`: Config file for Stable Diffusion XL Refiner
|
13 |
+
"""
|
14 |
+
return {
|
15 |
+
"v1": CURRENT_DIR / "v1-inference.yaml",
|
16 |
+
"v2": CURRENT_DIR / "v2-inference-v.yaml",
|
17 |
+
"xl": CURRENT_DIR / "sd_xl_base.yaml",
|
18 |
+
"xl_refiner": CURRENT_DIR / "sd_xl_refiner.yaml",
|
19 |
+
}
|
iopaint/model/power_paint/__init__.py
ADDED
File without changes
|
iopaint/plugins/__init__.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from loguru import logger
|
4 |
+
|
5 |
+
from .anime_seg import AnimeSeg
|
6 |
+
from .gfpgan_plugin import GFPGANPlugin
|
7 |
+
from .interactive_seg import InteractiveSeg
|
8 |
+
from .realesrgan import RealESRGANUpscaler
|
9 |
+
from .remove_bg import RemoveBG
|
10 |
+
from .restoreformer import RestoreFormerPlugin
|
11 |
+
from ..schema import InteractiveSegModel, Device, RealESRGANModel
|
12 |
+
|
13 |
+
|
14 |
+
def build_plugins(
|
15 |
+
enable_interactive_seg: bool,
|
16 |
+
interactive_seg_model: InteractiveSegModel,
|
17 |
+
interactive_seg_device: Device,
|
18 |
+
enable_remove_bg: bool,
|
19 |
+
remove_bg_model: str,
|
20 |
+
enable_anime_seg: bool,
|
21 |
+
enable_realesrgan: bool,
|
22 |
+
realesrgan_device: Device,
|
23 |
+
realesrgan_model: RealESRGANModel,
|
24 |
+
enable_gfpgan: bool,
|
25 |
+
gfpgan_device: Device,
|
26 |
+
enable_restoreformer: bool,
|
27 |
+
restoreformer_device: Device,
|
28 |
+
no_half: bool,
|
29 |
+
) -> Dict:
|
30 |
+
plugins = {}
|
31 |
+
if enable_interactive_seg:
|
32 |
+
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
33 |
+
plugins[InteractiveSeg.name] = InteractiveSeg(
|
34 |
+
interactive_seg_model, interactive_seg_device
|
35 |
+
)
|
36 |
+
|
37 |
+
if enable_remove_bg:
|
38 |
+
logger.info(f"Initialize {RemoveBG.name} plugin")
|
39 |
+
plugins[RemoveBG.name] = RemoveBG(remove_bg_model)
|
40 |
+
|
41 |
+
if enable_anime_seg:
|
42 |
+
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
43 |
+
plugins[AnimeSeg.name] = AnimeSeg()
|
44 |
+
|
45 |
+
if enable_realesrgan:
|
46 |
+
logger.info(
|
47 |
+
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
|
48 |
+
)
|
49 |
+
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
50 |
+
realesrgan_model,
|
51 |
+
realesrgan_device,
|
52 |
+
no_half=no_half,
|
53 |
+
)
|
54 |
+
|
55 |
+
if enable_gfpgan:
|
56 |
+
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
57 |
+
if enable_realesrgan:
|
58 |
+
logger.info("Use realesrgan as GFPGAN background upscaler")
|
59 |
+
else:
|
60 |
+
logger.info(
|
61 |
+
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
62 |
+
)
|
63 |
+
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
64 |
+
gfpgan_device,
|
65 |
+
upscaler=plugins.get(RealESRGANUpscaler.name, None),
|
66 |
+
)
|
67 |
+
|
68 |
+
if enable_restoreformer:
|
69 |
+
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
70 |
+
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
71 |
+
restoreformer_device,
|
72 |
+
upscaler=plugins.get(RealESRGANUpscaler.name, None),
|
73 |
+
)
|
74 |
+
return plugins
|
iopaint/plugins/anime_seg.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from iopaint.helper import load_model
|
9 |
+
from iopaint.plugins.base_plugin import BasePlugin
|
10 |
+
from iopaint.schema import RunPluginRequest
|
11 |
+
|
12 |
+
|
13 |
+
class REBNCONV(nn.Module):
|
14 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
15 |
+
super(REBNCONV, self).__init__()
|
16 |
+
|
17 |
+
self.conv_s1 = nn.Conv2d(
|
18 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
19 |
+
)
|
20 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
21 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
hx = x
|
25 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
26 |
+
|
27 |
+
return xout
|
28 |
+
|
29 |
+
|
30 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
31 |
+
def _upsample_like(src, tar):
|
32 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
33 |
+
|
34 |
+
return src
|
35 |
+
|
36 |
+
|
37 |
+
### RSU-7 ###
|
38 |
+
class RSU7(nn.Module):
|
39 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
40 |
+
super(RSU7, self).__init__()
|
41 |
+
|
42 |
+
self.in_ch = in_ch
|
43 |
+
self.mid_ch = mid_ch
|
44 |
+
self.out_ch = out_ch
|
45 |
+
|
46 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
47 |
+
|
48 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
49 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
50 |
+
|
51 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
52 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
53 |
+
|
54 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
55 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
56 |
+
|
57 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
58 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
59 |
+
|
60 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
61 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
62 |
+
|
63 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
64 |
+
|
65 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
66 |
+
|
67 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
68 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
69 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
70 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
71 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
72 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
b, c, h, w = x.shape
|
76 |
+
|
77 |
+
hx = x
|
78 |
+
hxin = self.rebnconvin(hx)
|
79 |
+
|
80 |
+
hx1 = self.rebnconv1(hxin)
|
81 |
+
hx = self.pool1(hx1)
|
82 |
+
|
83 |
+
hx2 = self.rebnconv2(hx)
|
84 |
+
hx = self.pool2(hx2)
|
85 |
+
|
86 |
+
hx3 = self.rebnconv3(hx)
|
87 |
+
hx = self.pool3(hx3)
|
88 |
+
|
89 |
+
hx4 = self.rebnconv4(hx)
|
90 |
+
hx = self.pool4(hx4)
|
91 |
+
|
92 |
+
hx5 = self.rebnconv5(hx)
|
93 |
+
hx = self.pool5(hx5)
|
94 |
+
|
95 |
+
hx6 = self.rebnconv6(hx)
|
96 |
+
|
97 |
+
hx7 = self.rebnconv7(hx6)
|
98 |
+
|
99 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
100 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
101 |
+
|
102 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
103 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
104 |
+
|
105 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
106 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
107 |
+
|
108 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
109 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
110 |
+
|
111 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
112 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
113 |
+
|
114 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
115 |
+
|
116 |
+
return hx1d + hxin
|
117 |
+
|
118 |
+
|
119 |
+
### RSU-6 ###
|
120 |
+
class RSU6(nn.Module):
|
121 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
122 |
+
super(RSU6, self).__init__()
|
123 |
+
|
124 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
125 |
+
|
126 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
127 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
128 |
+
|
129 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
130 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
131 |
+
|
132 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
133 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
134 |
+
|
135 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
136 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
137 |
+
|
138 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
139 |
+
|
140 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
141 |
+
|
142 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
143 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
144 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
145 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
146 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
hx = x
|
150 |
+
|
151 |
+
hxin = self.rebnconvin(hx)
|
152 |
+
|
153 |
+
hx1 = self.rebnconv1(hxin)
|
154 |
+
hx = self.pool1(hx1)
|
155 |
+
|
156 |
+
hx2 = self.rebnconv2(hx)
|
157 |
+
hx = self.pool2(hx2)
|
158 |
+
|
159 |
+
hx3 = self.rebnconv3(hx)
|
160 |
+
hx = self.pool3(hx3)
|
161 |
+
|
162 |
+
hx4 = self.rebnconv4(hx)
|
163 |
+
hx = self.pool4(hx4)
|
164 |
+
|
165 |
+
hx5 = self.rebnconv5(hx)
|
166 |
+
|
167 |
+
hx6 = self.rebnconv6(hx5)
|
168 |
+
|
169 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
170 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
171 |
+
|
172 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
173 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
174 |
+
|
175 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
176 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
177 |
+
|
178 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
179 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
180 |
+
|
181 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
182 |
+
|
183 |
+
return hx1d + hxin
|
184 |
+
|
185 |
+
|
186 |
+
### RSU-5 ###
|
187 |
+
class RSU5(nn.Module):
|
188 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
189 |
+
super(RSU5, self).__init__()
|
190 |
+
|
191 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
192 |
+
|
193 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
194 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
195 |
+
|
196 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
197 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
198 |
+
|
199 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
200 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
201 |
+
|
202 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
203 |
+
|
204 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
205 |
+
|
206 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
207 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
208 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
209 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
hx = x
|
213 |
+
|
214 |
+
hxin = self.rebnconvin(hx)
|
215 |
+
|
216 |
+
hx1 = self.rebnconv1(hxin)
|
217 |
+
hx = self.pool1(hx1)
|
218 |
+
|
219 |
+
hx2 = self.rebnconv2(hx)
|
220 |
+
hx = self.pool2(hx2)
|
221 |
+
|
222 |
+
hx3 = self.rebnconv3(hx)
|
223 |
+
hx = self.pool3(hx3)
|
224 |
+
|
225 |
+
hx4 = self.rebnconv4(hx)
|
226 |
+
|
227 |
+
hx5 = self.rebnconv5(hx4)
|
228 |
+
|
229 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
230 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
231 |
+
|
232 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
233 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
234 |
+
|
235 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
236 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
237 |
+
|
238 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
239 |
+
|
240 |
+
return hx1d + hxin
|
241 |
+
|
242 |
+
|
243 |
+
### RSU-4 ###
|
244 |
+
class RSU4(nn.Module):
|
245 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
246 |
+
super(RSU4, self).__init__()
|
247 |
+
|
248 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
249 |
+
|
250 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
251 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
252 |
+
|
253 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
254 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
255 |
+
|
256 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
257 |
+
|
258 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
259 |
+
|
260 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
261 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
262 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
hx = x
|
266 |
+
|
267 |
+
hxin = self.rebnconvin(hx)
|
268 |
+
|
269 |
+
hx1 = self.rebnconv1(hxin)
|
270 |
+
hx = self.pool1(hx1)
|
271 |
+
|
272 |
+
hx2 = self.rebnconv2(hx)
|
273 |
+
hx = self.pool2(hx2)
|
274 |
+
|
275 |
+
hx3 = self.rebnconv3(hx)
|
276 |
+
|
277 |
+
hx4 = self.rebnconv4(hx3)
|
278 |
+
|
279 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
280 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
281 |
+
|
282 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
283 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
284 |
+
|
285 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
286 |
+
|
287 |
+
return hx1d + hxin
|
288 |
+
|
289 |
+
|
290 |
+
### RSU-4F ###
|
291 |
+
class RSU4F(nn.Module):
|
292 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
293 |
+
super(RSU4F, self).__init__()
|
294 |
+
|
295 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
296 |
+
|
297 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
298 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
299 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
300 |
+
|
301 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
302 |
+
|
303 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
304 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
305 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
hx = x
|
309 |
+
|
310 |
+
hxin = self.rebnconvin(hx)
|
311 |
+
|
312 |
+
hx1 = self.rebnconv1(hxin)
|
313 |
+
hx2 = self.rebnconv2(hx1)
|
314 |
+
hx3 = self.rebnconv3(hx2)
|
315 |
+
|
316 |
+
hx4 = self.rebnconv4(hx3)
|
317 |
+
|
318 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
319 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
320 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
321 |
+
|
322 |
+
return hx1d + hxin
|
323 |
+
|
324 |
+
|
325 |
+
class ISNetDIS(nn.Module):
|
326 |
+
def __init__(self, in_ch=3, out_ch=1):
|
327 |
+
super(ISNetDIS, self).__init__()
|
328 |
+
|
329 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
330 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
331 |
+
|
332 |
+
self.stage1 = RSU7(64, 32, 64)
|
333 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
334 |
+
|
335 |
+
self.stage2 = RSU6(64, 32, 128)
|
336 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
337 |
+
|
338 |
+
self.stage3 = RSU5(128, 64, 256)
|
339 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
340 |
+
|
341 |
+
self.stage4 = RSU4(256, 128, 512)
|
342 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
343 |
+
|
344 |
+
self.stage5 = RSU4F(512, 256, 512)
|
345 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
346 |
+
|
347 |
+
self.stage6 = RSU4F(512, 256, 512)
|
348 |
+
|
349 |
+
# decoder
|
350 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
351 |
+
self.stage4d = RSU4(1024, 128, 256)
|
352 |
+
self.stage3d = RSU5(512, 64, 128)
|
353 |
+
self.stage2d = RSU6(256, 32, 64)
|
354 |
+
self.stage1d = RSU7(128, 16, 64)
|
355 |
+
|
356 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
357 |
+
|
358 |
+
def forward(self, x):
|
359 |
+
hx = x
|
360 |
+
|
361 |
+
hxin = self.conv_in(hx)
|
362 |
+
hx = self.pool_in(hxin)
|
363 |
+
|
364 |
+
# stage 1
|
365 |
+
hx1 = self.stage1(hxin)
|
366 |
+
hx = self.pool12(hx1)
|
367 |
+
|
368 |
+
# stage 2
|
369 |
+
hx2 = self.stage2(hx)
|
370 |
+
hx = self.pool23(hx2)
|
371 |
+
|
372 |
+
# stage 3
|
373 |
+
hx3 = self.stage3(hx)
|
374 |
+
hx = self.pool34(hx3)
|
375 |
+
|
376 |
+
# stage 4
|
377 |
+
hx4 = self.stage4(hx)
|
378 |
+
hx = self.pool45(hx4)
|
379 |
+
|
380 |
+
# stage 5
|
381 |
+
hx5 = self.stage5(hx)
|
382 |
+
hx = self.pool56(hx5)
|
383 |
+
|
384 |
+
# stage 6
|
385 |
+
hx6 = self.stage6(hx)
|
386 |
+
hx6up = _upsample_like(hx6, hx5)
|
387 |
+
|
388 |
+
# -------------------- decoder --------------------
|
389 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
390 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
391 |
+
|
392 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
393 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
394 |
+
|
395 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
396 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
397 |
+
|
398 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
399 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
400 |
+
|
401 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
402 |
+
|
403 |
+
# side output
|
404 |
+
d1 = self.side1(hx1d)
|
405 |
+
d1 = _upsample_like(d1, x)
|
406 |
+
return d1.sigmoid()
|
407 |
+
|
408 |
+
|
409 |
+
# 从小到大
|
410 |
+
ANIME_SEG_MODELS = {
|
411 |
+
"url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth",
|
412 |
+
"md5": "5f25479076b73074730ab8de9e8f2051",
|
413 |
+
}
|
414 |
+
|
415 |
+
|
416 |
+
class AnimeSeg(BasePlugin):
|
417 |
+
# Model from: https://github.com/SkyTNT/anime-segmentation
|
418 |
+
name = "AnimeSeg"
|
419 |
+
support_gen_image = True
|
420 |
+
support_gen_mask = True
|
421 |
+
|
422 |
+
def __init__(self):
|
423 |
+
super().__init__()
|
424 |
+
self.model = load_model(
|
425 |
+
ISNetDIS(),
|
426 |
+
ANIME_SEG_MODELS["url"],
|
427 |
+
"cpu",
|
428 |
+
ANIME_SEG_MODELS["md5"],
|
429 |
+
)
|
430 |
+
|
431 |
+
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
432 |
+
mask = self.forward(rgb_np_img)
|
433 |
+
mask = Image.fromarray(mask, mode="L")
|
434 |
+
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
435 |
+
empty = Image.new("RGBA", (w0, h0), 0)
|
436 |
+
img = Image.fromarray(rgb_np_img)
|
437 |
+
cutout = Image.composite(img, empty, mask)
|
438 |
+
return np.asarray(cutout)
|
439 |
+
|
440 |
+
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
441 |
+
return self.forward(rgb_np_img)
|
442 |
+
|
443 |
+
@torch.inference_mode()
|
444 |
+
def forward(self, rgb_np_img):
|
445 |
+
s = 1024
|
446 |
+
|
447 |
+
h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
448 |
+
if h > w:
|
449 |
+
h, w = s, int(s * w / h)
|
450 |
+
else:
|
451 |
+
h, w = int(s * h / w), s
|
452 |
+
ph, pw = s - h, s - w
|
453 |
+
tmpImg = np.zeros([s, s, 3], dtype=np.float32)
|
454 |
+
tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = (
|
455 |
+
cv2.resize(rgb_np_img, (w, h)) / 255
|
456 |
+
)
|
457 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
458 |
+
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor)
|
459 |
+
mask = self.model(tmpImg)
|
460 |
+
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
461 |
+
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
|
462 |
+
return (mask * 255).astype("uint8")
|
iopaint/plugins/base_plugin.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from iopaint.schema import RunPluginRequest
|
5 |
+
|
6 |
+
|
7 |
+
class BasePlugin:
|
8 |
+
name: str
|
9 |
+
support_gen_image: bool = False
|
10 |
+
support_gen_mask: bool = False
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
err_msg = self.check_dep()
|
14 |
+
if err_msg:
|
15 |
+
logger.error(err_msg)
|
16 |
+
exit(-1)
|
17 |
+
|
18 |
+
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
19 |
+
# return RGBA np image or BGR np image
|
20 |
+
...
|
21 |
+
|
22 |
+
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
23 |
+
# return GRAY or BGR np image, 255 means foreground, 0 means background
|
24 |
+
...
|
25 |
+
|
26 |
+
def check_dep(self):
|
27 |
+
...
|
28 |
+
|
29 |
+
def switch_model(self, new_model_name: str):
|
30 |
+
...
|
iopaint/plugins/segment_anything/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_sam import (
|
8 |
+
build_sam,
|
9 |
+
build_sam_vit_h,
|
10 |
+
build_sam_vit_l,
|
11 |
+
build_sam_vit_b,
|
12 |
+
sam_model_registry,
|
13 |
+
)
|
14 |
+
from .predictor import SamPredictor
|
iopaint/plugins/segment_anything/modeling/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .sam import Sam
|
8 |
+
from .image_encoder import ImageEncoderViT
|
9 |
+
from .mask_decoder import MaskDecoder
|
10 |
+
from .prompt_encoder import PromptEncoder
|
11 |
+
from .transformer import TwoWayTransformer
|
iopaint/plugins/segment_anything/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
iopaint/tests/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*_result.png
|
2 |
+
result/
|
iopaint/tests/__init__.py
ADDED
File without changes
|
model/__init__.py
ADDED
File without changes
|
utils/__init__.py
ADDED
File without changes
|