Spaces:
Running
Running
- rembg/__init__.py +1 -6
- rembg/_version.py +2 -2
- rembg/bg.py +32 -1
- rembg/cli.py +43 -3
- rembg/session_base.py +1 -1
- rembg/session_factory.py +32 -24
- rembg/session_simple.py +1 -1
rembg/__init__.py
CHANGED
@@ -1,11 +1,6 @@
|
|
1 |
-
import sys
|
2 |
-
import warnings
|
3 |
-
|
4 |
-
if not (sys.version_info.major == 3 and sys.version_info.minor == 9):
|
5 |
-
warnings.warn("This library is only for Python 3.9", RuntimeWarning)
|
6 |
-
|
7 |
from . import _version
|
8 |
|
9 |
__version__ = _version.get_versions()["version"]
|
10 |
|
11 |
from .bg import remove
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from . import _version
|
2 |
|
3 |
__version__ = _version.get_versions()["version"]
|
4 |
|
5 |
from .bg import remove
|
6 |
+
from .session_factory import new_session
|
rembg/_version.py
CHANGED
@@ -24,8 +24,8 @@ def get_keywords():
|
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
git_refnames = " (HEAD -> main)"
|
27 |
-
git_full = "
|
28 |
-
git_date = "2022-
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
|
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
git_refnames = " (HEAD -> main)"
|
27 |
+
git_full = "edc9fe27dff030cf6c2f29ef9a66c32d6e3f4658"
|
28 |
+
git_date = "2022-11-28 08:14:19 -0300"
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
rembg/bg.py
CHANGED
@@ -3,16 +3,26 @@ from enum import Enum
|
|
3 |
from typing import List, Optional, Union
|
4 |
|
5 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from PIL import Image
|
7 |
from PIL.Image import Image as PILImage
|
8 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
9 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
10 |
from pymatting.util.util import stack_images
|
11 |
-
from scipy.ndimage
|
12 |
|
13 |
from .session_base import BaseSession
|
14 |
from .session_factory import new_session
|
15 |
|
|
|
|
|
16 |
|
17 |
class ReturnType(Enum):
|
18 |
BYTES = 0
|
@@ -27,6 +37,10 @@ def alpha_matting_cutout(
|
|
27 |
background_threshold: int,
|
28 |
erode_structure_size: int,
|
29 |
) -> PILImage:
|
|
|
|
|
|
|
|
|
30 |
img = np.asarray(img)
|
31 |
mask = np.asarray(mask)
|
32 |
|
@@ -79,6 +93,19 @@ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
|
|
79 |
return dst
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def remove(
|
83 |
data: Union[bytes, PILImage, np.ndarray],
|
84 |
alpha_matting: bool = False,
|
@@ -87,6 +114,7 @@ def remove(
|
|
87 |
alpha_matting_erode_size: int = 10,
|
88 |
session: Optional[BaseSession] = None,
|
89 |
only_mask: bool = False,
|
|
|
90 |
) -> Union[bytes, PILImage, np.ndarray]:
|
91 |
|
92 |
if isinstance(data, PILImage):
|
@@ -108,6 +136,9 @@ def remove(
|
|
108 |
cutouts = []
|
109 |
|
110 |
for mask in masks:
|
|
|
|
|
|
|
111 |
if only_mask:
|
112 |
cutout = mask
|
113 |
|
|
|
3 |
from typing import List, Optional, Union
|
4 |
|
5 |
import numpy as np
|
6 |
+
from cv2 import (
|
7 |
+
BORDER_DEFAULT,
|
8 |
+
MORPH_ELLIPSE,
|
9 |
+
MORPH_OPEN,
|
10 |
+
GaussianBlur,
|
11 |
+
getStructuringElement,
|
12 |
+
morphologyEx,
|
13 |
+
)
|
14 |
from PIL import Image
|
15 |
from PIL.Image import Image as PILImage
|
16 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
17 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
18 |
from pymatting.util.util import stack_images
|
19 |
+
from scipy.ndimage import binary_erosion
|
20 |
|
21 |
from .session_base import BaseSession
|
22 |
from .session_factory import new_session
|
23 |
|
24 |
+
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
25 |
+
|
26 |
|
27 |
class ReturnType(Enum):
|
28 |
BYTES = 0
|
|
|
37 |
background_threshold: int,
|
38 |
erode_structure_size: int,
|
39 |
) -> PILImage:
|
40 |
+
|
41 |
+
if img.mode == "RGBA" or img.mode == "CMYK":
|
42 |
+
img = img.convert("RGB")
|
43 |
+
|
44 |
img = np.asarray(img)
|
45 |
mask = np.asarray(mask)
|
46 |
|
|
|
93 |
return dst
|
94 |
|
95 |
|
96 |
+
def post_process(mask: np.ndarray) -> np.ndarray:
|
97 |
+
"""
|
98 |
+
Post Process the mask for a smooth boundary by applying Morphological Operations
|
99 |
+
Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
|
100 |
+
args:
|
101 |
+
mask: Binary Numpy Mask
|
102 |
+
"""
|
103 |
+
mask = morphologyEx(mask, MORPH_OPEN, kernel)
|
104 |
+
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
|
105 |
+
mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
|
106 |
+
return mask
|
107 |
+
|
108 |
+
|
109 |
def remove(
|
110 |
data: Union[bytes, PILImage, np.ndarray],
|
111 |
alpha_matting: bool = False,
|
|
|
114 |
alpha_matting_erode_size: int = 10,
|
115 |
session: Optional[BaseSession] = None,
|
116 |
only_mask: bool = False,
|
117 |
+
post_process_mask: bool = False,
|
118 |
) -> Union[bytes, PILImage, np.ndarray]:
|
119 |
|
120 |
if isinstance(data, PILImage):
|
|
|
136 |
cutouts = []
|
137 |
|
138 |
for mask in masks:
|
139 |
+
if post_process_mask:
|
140 |
+
mask = Image.fromarray(post_process(np.array(mask)))
|
141 |
+
|
142 |
if only_mask:
|
143 |
cutout = mask
|
144 |
|
rembg/cli.py
CHANGED
@@ -33,7 +33,9 @@ def main() -> None:
|
|
33 |
"-m",
|
34 |
"--model",
|
35 |
default="u2net",
|
36 |
-
type=click.Choice(
|
|
|
|
|
37 |
show_default=True,
|
38 |
show_choices=True,
|
39 |
help="model name",
|
@@ -76,6 +78,13 @@ def main() -> None:
|
|
76 |
show_default=True,
|
77 |
help="output only the mask",
|
78 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
@click.argument(
|
80 |
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
81 |
)
|
@@ -93,7 +102,9 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
|
93 |
"-m",
|
94 |
"--model",
|
95 |
default="u2net",
|
96 |
-
type=click.Choice(
|
|
|
|
|
97 |
show_default=True,
|
98 |
show_choices=True,
|
99 |
help="model name",
|
@@ -136,6 +147,13 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
|
136 |
show_default=True,
|
137 |
help="output only the mask",
|
138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
@click.option(
|
140 |
"-w",
|
141 |
"--watch",
|
@@ -243,7 +261,15 @@ def p(
|
|
243 |
show_default=True,
|
244 |
help="log level",
|
245 |
)
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
sessions: dict[str, BaseSession] = {}
|
248 |
tags_metadata = [
|
249 |
{
|
@@ -284,6 +310,7 @@ def s(port: int, log_level: str) -> None:
|
|
284 |
u2netp = "u2netp"
|
285 |
u2net_human_seg = "u2net_human_seg"
|
286 |
u2net_cloth_seg = "u2net_cloth_seg"
|
|
|
287 |
|
288 |
class CommonQueryParams:
|
289 |
def __init__(
|
@@ -309,6 +336,7 @@ def s(port: int, log_level: str) -> None:
|
|
309 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
310 |
),
|
311 |
om: bool = Query(default=False, description="Only Mask"),
|
|
|
312 |
):
|
313 |
self.model = model
|
314 |
self.a = a
|
@@ -316,6 +344,7 @@ def s(port: int, log_level: str) -> None:
|
|
316 |
self.ab = ab
|
317 |
self.ae = ae
|
318 |
self.om = om
|
|
|
319 |
|
320 |
class CommonQueryPostParams:
|
321 |
def __init__(
|
@@ -341,6 +370,7 @@ def s(port: int, log_level: str) -> None:
|
|
341 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
342 |
),
|
343 |
om: bool = Form(default=False, description="Only Mask"),
|
|
|
344 |
):
|
345 |
self.model = model
|
346 |
self.a = a
|
@@ -348,6 +378,7 @@ def s(port: int, log_level: str) -> None:
|
|
348 |
self.ab = ab
|
349 |
self.ae = ae
|
350 |
self.om = om
|
|
|
351 |
|
352 |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
353 |
return Response(
|
@@ -361,10 +392,19 @@ def s(port: int, log_level: str) -> None:
|
|
361 |
alpha_matting_background_threshold=commons.ab,
|
362 |
alpha_matting_erode_size=commons.ae,
|
363 |
only_mask=commons.om,
|
|
|
364 |
),
|
365 |
media_type="image/png",
|
366 |
)
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
@app.get(
|
369 |
path="/",
|
370 |
tags=["Background Removal"],
|
|
|
33 |
"-m",
|
34 |
"--model",
|
35 |
default="u2net",
|
36 |
+
type=click.Choice(
|
37 |
+
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
38 |
+
),
|
39 |
show_default=True,
|
40 |
show_choices=True,
|
41 |
help="model name",
|
|
|
78 |
show_default=True,
|
79 |
help="output only the mask",
|
80 |
)
|
81 |
+
@click.option(
|
82 |
+
"-ppm",
|
83 |
+
"--post-process-mask",
|
84 |
+
is_flag=True,
|
85 |
+
show_default=True,
|
86 |
+
help="post process the mask",
|
87 |
+
)
|
88 |
@click.argument(
|
89 |
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
90 |
)
|
|
|
102 |
"-m",
|
103 |
"--model",
|
104 |
default="u2net",
|
105 |
+
type=click.Choice(
|
106 |
+
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
107 |
+
),
|
108 |
show_default=True,
|
109 |
show_choices=True,
|
110 |
help="model name",
|
|
|
147 |
show_default=True,
|
148 |
help="output only the mask",
|
149 |
)
|
150 |
+
@click.option(
|
151 |
+
"-ppm",
|
152 |
+
"--post-process-mask",
|
153 |
+
is_flag=True,
|
154 |
+
show_default=True,
|
155 |
+
help="post process the mask",
|
156 |
+
)
|
157 |
@click.option(
|
158 |
"-w",
|
159 |
"--watch",
|
|
|
261 |
show_default=True,
|
262 |
help="log level",
|
263 |
)
|
264 |
+
@click.option(
|
265 |
+
"-t",
|
266 |
+
"--threads",
|
267 |
+
default=None,
|
268 |
+
type=int,
|
269 |
+
show_default=True,
|
270 |
+
help="number of worker threads",
|
271 |
+
)
|
272 |
+
def s(port: int, log_level: str, threads: int) -> None:
|
273 |
sessions: dict[str, BaseSession] = {}
|
274 |
tags_metadata = [
|
275 |
{
|
|
|
310 |
u2netp = "u2netp"
|
311 |
u2net_human_seg = "u2net_human_seg"
|
312 |
u2net_cloth_seg = "u2net_cloth_seg"
|
313 |
+
silueta = "silueta"
|
314 |
|
315 |
class CommonQueryParams:
|
316 |
def __init__(
|
|
|
336 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
337 |
),
|
338 |
om: bool = Query(default=False, description="Only Mask"),
|
339 |
+
ppm: bool = Query(default=False, description="Post Process Mask"),
|
340 |
):
|
341 |
self.model = model
|
342 |
self.a = a
|
|
|
344 |
self.ab = ab
|
345 |
self.ae = ae
|
346 |
self.om = om
|
347 |
+
self.ppm = ppm
|
348 |
|
349 |
class CommonQueryPostParams:
|
350 |
def __init__(
|
|
|
370 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
371 |
),
|
372 |
om: bool = Form(default=False, description="Only Mask"),
|
373 |
+
ppm: bool = Form(default=False, description="Post Process Mask"),
|
374 |
):
|
375 |
self.model = model
|
376 |
self.a = a
|
|
|
378 |
self.ab = ab
|
379 |
self.ae = ae
|
380 |
self.om = om
|
381 |
+
self.ppm = ppm
|
382 |
|
383 |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
384 |
return Response(
|
|
|
392 |
alpha_matting_background_threshold=commons.ab,
|
393 |
alpha_matting_erode_size=commons.ae,
|
394 |
only_mask=commons.om,
|
395 |
+
post_process_mask=commons.ppm,
|
396 |
),
|
397 |
media_type="image/png",
|
398 |
)
|
399 |
|
400 |
+
@app.on_event("startup")
|
401 |
+
def startup():
|
402 |
+
if threads is not None:
|
403 |
+
from anyio import CapacityLimiter
|
404 |
+
from anyio.lowlevel import RunVar
|
405 |
+
|
406 |
+
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
407 |
+
|
408 |
@app.get(
|
409 |
path="/",
|
410 |
tags=["Background Removal"],
|
rembg/session_base.py
CHANGED
@@ -18,7 +18,7 @@ class BaseSession:
|
|
18 |
std: Tuple[float, float, float],
|
19 |
size: Tuple[int, int],
|
20 |
) -> Dict[str, np.ndarray]:
|
21 |
-
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
22 |
|
23 |
im_ary = np.array(im)
|
24 |
im_ary = im_ary / np.max(im_ary)
|
|
|
18 |
std: Tuple[float, float, float],
|
19 |
size: Tuple[int, int],
|
20 |
) -> Dict[str, np.ndarray]:
|
21 |
+
im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
|
22 |
|
23 |
im_ary = np.array(im)
|
24 |
im_ary = im_ary / np.max(im_ary)
|
rembg/session_factory.py
CHANGED
@@ -5,50 +5,56 @@ from contextlib import redirect_stdout
|
|
5 |
from pathlib import Path
|
6 |
from typing import Type
|
7 |
|
8 |
-
import gdown
|
9 |
import onnxruntime as ort
|
|
|
10 |
|
11 |
from .session_base import BaseSession
|
12 |
from .session_cloth import ClothSession
|
13 |
from .session_simple import SimpleSession
|
14 |
|
15 |
|
16 |
-
def new_session(model_name: str) -> BaseSession:
|
17 |
session_class: Type[BaseSession]
|
|
|
|
|
|
|
18 |
|
19 |
if model_name == "u2netp":
|
20 |
md5 = "8e83ca70e441ab06c318d82300c84806"
|
21 |
-
url =
|
22 |
-
|
23 |
-
|
24 |
-
md5 = "60024c5c889badc19c04ad937298a77b"
|
25 |
-
url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
|
26 |
session_class = SimpleSession
|
27 |
elif model_name == "u2net_human_seg":
|
28 |
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
29 |
-
url = "https://
|
30 |
session_class = SimpleSession
|
31 |
elif model_name == "u2net_cloth_seg":
|
32 |
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
33 |
-
url = "https://
|
34 |
session_class = ClothSession
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
38 |
)
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
gdown.download(url, str(path), use_cookies=False)
|
52 |
|
53 |
sess_opts = ort.SessionOptions()
|
54 |
|
@@ -58,6 +64,8 @@ def new_session(model_name: str) -> BaseSession:
|
|
58 |
return session_class(
|
59 |
model_name,
|
60 |
ort.InferenceSession(
|
61 |
-
str(
|
|
|
|
|
62 |
),
|
63 |
)
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import Type
|
7 |
|
|
|
8 |
import onnxruntime as ort
|
9 |
+
import pooch
|
10 |
|
11 |
from .session_base import BaseSession
|
12 |
from .session_cloth import ClothSession
|
13 |
from .session_simple import SimpleSession
|
14 |
|
15 |
|
16 |
+
def new_session(model_name: str = "u2net") -> BaseSession:
|
17 |
session_class: Type[BaseSession]
|
18 |
+
md5 = "60024c5c889badc19c04ad937298a77b"
|
19 |
+
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
20 |
+
session_class = SimpleSession
|
21 |
|
22 |
if model_name == "u2netp":
|
23 |
md5 = "8e83ca70e441ab06c318d82300c84806"
|
24 |
+
url = (
|
25 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
|
26 |
+
)
|
|
|
|
|
27 |
session_class = SimpleSession
|
28 |
elif model_name == "u2net_human_seg":
|
29 |
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
30 |
+
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
|
31 |
session_class = SimpleSession
|
32 |
elif model_name == "u2net_cloth_seg":
|
33 |
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
34 |
+
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
|
35 |
session_class = ClothSession
|
36 |
+
elif model_name == "silueta":
|
37 |
+
md5 = "55e59e0d8062d2f5d013f4725ee84782"
|
38 |
+
url = (
|
39 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
40 |
)
|
41 |
+
session_class = SimpleSession
|
42 |
|
43 |
+
u2net_home = os.getenv(
|
44 |
+
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
45 |
+
)
|
46 |
+
|
47 |
+
fname = f"{model_name}.onnx"
|
48 |
+
path = Path(u2net_home).expanduser()
|
49 |
+
full_path = Path(u2net_home).expanduser() / fname
|
50 |
|
51 |
+
pooch.retrieve(
|
52 |
+
url,
|
53 |
+
f"md5:{md5}",
|
54 |
+
fname=fname,
|
55 |
+
path=Path(u2net_home).expanduser(),
|
56 |
+
progressbar=True,
|
57 |
+
)
|
|
|
58 |
|
59 |
sess_opts = ort.SessionOptions()
|
60 |
|
|
|
64 |
return session_class(
|
65 |
model_name,
|
66 |
ort.InferenceSession(
|
67 |
+
str(full_path),
|
68 |
+
providers=ort.get_available_providers(),
|
69 |
+
sess_options=sess_opts,
|
70 |
),
|
71 |
)
|
rembg/session_simple.py
CHANGED
@@ -25,6 +25,6 @@ class SimpleSession(BaseSession):
|
|
25 |
pred = np.squeeze(pred)
|
26 |
|
27 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
29 |
|
30 |
return [mask]
|
|
|
25 |
pred = np.squeeze(pred)
|
26 |
|
27 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
29 |
|
30 |
return [mask]
|