Spaces:
Runtime error
Runtime error
Update
Browse files- anime-girl.jpg +0 -0
- app.py +4 -5
- rembg/_version.py +3 -3
- rembg/bg.py +14 -1
- rembg/commands/b_command.py +161 -0
- rembg/commands/s_command.py +58 -11
- rembg/session_factory.py +4 -2
- rembg/sessions/base.py +24 -2
- rembg/sessions/dis_anime.py +49 -0
- rembg/sessions/dis_general_use.py +49 -0
- rembg/sessions/sam.py +8 -4
- rembg/sessions/silueta.py +4 -2
- rembg/sessions/u2net.py +4 -2
- rembg/sessions/u2net_cloth_seg.py +4 -2
- rembg/sessions/u2net_human_seg.py +4 -2
- rembg/sessions/u2netp.py +4 -2
anime-girl.jpg
ADDED
app.py
CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
|
|
5 |
import os
|
6 |
import cv2
|
7 |
|
8 |
-
def inference(file,
|
9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
10 |
cv2.imwrite(os.path.join("input.png"), im)
|
11 |
|
@@ -20,7 +20,6 @@ def inference(file, af, mask, model):
|
|
20 |
output = remove(
|
21 |
input,
|
22 |
session = new_session(model),
|
23 |
-
alpha_matting_erode_size = af,
|
24 |
only_mask = (True if mask == "Mask only" else False)
|
25 |
)
|
26 |
|
@@ -38,7 +37,6 @@ gr.Interface(
|
|
38 |
inference,
|
39 |
[
|
40 |
gr.inputs.Image(type="filepath", label="Input"),
|
41 |
-
gr.inputs.Slider(10, 25, default=10, label="Alpha matting erode size"),
|
42 |
gr.inputs.Radio(
|
43 |
[
|
44 |
"Default",
|
@@ -55,10 +53,11 @@ gr.Interface(
|
|
55 |
"u2net_cloth_seg",
|
56 |
"silueta",
|
57 |
"isnet-general-use",
|
|
|
58 |
"sam",
|
59 |
],
|
60 |
type="value",
|
61 |
-
default="
|
62 |
label="Models"
|
63 |
),
|
64 |
],
|
@@ -66,6 +65,6 @@ gr.Interface(
|
|
66 |
title=title,
|
67 |
description=description,
|
68 |
article=article,
|
69 |
-
examples=[["lion.png",
|
70 |
enable_queue=True
|
71 |
).launch()
|
|
|
5 |
import os
|
6 |
import cv2
|
7 |
|
8 |
+
def inference(file, mask, model):
|
9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
10 |
cv2.imwrite(os.path.join("input.png"), im)
|
11 |
|
|
|
20 |
output = remove(
|
21 |
input,
|
22 |
session = new_session(model),
|
|
|
23 |
only_mask = (True if mask == "Mask only" else False)
|
24 |
)
|
25 |
|
|
|
37 |
inference,
|
38 |
[
|
39 |
gr.inputs.Image(type="filepath", label="Input"),
|
|
|
40 |
gr.inputs.Radio(
|
41 |
[
|
42 |
"Default",
|
|
|
53 |
"u2net_cloth_seg",
|
54 |
"silueta",
|
55 |
"isnet-general-use",
|
56 |
+
"isnet-anime",
|
57 |
"sam",
|
58 |
],
|
59 |
type="value",
|
60 |
+
default="isnet-general-use",
|
61 |
label="Models"
|
62 |
),
|
63 |
],
|
|
|
65 |
title=title,
|
66 |
description=description,
|
67 |
article=article,
|
68 |
+
examples=[["lion.png", "Default", "u2net"], ["girl.jpg", "Default", "u2net"], ["anime-girl.jpg", "Default", "isnet-anime"]],
|
69 |
enable_queue=True
|
70 |
).launch()
|
rembg/_version.py
CHANGED
@@ -23,9 +23,9 @@ def get_keywords():
|
|
23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
-
git_refnames = " (HEAD -> main)"
|
27 |
-
git_full = "
|
28 |
-
git_date = "2023-
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
|
|
23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
+
git_refnames = " (HEAD -> main, tag: v2.0.43)"
|
27 |
+
git_full = "848a38e4cc5cf41522974dea00848596105b1dfa"
|
28 |
+
git_date = "2023-06-02 09:20:57 -0300"
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
rembg/bg.py
CHANGED
@@ -11,7 +11,7 @@ from cv2 import (
|
|
11 |
getStructuringElement,
|
12 |
morphologyEx,
|
13 |
)
|
14 |
-
from PIL import Image
|
15 |
from PIL.Image import Image as PILImage
|
16 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
17 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
@@ -19,6 +19,7 @@ from pymatting.util.util import stack_images
|
|
19 |
from scipy.ndimage import binary_erosion
|
20 |
|
21 |
from .session_factory import new_session
|
|
|
22 |
from .sessions.base import BaseSession
|
23 |
|
24 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
@@ -113,6 +114,15 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
|
|
113 |
return colored_image
|
114 |
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def remove(
|
117 |
data: Union[bytes, PILImage, np.ndarray],
|
118 |
alpha_matting: bool = False,
|
@@ -138,6 +148,9 @@ def remove(
|
|
138 |
else:
|
139 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
140 |
|
|
|
|
|
|
|
141 |
if session is None:
|
142 |
session = new_session("u2net", *args, **kwargs)
|
143 |
|
|
|
11 |
getStructuringElement,
|
12 |
morphologyEx,
|
13 |
)
|
14 |
+
from PIL import Image, ImageOps
|
15 |
from PIL.Image import Image as PILImage
|
16 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
17 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
|
19 |
from scipy.ndimage import binary_erosion
|
20 |
|
21 |
from .session_factory import new_session
|
22 |
+
from .sessions import sessions_class
|
23 |
from .sessions.base import BaseSession
|
24 |
|
25 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
|
|
114 |
return colored_image
|
115 |
|
116 |
|
117 |
+
def fix_image_orientation(img: PILImage) -> PILImage:
|
118 |
+
return ImageOps.exif_transpose(img)
|
119 |
+
|
120 |
+
|
121 |
+
def download_models() -> None:
|
122 |
+
for session in sessions_class:
|
123 |
+
session.download_models()
|
124 |
+
|
125 |
+
|
126 |
def remove(
|
127 |
data: Union[bytes, PILImage, np.ndarray],
|
128 |
alpha_matting: bool = False,
|
|
|
148 |
else:
|
149 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
150 |
|
151 |
+
# Fix image orientation
|
152 |
+
img = fix_image_orientation(img)
|
153 |
+
|
154 |
if session is None:
|
155 |
session = new_session("u2net", *args, **kwargs)
|
156 |
|
rembg/commands/b_command.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from typing import IO
|
7 |
+
|
8 |
+
import click
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from ..bg import remove
|
12 |
+
from ..session_factory import new_session
|
13 |
+
from ..sessions import sessions_names
|
14 |
+
|
15 |
+
|
16 |
+
@click.command(
|
17 |
+
name="b",
|
18 |
+
help="for a byte stream as input",
|
19 |
+
)
|
20 |
+
@click.option(
|
21 |
+
"-m",
|
22 |
+
"--model",
|
23 |
+
default="u2net",
|
24 |
+
type=click.Choice(sessions_names),
|
25 |
+
show_default=True,
|
26 |
+
show_choices=True,
|
27 |
+
help="model name",
|
28 |
+
)
|
29 |
+
@click.option(
|
30 |
+
"-a",
|
31 |
+
"--alpha-matting",
|
32 |
+
is_flag=True,
|
33 |
+
show_default=True,
|
34 |
+
help="use alpha matting",
|
35 |
+
)
|
36 |
+
@click.option(
|
37 |
+
"-af",
|
38 |
+
"--alpha-matting-foreground-threshold",
|
39 |
+
default=240,
|
40 |
+
type=int,
|
41 |
+
show_default=True,
|
42 |
+
help="trimap fg threshold",
|
43 |
+
)
|
44 |
+
@click.option(
|
45 |
+
"-ab",
|
46 |
+
"--alpha-matting-background-threshold",
|
47 |
+
default=10,
|
48 |
+
type=int,
|
49 |
+
show_default=True,
|
50 |
+
help="trimap bg threshold",
|
51 |
+
)
|
52 |
+
@click.option(
|
53 |
+
"-ae",
|
54 |
+
"--alpha-matting-erode-size",
|
55 |
+
default=10,
|
56 |
+
type=int,
|
57 |
+
show_default=True,
|
58 |
+
help="erode size",
|
59 |
+
)
|
60 |
+
@click.option(
|
61 |
+
"-om",
|
62 |
+
"--only-mask",
|
63 |
+
is_flag=True,
|
64 |
+
show_default=True,
|
65 |
+
help="output only the mask",
|
66 |
+
)
|
67 |
+
@click.option(
|
68 |
+
"-ppm",
|
69 |
+
"--post-process-mask",
|
70 |
+
is_flag=True,
|
71 |
+
show_default=True,
|
72 |
+
help="post process the mask",
|
73 |
+
)
|
74 |
+
@click.option(
|
75 |
+
"-bgc",
|
76 |
+
"--bgcolor",
|
77 |
+
default=None,
|
78 |
+
type=(int, int, int, int),
|
79 |
+
nargs=4,
|
80 |
+
help="Background color (R G B A) to replace the removed background with",
|
81 |
+
)
|
82 |
+
@click.option("-x", "--extras", type=str)
|
83 |
+
@click.option(
|
84 |
+
"-o",
|
85 |
+
"--output_specifier",
|
86 |
+
type=str,
|
87 |
+
help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
|
88 |
+
)
|
89 |
+
@click.argument(
|
90 |
+
"image_width",
|
91 |
+
type=int,
|
92 |
+
)
|
93 |
+
@click.argument(
|
94 |
+
"image_height",
|
95 |
+
type=int,
|
96 |
+
)
|
97 |
+
def rs_command(
|
98 |
+
model: str,
|
99 |
+
extras: str,
|
100 |
+
image_width: int,
|
101 |
+
image_height: int,
|
102 |
+
output_specifier: str,
|
103 |
+
**kwargs
|
104 |
+
) -> None:
|
105 |
+
try:
|
106 |
+
kwargs.update(json.loads(extras))
|
107 |
+
except Exception:
|
108 |
+
pass
|
109 |
+
|
110 |
+
session = new_session(model)
|
111 |
+
bytes_per_img = image_width * image_height * 3
|
112 |
+
|
113 |
+
if output_specifier:
|
114 |
+
output_dir = os.path.dirname(
|
115 |
+
os.path.abspath(os.path.expanduser(output_specifier))
|
116 |
+
)
|
117 |
+
|
118 |
+
if not os.path.isdir(output_dir):
|
119 |
+
os.makedirs(output_dir, exist_ok=True)
|
120 |
+
|
121 |
+
def img_to_byte_array(img: Image) -> bytes:
|
122 |
+
buff = io.BytesIO()
|
123 |
+
img.save(buff, format="PNG")
|
124 |
+
return buff.getvalue()
|
125 |
+
|
126 |
+
async def connect_stdin_stdout():
|
127 |
+
loop = asyncio.get_event_loop()
|
128 |
+
reader = asyncio.StreamReader()
|
129 |
+
protocol = asyncio.StreamReaderProtocol(reader)
|
130 |
+
|
131 |
+
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
132 |
+
w_transport, w_protocol = await loop.connect_write_pipe(
|
133 |
+
asyncio.streams.FlowControlMixin, sys.stdout
|
134 |
+
)
|
135 |
+
|
136 |
+
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
|
137 |
+
return reader, writer
|
138 |
+
|
139 |
+
async def main():
|
140 |
+
reader, writer = await connect_stdin_stdout()
|
141 |
+
|
142 |
+
idx = 0
|
143 |
+
while True:
|
144 |
+
try:
|
145 |
+
img_bytes = await reader.readexactly(bytes_per_img)
|
146 |
+
if not img_bytes:
|
147 |
+
break
|
148 |
+
|
149 |
+
img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
|
150 |
+
output = remove(img, session=session, **kwargs)
|
151 |
+
|
152 |
+
if output_specifier:
|
153 |
+
output.save((output_specifier % idx), format="PNG")
|
154 |
+
else:
|
155 |
+
writer.write(img_to_byte_array(output))
|
156 |
+
|
157 |
+
idx += 1
|
158 |
+
except asyncio.IncompleteReadError:
|
159 |
+
break
|
160 |
+
|
161 |
+
asyncio.run(main())
|
rembg/commands/s_command.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import json
|
2 |
-
|
|
|
|
|
3 |
|
4 |
import aiohttp
|
5 |
import click
|
|
|
6 |
import uvicorn
|
7 |
from asyncer import asyncify
|
8 |
from fastapi import Depends, FastAPI, File, Form, Query
|
@@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
70 |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
71 |
},
|
72 |
openapi_tags=tags_metadata,
|
|
|
73 |
)
|
74 |
|
75 |
app.add_middleware(
|
@@ -83,10 +87,10 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
83 |
class CommonQueryParams:
|
84 |
def __init__(
|
85 |
self,
|
86 |
-
model:
|
87 |
-
str, Query(regex=r"(" + "|".join(sessions_names) + ")")
|
88 |
-
] = Query(
|
89 |
description="Model to use when processing image",
|
|
|
|
|
90 |
),
|
91 |
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
92 |
af: int = Query(
|
@@ -128,10 +132,10 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
128 |
class CommonQueryPostParams:
|
129 |
def __init__(
|
130 |
self,
|
131 |
-
model:
|
132 |
-
str, Form(regex=r"(" + "|".join(sessions_names) + ")")
|
133 |
-
] = Form(
|
134 |
description="Model to use when processing image",
|
|
|
|
|
135 |
),
|
136 |
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
137 |
af: int = Form(
|
@@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
190 |
only_mask=commons.om,
|
191 |
post_process_mask=commons.ppm,
|
192 |
bgcolor=commons.bgc,
|
193 |
-
**kwargs
|
194 |
),
|
195 |
media_type="image/png",
|
196 |
)
|
197 |
|
198 |
@app.on_event("startup")
|
199 |
def startup():
|
|
|
|
|
|
|
|
|
|
|
200 |
if threads is not None:
|
201 |
from anyio import CapacityLimiter
|
202 |
from anyio.lowlevel import RunVar
|
@@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
204 |
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
205 |
|
206 |
@app.get(
|
207 |
-
path="/",
|
208 |
tags=["Background Removal"],
|
209 |
summary="Remove from URL",
|
210 |
description="Removes the background from an image obtained by retrieving an URL.",
|
@@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
221 |
return await asyncify(im_without_bg)(file, commons)
|
222 |
|
223 |
@app.post(
|
224 |
-
path="/",
|
225 |
tags=["Background Removal"],
|
226 |
summary="Remove from Stream",
|
227 |
description="Removes the background from an image sent within the request itself.",
|
@@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
235 |
):
|
236 |
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
237 |
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
+
import webbrowser
|
4 |
+
from typing import Optional, Tuple, cast
|
5 |
|
6 |
import aiohttp
|
7 |
import click
|
8 |
+
import gradio as gr
|
9 |
import uvicorn
|
10 |
from asyncer import asyncify
|
11 |
from fastapi import Depends, FastAPI, File, Form, Query
|
|
|
73 |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
74 |
},
|
75 |
openapi_tags=tags_metadata,
|
76 |
+
docs_url="/api",
|
77 |
)
|
78 |
|
79 |
app.add_middleware(
|
|
|
87 |
class CommonQueryParams:
|
88 |
def __init__(
|
89 |
self,
|
90 |
+
model: str = Query(
|
|
|
|
|
91 |
description="Model to use when processing image",
|
92 |
+
regex=r"(" + "|".join(sessions_names) + ")",
|
93 |
+
default="u2net",
|
94 |
),
|
95 |
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
96 |
af: int = Query(
|
|
|
132 |
class CommonQueryPostParams:
|
133 |
def __init__(
|
134 |
self,
|
135 |
+
model: str = Form(
|
|
|
|
|
136 |
description="Model to use when processing image",
|
137 |
+
regex=r"(" + "|".join(sessions_names) + ")",
|
138 |
+
default="u2net",
|
139 |
),
|
140 |
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
141 |
af: int = Form(
|
|
|
194 |
only_mask=commons.om,
|
195 |
post_process_mask=commons.ppm,
|
196 |
bgcolor=commons.bgc,
|
197 |
+
**kwargs,
|
198 |
),
|
199 |
media_type="image/png",
|
200 |
)
|
201 |
|
202 |
@app.on_event("startup")
|
203 |
def startup():
|
204 |
+
try:
|
205 |
+
webbrowser.open(f"http://localhost:{port}")
|
206 |
+
except Exception:
|
207 |
+
pass
|
208 |
+
|
209 |
if threads is not None:
|
210 |
from anyio import CapacityLimiter
|
211 |
from anyio.lowlevel import RunVar
|
|
|
213 |
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
214 |
|
215 |
@app.get(
|
216 |
+
path="/api/remove",
|
217 |
tags=["Background Removal"],
|
218 |
summary="Remove from URL",
|
219 |
description="Removes the background from an image obtained by retrieving an URL.",
|
|
|
230 |
return await asyncify(im_without_bg)(file, commons)
|
231 |
|
232 |
@app.post(
|
233 |
+
path="/api/remove",
|
234 |
tags=["Background Removal"],
|
235 |
summary="Remove from Stream",
|
236 |
description="Removes the background from an image sent within the request itself.",
|
|
|
244 |
):
|
245 |
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
246 |
|
247 |
+
def gr_app(app):
|
248 |
+
def inference(input_path, model):
|
249 |
+
output_path = "output.png"
|
250 |
+
with open(input_path, "rb") as i:
|
251 |
+
with open(output_path, "wb") as o:
|
252 |
+
input = i.read()
|
253 |
+
output = remove(input, session=new_session(model))
|
254 |
+
o.write(output)
|
255 |
+
return os.path.join(output_path)
|
256 |
+
|
257 |
+
interface = gr.Interface(
|
258 |
+
inference,
|
259 |
+
[
|
260 |
+
gr.components.Image(type="filepath", label="Input"),
|
261 |
+
gr.components.Dropdown(
|
262 |
+
[
|
263 |
+
"u2net",
|
264 |
+
"u2netp",
|
265 |
+
"u2net_human_seg",
|
266 |
+
"u2net_cloth_seg",
|
267 |
+
"silueta",
|
268 |
+
"isnet-general-use",
|
269 |
+
"isnet-anime",
|
270 |
+
],
|
271 |
+
value="u2net",
|
272 |
+
label="Models",
|
273 |
+
),
|
274 |
+
],
|
275 |
+
gr.components.Image(type="filepath", label="Output"),
|
276 |
+
)
|
277 |
+
|
278 |
+
interface.queue(concurrency_count=3)
|
279 |
+
app = gr.mount_gradio_app(app, interface, path="/")
|
280 |
+
return app
|
281 |
+
|
282 |
+
print(f"To access the API documentation, go to http://localhost:{port}/api")
|
283 |
+
print(f"To access the UI, go to http://localhost:{port}")
|
284 |
+
|
285 |
+
uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)
|
rembg/session_factory.py
CHANGED
@@ -8,7 +8,9 @@ from .sessions.base import BaseSession
|
|
8 |
from .sessions.u2net import U2netSession
|
9 |
|
10 |
|
11 |
-
def new_session(
|
|
|
|
|
12 |
session_class: Type[BaseSession] = U2netSession
|
13 |
|
14 |
for sc in sessions_class:
|
@@ -21,4 +23,4 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
|
21 |
if "OMP_NUM_THREADS" in os.environ:
|
22 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
23 |
|
24 |
-
return session_class(model_name, sess_opts, *args, **kwargs)
|
|
|
8 |
from .sessions.u2net import U2netSession
|
9 |
|
10 |
|
11 |
+
def new_session(
|
12 |
+
model_name: str = "u2net", providers=None, *args, **kwargs
|
13 |
+
) -> BaseSession:
|
14 |
session_class: Type[BaseSession] = U2netSession
|
15 |
|
16 |
for sc in sessions_class:
|
|
|
23 |
if "OMP_NUM_THREADS" in os.environ:
|
24 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
25 |
|
26 |
+
return session_class(model_name, sess_opts, providers, *args, **kwargs)
|
rembg/sessions/base.py
CHANGED
@@ -8,11 +8,29 @@ from PIL.Image import Image as PILImage
|
|
8 |
|
9 |
|
10 |
class BaseSession:
|
11 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
self.inner_session = ort.InferenceSession(
|
14 |
str(self.__class__.download_models()),
|
15 |
-
providers=
|
16 |
sess_options=sess_opts,
|
17 |
)
|
18 |
|
@@ -46,6 +64,10 @@ class BaseSession:
|
|
46 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
47 |
raise NotImplementedError
|
48 |
|
|
|
|
|
|
|
|
|
49 |
@classmethod
|
50 |
def u2net_home(cls, *args, **kwargs):
|
51 |
return os.path.expanduser(
|
|
|
8 |
|
9 |
|
10 |
class BaseSession:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model_name: str,
|
14 |
+
sess_opts: ort.SessionOptions,
|
15 |
+
providers=None,
|
16 |
+
*args,
|
17 |
+
**kwargs
|
18 |
+
):
|
19 |
self.model_name = model_name
|
20 |
+
|
21 |
+
self.providers = []
|
22 |
+
|
23 |
+
_providers = ort.get_available_providers()
|
24 |
+
if providers:
|
25 |
+
for provider in providers:
|
26 |
+
if provider in _providers:
|
27 |
+
self.providers.append(provider)
|
28 |
+
else:
|
29 |
+
self.providers.extend(_providers)
|
30 |
+
|
31 |
self.inner_session = ort.InferenceSession(
|
32 |
str(self.__class__.download_models()),
|
33 |
+
providers=self.providers,
|
34 |
sess_options=sess_opts,
|
35 |
)
|
36 |
|
|
|
64 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
65 |
raise NotImplementedError
|
66 |
|
67 |
+
@classmethod
|
68 |
+
def checksum_disabled(cls, *args, **kwargs):
|
69 |
+
return os.getenv("MODEL_CHECKSUM_DISABLED", None) is not None
|
70 |
+
|
71 |
@classmethod
|
72 |
def u2net_home(cls, *args, **kwargs):
|
73 |
return os.path.expanduser(
|
rembg/sessions/dis_anime.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class DisSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
17 |
+
)
|
18 |
+
|
19 |
+
pred = ort_outs[0][:, 0, :, :]
|
20 |
+
|
21 |
+
ma = np.max(pred)
|
22 |
+
mi = np.min(pred)
|
23 |
+
|
24 |
+
pred = (pred - mi) / (ma - mi)
|
25 |
+
pred = np.squeeze(pred)
|
26 |
+
|
27 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
29 |
+
|
30 |
+
return [mask]
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def download_models(cls, *args, **kwargs):
|
34 |
+
fname = f"{cls.name()}.onnx"
|
35 |
+
pooch.retrieve(
|
36 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
37 |
+
None
|
38 |
+
if cls.checksum_disabled(*args, **kwargs)
|
39 |
+
else "md5:6f184e756bb3bd901c8849220a83e38e",
|
40 |
+
fname=fname,
|
41 |
+
path=cls.u2net_home(*args, **kwargs),
|
42 |
+
progressbar=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
return os.path.join(cls.u2net_home(), fname)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def name(cls, *args, **kwargs):
|
49 |
+
return "isnet-anime"
|
rembg/sessions/dis_general_use.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class DisSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
17 |
+
)
|
18 |
+
|
19 |
+
pred = ort_outs[0][:, 0, :, :]
|
20 |
+
|
21 |
+
ma = np.max(pred)
|
22 |
+
mi = np.min(pred)
|
23 |
+
|
24 |
+
pred = (pred - mi) / (ma - mi)
|
25 |
+
pred = np.squeeze(pred)
|
26 |
+
|
27 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
29 |
+
|
30 |
+
return [mask]
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def download_models(cls, *args, **kwargs):
|
34 |
+
fname = f"{cls.name()}.onnx"
|
35 |
+
pooch.retrieve(
|
36 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
37 |
+
None
|
38 |
+
if cls.checksum_disabled(*args, **kwargs)
|
39 |
+
else "md5:fc16ebd8b0c10d971d3513d564d01e29",
|
40 |
+
fname=fname,
|
41 |
+
path=cls.u2net_home(*args, **kwargs),
|
42 |
+
progressbar=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
return os.path.join(cls.u2net_home(), fname)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def name(cls, *args, **kwargs):
|
49 |
+
return "isnet-general-use"
|
rembg/sessions/sam.py
CHANGED
@@ -141,17 +141,21 @@ class SamSession(BaseSession):
|
|
141 |
|
142 |
pooch.retrieve(
|
143 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
144 |
-
|
|
|
|
|
145 |
fname=fname_encoder,
|
146 |
-
path=cls.u2net_home(),
|
147 |
progressbar=True,
|
148 |
)
|
149 |
|
150 |
pooch.retrieve(
|
151 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
152 |
-
|
|
|
|
|
153 |
fname=fname_decoder,
|
154 |
-
path=cls.u2net_home(),
|
155 |
progressbar=True,
|
156 |
)
|
157 |
|
|
|
141 |
|
142 |
pooch.retrieve(
|
143 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
144 |
+
None
|
145 |
+
if cls.checksum_disabled(*args, **kwargs)
|
146 |
+
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
|
147 |
fname=fname_encoder,
|
148 |
+
path=cls.u2net_home(*args, **kwargs),
|
149 |
progressbar=True,
|
150 |
)
|
151 |
|
152 |
pooch.retrieve(
|
153 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
154 |
+
None
|
155 |
+
if cls.checksum_disabled(*args, **kwargs)
|
156 |
+
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
|
157 |
fname=fname_decoder,
|
158 |
+
path=cls.u2net_home(*args, **kwargs),
|
159 |
progressbar=True,
|
160 |
)
|
161 |
|
rembg/sessions/silueta.py
CHANGED
@@ -36,9 +36,11 @@ class SiluetaSession(BaseSession):
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
39 |
-
|
|
|
|
|
40 |
fname=fname,
|
41 |
-
path=cls.u2net_home(),
|
42 |
progressbar=True,
|
43 |
)
|
44 |
|
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
39 |
+
None
|
40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
41 |
+
else "md5:55e59e0d8062d2f5d013f4725ee84782",
|
42 |
fname=fname,
|
43 |
+
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
rembg/sessions/u2net.py
CHANGED
@@ -36,9 +36,11 @@ class U2netSession(BaseSession):
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
39 |
-
|
|
|
|
|
40 |
fname=fname,
|
41 |
-
path=cls.u2net_home(),
|
42 |
progressbar=True,
|
43 |
)
|
44 |
|
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
39 |
+
None
|
40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
41 |
+
else "md5:60024c5c889badc19c04ad937298a77b",
|
42 |
fname=fname,
|
43 |
+
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
rembg/sessions/u2net_cloth_seg.py
CHANGED
@@ -97,9 +97,11 @@ class Unet2ClothSession(BaseSession):
|
|
97 |
fname = f"{cls.name()}.onnx"
|
98 |
pooch.retrieve(
|
99 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
100 |
-
|
|
|
|
|
101 |
fname=fname,
|
102 |
-
path=cls.u2net_home(),
|
103 |
progressbar=True,
|
104 |
)
|
105 |
|
|
|
97 |
fname = f"{cls.name()}.onnx"
|
98 |
pooch.retrieve(
|
99 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
100 |
+
None
|
101 |
+
if cls.checksum_disabled(*args, **kwargs)
|
102 |
+
else "md5:2434d1f3cb744e0e49386c906e5a08bb",
|
103 |
fname=fname,
|
104 |
+
path=cls.u2net_home(*args, **kwargs),
|
105 |
progressbar=True,
|
106 |
)
|
107 |
|
rembg/sessions/u2net_human_seg.py
CHANGED
@@ -36,9 +36,11 @@ class U2netHumanSegSession(BaseSession):
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
39 |
-
|
|
|
|
|
40 |
fname=fname,
|
41 |
-
path=cls.u2net_home(),
|
42 |
progressbar=True,
|
43 |
)
|
44 |
|
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
39 |
+
None
|
40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
41 |
+
else "md5:c09ddc2e0104f800e3e1bb4652583d1f",
|
42 |
fname=fname,
|
43 |
+
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
rembg/sessions/u2netp.py
CHANGED
@@ -36,9 +36,11 @@ class U2netpSession(BaseSession):
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
39 |
-
|
|
|
|
|
40 |
fname=fname,
|
41 |
-
path=cls.u2net_home(),
|
42 |
progressbar=True,
|
43 |
)
|
44 |
|
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
39 |
+
None
|
40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
41 |
+
else "md5:8e83ca70e441ab06c318d82300c84806",
|
42 |
fname=fname,
|
43 |
+
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|