schirrmacher
commited on
Commit
•
04566b4
1
Parent(s):
08aed96
Upload folder using huggingface_hub
Browse files- utils/inference.py +33 -7
utils/inference.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
from skimage import io
|
@@ -6,6 +7,31 @@ from ormbg import ORMBG
|
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
|
10 |
if len(im.shape) < 3:
|
11 |
im = im[:, :, np.newaxis]
|
@@ -27,19 +53,19 @@ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
|
27 |
return im_array
|
28 |
|
29 |
|
30 |
-
def
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
net = ORMBG()
|
36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
|
38 |
if torch.cuda.is_available():
|
39 |
-
net.load_state_dict(torch.load(
|
40 |
net = net.cuda()
|
41 |
else:
|
42 |
-
net.load_state_dict(torch.load(
|
43 |
net.eval()
|
44 |
|
45 |
model_input_size = [1024, 1024]
|
@@ -61,4 +87,4 @@ def example_inference():
|
|
61 |
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
|
|
|
1 |
import torch
|
2 |
+
import argparse
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
from skimage import io
|
|
|
7 |
import torch.nn.functional as F
|
8 |
|
9 |
|
10 |
+
def parse_args():
|
11 |
+
parser = argparse.ArgumentParser(
|
12 |
+
description="Remove background from images using ORMBG model."
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--input",
|
16 |
+
type=str,
|
17 |
+
default="example.png",
|
18 |
+
help="Path to the input image file.",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--output",
|
22 |
+
type=str,
|
23 |
+
default="no-background.png",
|
24 |
+
help="Path to the output image file.",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--model-path",
|
28 |
+
type=str,
|
29 |
+
default="models/ormbg.pth",
|
30 |
+
help="Path to the model file.",
|
31 |
+
)
|
32 |
+
return parser.parse_args()
|
33 |
+
|
34 |
+
|
35 |
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
|
36 |
if len(im.shape) < 3:
|
37 |
im = im[:, :, np.newaxis]
|
|
|
53 |
return im_array
|
54 |
|
55 |
|
56 |
+
def inference(args):
|
57 |
+
image_path = args.input
|
58 |
+
result_name = args.output
|
59 |
+
model_path = args.model_path
|
60 |
|
61 |
net = ORMBG()
|
62 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
|
64 |
if torch.cuda.is_available():
|
65 |
+
net.load_state_dict(torch.load(model_path))
|
66 |
net = net.cuda()
|
67 |
else:
|
68 |
+
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
69 |
net.eval()
|
70 |
|
71 |
model_input_size = [1024, 1024]
|
|
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
+
inference(parse_args())
|