Spaces:
Runtime error
Runtime error
Update files
Browse files- PyPatchMatch/.gitignore +4 -4
- PyPatchMatch/LICENSE +21 -21
- PyPatchMatch/Makefile +54 -54
- PyPatchMatch/README.md +64 -64
- PyPatchMatch/csrc/inpaint.cpp +234 -234
- PyPatchMatch/csrc/inpaint.h +27 -27
- PyPatchMatch/csrc/masked_image.cpp +138 -138
- PyPatchMatch/csrc/masked_image.h +112 -112
- PyPatchMatch/csrc/nnf.cpp +268 -268
- PyPatchMatch/csrc/nnf.h +133 -133
- PyPatchMatch/csrc/pyinterface.cpp +107 -107
- PyPatchMatch/csrc/pyinterface.h +38 -38
- PyPatchMatch/examples/.gitignore +2 -2
- PyPatchMatch/examples/cpp_example.cpp +31 -31
- PyPatchMatch/examples/cpp_example_run.sh +18 -18
- PyPatchMatch/examples/py_example.py +21 -21
- PyPatchMatch/examples/py_example_global_mask.py +27 -27
- PyPatchMatch/patch_match.py +263 -201
- PyPatchMatch/travis.sh +9 -9
- config.yaml +18 -0
- convert_checkpoint.py +706 -0
- css/w2ui.min.css +0 -0
- js/fabric.min.js +0 -0
- js/keyboard.js +37 -0
- js/mode.js +5 -5
- js/outpaint.js +22 -30
- js/proceed.js +41 -21
- js/setup.js +27 -21
- js/toolbar.js +581 -0
- js/upload.js +18 -19
- js/w2ui.min.js +0 -0
- js/xss.js +31 -0
- models/v1-inference.yaml +70 -0
- models/v1-inpainting-inference.yaml +70 -0
PyPatchMatch/.gitignore
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
/build/
|
2 |
-
/*.so
|
3 |
-
__pycache__
|
4 |
-
*.py[cod]
|
|
|
1 |
+
/build/
|
2 |
+
/*.so
|
3 |
+
__pycache__
|
4 |
+
*.py[cod]
|
PyPatchMatch/LICENSE
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2020 Jiayuan Mao
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Jiayuan Mao
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
PyPatchMatch/Makefile
CHANGED
@@ -1,54 +1,54 @@
|
|
1 |
-
#
|
2 |
-
# Makefile
|
3 |
-
# Jiayuan Mao, 2019-01-09 13:59
|
4 |
-
#
|
5 |
-
|
6 |
-
SRC_DIR = csrc
|
7 |
-
INC_DIR = csrc
|
8 |
-
OBJ_DIR = build/obj
|
9 |
-
TARGET = libpatchmatch.so
|
10 |
-
|
11 |
-
LIB_TARGET = $(TARGET)
|
12 |
-
INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
|
13 |
-
|
14 |
-
CXX = $(ENVIRONMENT_OPTIONS) g++
|
15 |
-
CXXFLAGS = -std=c++14
|
16 |
-
CXXFLAGS += -Ofast -ffast-math -w
|
17 |
-
# CXXFLAGS += -g
|
18 |
-
CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
|
19 |
-
CXXFLAGS += $(INCLUDE_DIR)
|
20 |
-
LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
|
21 |
-
|
22 |
-
|
23 |
-
CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
|
24 |
-
OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
|
25 |
-
DEPFILES = $(OBJS:.o=.d)
|
26 |
-
|
27 |
-
.PHONY: all clean rebuild test
|
28 |
-
|
29 |
-
all: $(LIB_TARGET)
|
30 |
-
|
31 |
-
$(OBJ_DIR)/%.o: %.cpp
|
32 |
-
@echo "[CC] $< ..."
|
33 |
-
@$(CXX) -c $< $(CXXFLAGS) -o $@
|
34 |
-
|
35 |
-
$(OBJ_DIR)/%.d: %.cpp
|
36 |
-
@mkdir -pv $(dir $@)
|
37 |
-
@echo "[dep] $< ..."
|
38 |
-
@$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
|
39 |
-
|
40 |
-
sinclude $(DEPFILES)
|
41 |
-
|
42 |
-
$(LIB_TARGET): $(OBJS)
|
43 |
-
@echo "[link] $(LIB_TARGET) ..."
|
44 |
-
@$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
|
45 |
-
|
46 |
-
clean:
|
47 |
-
rm -rf $(OBJ_DIR) $(LIB_TARGET)
|
48 |
-
|
49 |
-
rebuild:
|
50 |
-
+@make clean
|
51 |
-
+@make
|
52 |
-
|
53 |
-
# vim:ft=make
|
54 |
-
#
|
|
|
1 |
+
#
|
2 |
+
# Makefile
|
3 |
+
# Jiayuan Mao, 2019-01-09 13:59
|
4 |
+
#
|
5 |
+
|
6 |
+
SRC_DIR = csrc
|
7 |
+
INC_DIR = csrc
|
8 |
+
OBJ_DIR = build/obj
|
9 |
+
TARGET = libpatchmatch.so
|
10 |
+
|
11 |
+
LIB_TARGET = $(TARGET)
|
12 |
+
INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
|
13 |
+
|
14 |
+
CXX = $(ENVIRONMENT_OPTIONS) g++
|
15 |
+
CXXFLAGS = -std=c++14
|
16 |
+
CXXFLAGS += -Ofast -ffast-math -w
|
17 |
+
# CXXFLAGS += -g
|
18 |
+
CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
|
19 |
+
CXXFLAGS += $(INCLUDE_DIR)
|
20 |
+
LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
|
21 |
+
|
22 |
+
|
23 |
+
CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
|
24 |
+
OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
|
25 |
+
DEPFILES = $(OBJS:.o=.d)
|
26 |
+
|
27 |
+
.PHONY: all clean rebuild test
|
28 |
+
|
29 |
+
all: $(LIB_TARGET)
|
30 |
+
|
31 |
+
$(OBJ_DIR)/%.o: %.cpp
|
32 |
+
@echo "[CC] $< ..."
|
33 |
+
@$(CXX) -c $< $(CXXFLAGS) -o $@
|
34 |
+
|
35 |
+
$(OBJ_DIR)/%.d: %.cpp
|
36 |
+
@mkdir -pv $(dir $@)
|
37 |
+
@echo "[dep] $< ..."
|
38 |
+
@$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
|
39 |
+
|
40 |
+
sinclude $(DEPFILES)
|
41 |
+
|
42 |
+
$(LIB_TARGET): $(OBJS)
|
43 |
+
@echo "[link] $(LIB_TARGET) ..."
|
44 |
+
@$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
|
45 |
+
|
46 |
+
clean:
|
47 |
+
rm -rf $(OBJ_DIR) $(LIB_TARGET)
|
48 |
+
|
49 |
+
rebuild:
|
50 |
+
+@make clean
|
51 |
+
+@make
|
52 |
+
|
53 |
+
# vim:ft=make
|
54 |
+
#
|
PyPatchMatch/README.md
CHANGED
@@ -1,64 +1,64 @@
|
|
1 |
-
PatchMatch based Inpainting
|
2 |
-
=====================================
|
3 |
-
This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
|
4 |
-
This implementation is heavily based on the implementation by Younesse ANDAM:
|
5 |
-
(younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
|
6 |
-
|
7 |
-
Usage
|
8 |
-
-------------------------------------
|
9 |
-
|
10 |
-
You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
|
11 |
-
shared library `libpatchmatch.so`.
|
12 |
-
|
13 |
-
For Python users (example available at `examples/py_example.py`)
|
14 |
-
|
15 |
-
```python
|
16 |
-
import patch_match
|
17 |
-
|
18 |
-
image = ... # either a numpy ndarray or a PIL Image object.
|
19 |
-
mask = ... # either a numpy ndarray or a PIL Image object.
|
20 |
-
result = patch_match.inpaint(image, mask, patch_size=5)
|
21 |
-
```
|
22 |
-
|
23 |
-
For C++ users (examples available at `examples/cpp_example.cpp`)
|
24 |
-
|
25 |
-
```cpp
|
26 |
-
#include "inpaint.h"
|
27 |
-
|
28 |
-
int main() {
|
29 |
-
cv::Mat image = ...
|
30 |
-
cv::Mat mask = ...
|
31 |
-
|
32 |
-
cv::Mat result = Inpainting(image, mask, 5).run();
|
33 |
-
|
34 |
-
return 0;
|
35 |
-
}
|
36 |
-
```
|
37 |
-
|
38 |
-
|
39 |
-
README and COPYRIGHT by Younesse ANDAM
|
40 |
-
-------------------------------------
|
41 |
-
@Author: Younesse ANDAM
|
42 |
-
|
43 |
-
@Contact: [email protected]
|
44 |
-
|
45 |
-
Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
|
46 |
-
The algorithm is presented in the following paper
|
47 |
-
PatchMatch A Randomized Correspondence Algorithm
|
48 |
-
for Structural Image Editing
|
49 |
-
by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
|
50 |
-
ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
|
51 |
-
|
52 |
-
For more information please refer to
|
53 |
-
http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
|
54 |
-
|
55 |
-
Copyright (c) 2010-2011
|
56 |
-
|
57 |
-
|
58 |
-
Requirements
|
59 |
-
-------------------------------------
|
60 |
-
|
61 |
-
To run the project you need to install Opencv library and link it to your project.
|
62 |
-
Opencv can be download it here
|
63 |
-
http://opencv.org/downloads.html
|
64 |
-
|
|
|
1 |
+
PatchMatch based Inpainting
|
2 |
+
=====================================
|
3 |
+
This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
|
4 |
+
This implementation is heavily based on the implementation by Younesse ANDAM:
|
5 |
+
(younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
|
6 |
+
|
7 |
+
Usage
|
8 |
+
-------------------------------------
|
9 |
+
|
10 |
+
You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
|
11 |
+
shared library `libpatchmatch.so`.
|
12 |
+
|
13 |
+
For Python users (example available at `examples/py_example.py`)
|
14 |
+
|
15 |
+
```python
|
16 |
+
import patch_match
|
17 |
+
|
18 |
+
image = ... # either a numpy ndarray or a PIL Image object.
|
19 |
+
mask = ... # either a numpy ndarray or a PIL Image object.
|
20 |
+
result = patch_match.inpaint(image, mask, patch_size=5)
|
21 |
+
```
|
22 |
+
|
23 |
+
For C++ users (examples available at `examples/cpp_example.cpp`)
|
24 |
+
|
25 |
+
```cpp
|
26 |
+
#include "inpaint.h"
|
27 |
+
|
28 |
+
int main() {
|
29 |
+
cv::Mat image = ...
|
30 |
+
cv::Mat mask = ...
|
31 |
+
|
32 |
+
cv::Mat result = Inpainting(image, mask, 5).run();
|
33 |
+
|
34 |
+
return 0;
|
35 |
+
}
|
36 |
+
```
|
37 |
+
|
38 |
+
|
39 |
+
README and COPYRIGHT by Younesse ANDAM
|
40 |
+
-------------------------------------
|
41 |
+
@Author: Younesse ANDAM
|
42 |
+
|
43 |
+
@Contact: [email protected]
|
44 |
+
|
45 |
+
Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
|
46 |
+
The algorithm is presented in the following paper
|
47 |
+
PatchMatch A Randomized Correspondence Algorithm
|
48 |
+
for Structural Image Editing
|
49 |
+
by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
|
50 |
+
ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
|
51 |
+
|
52 |
+
For more information please refer to
|
53 |
+
http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
|
54 |
+
|
55 |
+
Copyright (c) 2010-2011
|
56 |
+
|
57 |
+
|
58 |
+
Requirements
|
59 |
+
-------------------------------------
|
60 |
+
|
61 |
+
To run the project you need to install Opencv library and link it to your project.
|
62 |
+
Opencv can be download it here
|
63 |
+
http://opencv.org/downloads.html
|
64 |
+
|
PyPatchMatch/csrc/inpaint.cpp
CHANGED
@@ -1,234 +1,234 @@
|
|
1 |
-
#include <algorithm>
|
2 |
-
#include <iostream>
|
3 |
-
#include <opencv2/imgcodecs.hpp>
|
4 |
-
#include <opencv2/imgproc.hpp>
|
5 |
-
#include <opencv2/highgui.hpp>
|
6 |
-
|
7 |
-
#include "inpaint.h"
|
8 |
-
|
9 |
-
namespace {
|
10 |
-
static std::vector<double> kDistance2Similarity;
|
11 |
-
|
12 |
-
void init_kDistance2Similarity() {
|
13 |
-
double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
|
14 |
-
int length = (PatchDistanceMetric::kDistanceScale + 1);
|
15 |
-
kDistance2Similarity.resize(length);
|
16 |
-
for (int i = 0; i < length; ++i) {
|
17 |
-
double t = (double) i / length;
|
18 |
-
int j = (int) (100 * t);
|
19 |
-
int k = j + 1;
|
20 |
-
double vj = (j < 11) ? base[j] : 0;
|
21 |
-
double vk = (k < 11) ? base[k] : 0;
|
22 |
-
kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
|
23 |
-
}
|
24 |
-
}
|
25 |
-
|
26 |
-
|
27 |
-
inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
|
28 |
-
if (source.is_masked(ys, xs)) return;
|
29 |
-
if (source.is_globally_masked(ys, xs)) return;
|
30 |
-
|
31 |
-
auto source_ptr = source.get_image(ys, xs);
|
32 |
-
auto target_ptr = target.ptr<double>(yt, xt);
|
33 |
-
|
34 |
-
#pragma unroll
|
35 |
-
for (int c = 0; c < 3; ++c)
|
36 |
-
target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
|
37 |
-
target_ptr[3] += weight;
|
38 |
-
}
|
39 |
-
}
|
40 |
-
|
41 |
-
/**
|
42 |
-
* This algorithme uses a version proposed by Xavier Philippeau.
|
43 |
-
*/
|
44 |
-
|
45 |
-
Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
|
46 |
-
: m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
|
47 |
-
_initialize_pyramid();
|
48 |
-
}
|
49 |
-
|
50 |
-
Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
|
51 |
-
: m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
|
52 |
-
_initialize_pyramid();
|
53 |
-
}
|
54 |
-
|
55 |
-
void Inpainting::_initialize_pyramid() {
|
56 |
-
auto source = m_initial;
|
57 |
-
m_pyramid.push_back(source);
|
58 |
-
while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
|
59 |
-
source = source.downsample();
|
60 |
-
m_pyramid.push_back(source);
|
61 |
-
}
|
62 |
-
|
63 |
-
if (kDistance2Similarity.size() == 0) {
|
64 |
-
init_kDistance2Similarity();
|
65 |
-
}
|
66 |
-
}
|
67 |
-
|
68 |
-
cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
|
69 |
-
srand(random_seed);
|
70 |
-
const int nr_levels = m_pyramid.size();
|
71 |
-
|
72 |
-
MaskedImage source, target;
|
73 |
-
for (int level = nr_levels - 1; level >= 0; --level) {
|
74 |
-
if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
|
75 |
-
|
76 |
-
source = m_pyramid[level];
|
77 |
-
|
78 |
-
if (level == nr_levels - 1) {
|
79 |
-
target = source.clone();
|
80 |
-
target.clear_mask();
|
81 |
-
m_source2target = NearestNeighborField(source, target, m_distance_metric);
|
82 |
-
m_target2source = NearestNeighborField(target, source, m_distance_metric);
|
83 |
-
} else {
|
84 |
-
m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
|
85 |
-
m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
|
86 |
-
}
|
87 |
-
|
88 |
-
if (verbose) std::cerr << "Initialization done." << std::endl;
|
89 |
-
|
90 |
-
if (verbose_visualize) {
|
91 |
-
auto visualize_size = m_initial.size();
|
92 |
-
cv::Mat source_visualize(visualize_size, m_initial.image().type());
|
93 |
-
cv::resize(source.image(), source_visualize, visualize_size);
|
94 |
-
cv::imshow("Source", source_visualize);
|
95 |
-
cv::Mat target_visualize(visualize_size, m_initial.image().type());
|
96 |
-
cv::resize(target.image(), target_visualize, visualize_size);
|
97 |
-
cv::imshow("Target", target_visualize);
|
98 |
-
cv::waitKey(0);
|
99 |
-
}
|
100 |
-
|
101 |
-
target = _expectation_maximization(source, target, level, verbose);
|
102 |
-
}
|
103 |
-
|
104 |
-
return target.image();
|
105 |
-
}
|
106 |
-
|
107 |
-
// EM-Like algorithm (see "PatchMatch" - page 6).
|
108 |
-
// Returns a double sized target image (unless level = 0).
|
109 |
-
MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
|
110 |
-
const int nr_iters_em = 1 + 2 * level;
|
111 |
-
const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
|
112 |
-
const int patch_size = m_distance_metric->patch_size();
|
113 |
-
|
114 |
-
MaskedImage new_source, new_target;
|
115 |
-
|
116 |
-
for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
|
117 |
-
if (iter_em != 0) {
|
118 |
-
m_source2target.set_target(new_target);
|
119 |
-
m_target2source.set_source(new_target);
|
120 |
-
target = new_target;
|
121 |
-
}
|
122 |
-
|
123 |
-
if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
|
124 |
-
|
125 |
-
auto size = source.size();
|
126 |
-
for (int i = 0; i < size.height; ++i) {
|
127 |
-
for (int j = 0; j < size.width; ++j) {
|
128 |
-
if (!source.contains_mask(i, j, patch_size)) {
|
129 |
-
m_source2target.set_identity(i, j);
|
130 |
-
m_target2source.set_identity(i, j);
|
131 |
-
}
|
132 |
-
}
|
133 |
-
}
|
134 |
-
if (verbose) std::cerr << " NNF minimization started." << std::endl;
|
135 |
-
m_source2target.minimize(nr_iters_nnf);
|
136 |
-
m_target2source.minimize(nr_iters_nnf);
|
137 |
-
if (verbose) std::cerr << " NNF minimization finished." << std::endl;
|
138 |
-
|
139 |
-
// Instead of upsizing the final target, we build the last target from the next level source image.
|
140 |
-
// Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
|
141 |
-
bool upscaled = false;
|
142 |
-
if (level >= 1 && iter_em == nr_iters_em - 1) {
|
143 |
-
new_source = m_pyramid[level - 1];
|
144 |
-
new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
|
145 |
-
upscaled = true;
|
146 |
-
} else {
|
147 |
-
new_source = m_pyramid[level];
|
148 |
-
new_target = target.clone();
|
149 |
-
}
|
150 |
-
|
151 |
-
auto vote = cv::Mat(new_target.size(), CV_64FC4);
|
152 |
-
vote.setTo(cv::Scalar::all(0));
|
153 |
-
|
154 |
-
// Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
|
155 |
-
_expectation_step(m_source2target, 1, vote, new_source, upscaled);
|
156 |
-
if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
|
157 |
-
_expectation_step(m_target2source, 0, vote, new_source, upscaled);
|
158 |
-
if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
|
159 |
-
|
160 |
-
// Compile votes and update pixel values.
|
161 |
-
_maximization_step(new_target, vote);
|
162 |
-
if (verbose) std::cerr << " Minimization step finished." << std::endl;
|
163 |
-
}
|
164 |
-
|
165 |
-
return new_target;
|
166 |
-
}
|
167 |
-
|
168 |
-
// Expectation step: vote for best estimations of each pixel.
|
169 |
-
void Inpainting::_expectation_step(
|
170 |
-
const NearestNeighborField &nnf, bool source2target,
|
171 |
-
cv::Mat &vote, const MaskedImage &source, bool upscaled
|
172 |
-
) {
|
173 |
-
auto source_size = nnf.source_size();
|
174 |
-
auto target_size = nnf.target_size();
|
175 |
-
const int patch_size = m_distance_metric->patch_size();
|
176 |
-
|
177 |
-
for (int i = 0; i < source_size.height; ++i) {
|
178 |
-
for (int j = 0; j < source_size.width; ++j) {
|
179 |
-
if (nnf.source().is_globally_masked(i, j)) continue;
|
180 |
-
int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
|
181 |
-
double w = kDistance2Similarity[dp];
|
182 |
-
|
183 |
-
for (int di = -patch_size; di <= patch_size; ++di) {
|
184 |
-
for (int dj = -patch_size; dj <= patch_size; ++dj) {
|
185 |
-
int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
|
186 |
-
if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
|
187 |
-
if (nnf.source().is_globally_masked(ys, xs)) continue;
|
188 |
-
if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
|
189 |
-
if (nnf.target().is_globally_masked(yt, xt)) continue;
|
190 |
-
|
191 |
-
if (!source2target) {
|
192 |
-
std::swap(ys, yt);
|
193 |
-
std::swap(xs, xt);
|
194 |
-
}
|
195 |
-
|
196 |
-
if (upscaled) {
|
197 |
-
for (int uy = 0; uy < 2; ++uy) {
|
198 |
-
for (int ux = 0; ux < 2; ++ux) {
|
199 |
-
_weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
|
200 |
-
}
|
201 |
-
}
|
202 |
-
} else {
|
203 |
-
_weighted_copy(source, ys, xs, vote, yt, xt, w);
|
204 |
-
}
|
205 |
-
}
|
206 |
-
}
|
207 |
-
}
|
208 |
-
}
|
209 |
-
}
|
210 |
-
|
211 |
-
// Maximization Step: maximum likelihood of target pixel.
|
212 |
-
void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
|
213 |
-
auto target_size = target.size();
|
214 |
-
for (int i = 0; i < target_size.height; ++i) {
|
215 |
-
for (int j = 0; j < target_size.width; ++j) {
|
216 |
-
const double *source_ptr = vote.ptr<double>(i, j);
|
217 |
-
unsigned char *target_ptr = target.get_mutable_image(i, j);
|
218 |
-
|
219 |
-
if (target.is_globally_masked(i, j)) {
|
220 |
-
continue;
|
221 |
-
}
|
222 |
-
|
223 |
-
if (source_ptr[3] > 0) {
|
224 |
-
unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
|
225 |
-
unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
|
226 |
-
unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
|
227 |
-
target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
|
228 |
-
} else {
|
229 |
-
target.set_mask(i, j, 0);
|
230 |
-
}
|
231 |
-
}
|
232 |
-
}
|
233 |
-
}
|
234 |
-
|
|
|
1 |
+
#include <algorithm>
|
2 |
+
#include <iostream>
|
3 |
+
#include <opencv2/imgcodecs.hpp>
|
4 |
+
#include <opencv2/imgproc.hpp>
|
5 |
+
#include <opencv2/highgui.hpp>
|
6 |
+
|
7 |
+
#include "inpaint.h"
|
8 |
+
|
9 |
+
namespace {
|
10 |
+
static std::vector<double> kDistance2Similarity;
|
11 |
+
|
12 |
+
void init_kDistance2Similarity() {
|
13 |
+
double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
|
14 |
+
int length = (PatchDistanceMetric::kDistanceScale + 1);
|
15 |
+
kDistance2Similarity.resize(length);
|
16 |
+
for (int i = 0; i < length; ++i) {
|
17 |
+
double t = (double) i / length;
|
18 |
+
int j = (int) (100 * t);
|
19 |
+
int k = j + 1;
|
20 |
+
double vj = (j < 11) ? base[j] : 0;
|
21 |
+
double vk = (k < 11) ? base[k] : 0;
|
22 |
+
kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
|
28 |
+
if (source.is_masked(ys, xs)) return;
|
29 |
+
if (source.is_globally_masked(ys, xs)) return;
|
30 |
+
|
31 |
+
auto source_ptr = source.get_image(ys, xs);
|
32 |
+
auto target_ptr = target.ptr<double>(yt, xt);
|
33 |
+
|
34 |
+
#pragma unroll
|
35 |
+
for (int c = 0; c < 3; ++c)
|
36 |
+
target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
|
37 |
+
target_ptr[3] += weight;
|
38 |
+
}
|
39 |
+
}
|
40 |
+
|
41 |
+
/**
|
42 |
+
* This algorithme uses a version proposed by Xavier Philippeau.
|
43 |
+
*/
|
44 |
+
|
45 |
+
Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
|
46 |
+
: m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
|
47 |
+
_initialize_pyramid();
|
48 |
+
}
|
49 |
+
|
50 |
+
Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
|
51 |
+
: m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
|
52 |
+
_initialize_pyramid();
|
53 |
+
}
|
54 |
+
|
55 |
+
void Inpainting::_initialize_pyramid() {
|
56 |
+
auto source = m_initial;
|
57 |
+
m_pyramid.push_back(source);
|
58 |
+
while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
|
59 |
+
source = source.downsample();
|
60 |
+
m_pyramid.push_back(source);
|
61 |
+
}
|
62 |
+
|
63 |
+
if (kDistance2Similarity.size() == 0) {
|
64 |
+
init_kDistance2Similarity();
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
|
69 |
+
srand(random_seed);
|
70 |
+
const int nr_levels = m_pyramid.size();
|
71 |
+
|
72 |
+
MaskedImage source, target;
|
73 |
+
for (int level = nr_levels - 1; level >= 0; --level) {
|
74 |
+
if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
|
75 |
+
|
76 |
+
source = m_pyramid[level];
|
77 |
+
|
78 |
+
if (level == nr_levels - 1) {
|
79 |
+
target = source.clone();
|
80 |
+
target.clear_mask();
|
81 |
+
m_source2target = NearestNeighborField(source, target, m_distance_metric);
|
82 |
+
m_target2source = NearestNeighborField(target, source, m_distance_metric);
|
83 |
+
} else {
|
84 |
+
m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
|
85 |
+
m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
|
86 |
+
}
|
87 |
+
|
88 |
+
if (verbose) std::cerr << "Initialization done." << std::endl;
|
89 |
+
|
90 |
+
if (verbose_visualize) {
|
91 |
+
auto visualize_size = m_initial.size();
|
92 |
+
cv::Mat source_visualize(visualize_size, m_initial.image().type());
|
93 |
+
cv::resize(source.image(), source_visualize, visualize_size);
|
94 |
+
cv::imshow("Source", source_visualize);
|
95 |
+
cv::Mat target_visualize(visualize_size, m_initial.image().type());
|
96 |
+
cv::resize(target.image(), target_visualize, visualize_size);
|
97 |
+
cv::imshow("Target", target_visualize);
|
98 |
+
cv::waitKey(0);
|
99 |
+
}
|
100 |
+
|
101 |
+
target = _expectation_maximization(source, target, level, verbose);
|
102 |
+
}
|
103 |
+
|
104 |
+
return target.image();
|
105 |
+
}
|
106 |
+
|
107 |
+
// EM-Like algorithm (see "PatchMatch" - page 6).
|
108 |
+
// Returns a double sized target image (unless level = 0).
|
109 |
+
MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
|
110 |
+
const int nr_iters_em = 1 + 2 * level;
|
111 |
+
const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
|
112 |
+
const int patch_size = m_distance_metric->patch_size();
|
113 |
+
|
114 |
+
MaskedImage new_source, new_target;
|
115 |
+
|
116 |
+
for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
|
117 |
+
if (iter_em != 0) {
|
118 |
+
m_source2target.set_target(new_target);
|
119 |
+
m_target2source.set_source(new_target);
|
120 |
+
target = new_target;
|
121 |
+
}
|
122 |
+
|
123 |
+
if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
|
124 |
+
|
125 |
+
auto size = source.size();
|
126 |
+
for (int i = 0; i < size.height; ++i) {
|
127 |
+
for (int j = 0; j < size.width; ++j) {
|
128 |
+
if (!source.contains_mask(i, j, patch_size)) {
|
129 |
+
m_source2target.set_identity(i, j);
|
130 |
+
m_target2source.set_identity(i, j);
|
131 |
+
}
|
132 |
+
}
|
133 |
+
}
|
134 |
+
if (verbose) std::cerr << " NNF minimization started." << std::endl;
|
135 |
+
m_source2target.minimize(nr_iters_nnf);
|
136 |
+
m_target2source.minimize(nr_iters_nnf);
|
137 |
+
if (verbose) std::cerr << " NNF minimization finished." << std::endl;
|
138 |
+
|
139 |
+
// Instead of upsizing the final target, we build the last target from the next level source image.
|
140 |
+
// Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
|
141 |
+
bool upscaled = false;
|
142 |
+
if (level >= 1 && iter_em == nr_iters_em - 1) {
|
143 |
+
new_source = m_pyramid[level - 1];
|
144 |
+
new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
|
145 |
+
upscaled = true;
|
146 |
+
} else {
|
147 |
+
new_source = m_pyramid[level];
|
148 |
+
new_target = target.clone();
|
149 |
+
}
|
150 |
+
|
151 |
+
auto vote = cv::Mat(new_target.size(), CV_64FC4);
|
152 |
+
vote.setTo(cv::Scalar::all(0));
|
153 |
+
|
154 |
+
// Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
|
155 |
+
_expectation_step(m_source2target, 1, vote, new_source, upscaled);
|
156 |
+
if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
|
157 |
+
_expectation_step(m_target2source, 0, vote, new_source, upscaled);
|
158 |
+
if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
|
159 |
+
|
160 |
+
// Compile votes and update pixel values.
|
161 |
+
_maximization_step(new_target, vote);
|
162 |
+
if (verbose) std::cerr << " Minimization step finished." << std::endl;
|
163 |
+
}
|
164 |
+
|
165 |
+
return new_target;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Expectation step: vote for best estimations of each pixel.
|
169 |
+
void Inpainting::_expectation_step(
|
170 |
+
const NearestNeighborField &nnf, bool source2target,
|
171 |
+
cv::Mat &vote, const MaskedImage &source, bool upscaled
|
172 |
+
) {
|
173 |
+
auto source_size = nnf.source_size();
|
174 |
+
auto target_size = nnf.target_size();
|
175 |
+
const int patch_size = m_distance_metric->patch_size();
|
176 |
+
|
177 |
+
for (int i = 0; i < source_size.height; ++i) {
|
178 |
+
for (int j = 0; j < source_size.width; ++j) {
|
179 |
+
if (nnf.source().is_globally_masked(i, j)) continue;
|
180 |
+
int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
|
181 |
+
double w = kDistance2Similarity[dp];
|
182 |
+
|
183 |
+
for (int di = -patch_size; di <= patch_size; ++di) {
|
184 |
+
for (int dj = -patch_size; dj <= patch_size; ++dj) {
|
185 |
+
int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
|
186 |
+
if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
|
187 |
+
if (nnf.source().is_globally_masked(ys, xs)) continue;
|
188 |
+
if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
|
189 |
+
if (nnf.target().is_globally_masked(yt, xt)) continue;
|
190 |
+
|
191 |
+
if (!source2target) {
|
192 |
+
std::swap(ys, yt);
|
193 |
+
std::swap(xs, xt);
|
194 |
+
}
|
195 |
+
|
196 |
+
if (upscaled) {
|
197 |
+
for (int uy = 0; uy < 2; ++uy) {
|
198 |
+
for (int ux = 0; ux < 2; ++ux) {
|
199 |
+
_weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
|
200 |
+
}
|
201 |
+
}
|
202 |
+
} else {
|
203 |
+
_weighted_copy(source, ys, xs, vote, yt, xt, w);
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
// Maximization Step: maximum likelihood of target pixel.
|
212 |
+
void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
|
213 |
+
auto target_size = target.size();
|
214 |
+
for (int i = 0; i < target_size.height; ++i) {
|
215 |
+
for (int j = 0; j < target_size.width; ++j) {
|
216 |
+
const double *source_ptr = vote.ptr<double>(i, j);
|
217 |
+
unsigned char *target_ptr = target.get_mutable_image(i, j);
|
218 |
+
|
219 |
+
if (target.is_globally_masked(i, j)) {
|
220 |
+
continue;
|
221 |
+
}
|
222 |
+
|
223 |
+
if (source_ptr[3] > 0) {
|
224 |
+
unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
|
225 |
+
unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
|
226 |
+
unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
|
227 |
+
target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
|
228 |
+
} else {
|
229 |
+
target.set_mask(i, j, 0);
|
230 |
+
}
|
231 |
+
}
|
232 |
+
}
|
233 |
+
}
|
234 |
+
|
PyPatchMatch/csrc/inpaint.h
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <vector>
|
4 |
-
|
5 |
-
#include "masked_image.h"
|
6 |
-
#include "nnf.h"
|
7 |
-
|
8 |
-
class Inpainting {
|
9 |
-
public:
|
10 |
-
Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
|
11 |
-
Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
|
12 |
-
cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
|
13 |
-
|
14 |
-
private:
|
15 |
-
void _initialize_pyramid(void);
|
16 |
-
MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
|
17 |
-
void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
|
18 |
-
void _maximization_step(MaskedImage &target, const cv::Mat &vote);
|
19 |
-
|
20 |
-
MaskedImage m_initial;
|
21 |
-
std::vector<MaskedImage> m_pyramid;
|
22 |
-
|
23 |
-
NearestNeighborField m_source2target;
|
24 |
-
NearestNeighborField m_target2source;
|
25 |
-
const PatchDistanceMetric *m_distance_metric;
|
26 |
-
};
|
27 |
-
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
#include "masked_image.h"
|
6 |
+
#include "nnf.h"
|
7 |
+
|
8 |
+
class Inpainting {
|
9 |
+
public:
|
10 |
+
Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
|
11 |
+
Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
|
12 |
+
cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
|
13 |
+
|
14 |
+
private:
|
15 |
+
void _initialize_pyramid(void);
|
16 |
+
MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
|
17 |
+
void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
|
18 |
+
void _maximization_step(MaskedImage &target, const cv::Mat &vote);
|
19 |
+
|
20 |
+
MaskedImage m_initial;
|
21 |
+
std::vector<MaskedImage> m_pyramid;
|
22 |
+
|
23 |
+
NearestNeighborField m_source2target;
|
24 |
+
NearestNeighborField m_target2source;
|
25 |
+
const PatchDistanceMetric *m_distance_metric;
|
26 |
+
};
|
27 |
+
|
PyPatchMatch/csrc/masked_image.cpp
CHANGED
@@ -1,138 +1,138 @@
|
|
1 |
-
#include "masked_image.h"
|
2 |
-
#include <algorithm>
|
3 |
-
#include <iostream>
|
4 |
-
|
5 |
-
const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
|
6 |
-
const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
|
7 |
-
|
8 |
-
bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
|
9 |
-
auto mask_size = size();
|
10 |
-
for (int dy = -patch_size; dy <= patch_size; ++dy) {
|
11 |
-
for (int dx = -patch_size; dx <= patch_size; ++dx) {
|
12 |
-
int yy = y + dy, xx = x + dx;
|
13 |
-
if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
|
14 |
-
if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
|
15 |
-
}
|
16 |
-
}
|
17 |
-
}
|
18 |
-
return false;
|
19 |
-
}
|
20 |
-
|
21 |
-
MaskedImage MaskedImage::downsample() const {
|
22 |
-
const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
|
23 |
-
const auto &kernel = MaskedImage::kDownsampleKernel;
|
24 |
-
|
25 |
-
const auto size = this->size();
|
26 |
-
const auto new_size = cv::Size(size.width / 2, size.height / 2);
|
27 |
-
|
28 |
-
auto ret = MaskedImage(new_size.width, new_size.height);
|
29 |
-
if (!m_global_mask.empty()) ret.init_global_mask_mat();
|
30 |
-
for (int y = 0; y < size.height - 1; y += 2) {
|
31 |
-
for (int x = 0; x < size.width - 1; x += 2) {
|
32 |
-
int r = 0, g = 0, b = 0, ksum = 0;
|
33 |
-
bool is_gmasked = true;
|
34 |
-
|
35 |
-
for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
|
36 |
-
for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
|
37 |
-
int yy = y + dy, xx = x + dx;
|
38 |
-
if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
|
39 |
-
if (!is_globally_masked(yy, xx)) {
|
40 |
-
is_gmasked = false;
|
41 |
-
}
|
42 |
-
if (!is_masked(yy, xx)) {
|
43 |
-
auto source_ptr = get_image(yy, xx);
|
44 |
-
int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
|
45 |
-
r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
|
46 |
-
ksum += k;
|
47 |
-
}
|
48 |
-
}
|
49 |
-
}
|
50 |
-
}
|
51 |
-
|
52 |
-
if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
|
53 |
-
|
54 |
-
if (!m_global_mask.empty()) {
|
55 |
-
ret.set_global_mask(y / 2, x / 2, is_gmasked);
|
56 |
-
}
|
57 |
-
if (ksum > 0) {
|
58 |
-
auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
|
59 |
-
target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
|
60 |
-
ret.set_mask(y / 2, x / 2, 0);
|
61 |
-
} else {
|
62 |
-
ret.set_mask(y / 2, x / 2, 1);
|
63 |
-
}
|
64 |
-
}
|
65 |
-
}
|
66 |
-
|
67 |
-
return ret;
|
68 |
-
}
|
69 |
-
|
70 |
-
MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
|
71 |
-
const auto size = this->size();
|
72 |
-
auto ret = MaskedImage(new_w, new_h);
|
73 |
-
if (!m_global_mask.empty()) ret.init_global_mask_mat();
|
74 |
-
for (int y = 0; y < new_h; ++y) {
|
75 |
-
for (int x = 0; x < new_w; ++x) {
|
76 |
-
int yy = y * size.height / new_h;
|
77 |
-
int xx = x * size.width / new_w;
|
78 |
-
|
79 |
-
if (is_globally_masked(yy, xx)) {
|
80 |
-
ret.set_global_mask(y, x, 1);
|
81 |
-
ret.set_mask(y, x, 1);
|
82 |
-
} else {
|
83 |
-
if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
|
84 |
-
|
85 |
-
if (is_masked(yy, xx)) {
|
86 |
-
ret.set_mask(y, x, 1);
|
87 |
-
} else {
|
88 |
-
auto source_ptr = get_image(yy, xx);
|
89 |
-
auto target_ptr = ret.get_mutable_image(y, x);
|
90 |
-
for (int c = 0; c < 3; ++c)
|
91 |
-
target_ptr[c] = source_ptr[c];
|
92 |
-
ret.set_mask(y, x, 0);
|
93 |
-
}
|
94 |
-
}
|
95 |
-
}
|
96 |
-
}
|
97 |
-
|
98 |
-
return ret;
|
99 |
-
}
|
100 |
-
|
101 |
-
MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
|
102 |
-
auto ret = upsample(new_w, new_h);
|
103 |
-
ret.set_global_mask_mat(new_global_mask);
|
104 |
-
return ret;
|
105 |
-
}
|
106 |
-
|
107 |
-
void MaskedImage::compute_image_gradients() {
|
108 |
-
if (m_image_grad_computed) {
|
109 |
-
return;
|
110 |
-
}
|
111 |
-
|
112 |
-
const auto size = m_image.size();
|
113 |
-
m_image_grady = cv::Mat(size, CV_8UC3);
|
114 |
-
m_image_gradx = cv::Mat(size, CV_8UC3);
|
115 |
-
m_image_grady = cv::Scalar::all(0);
|
116 |
-
m_image_gradx = cv::Scalar::all(0);
|
117 |
-
|
118 |
-
for (int i = 1; i < size.height - 1; ++i) {
|
119 |
-
const auto *ptr = m_image.ptr<unsigned char>(i, 0);
|
120 |
-
const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
|
121 |
-
const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
|
122 |
-
const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
|
123 |
-
const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
|
124 |
-
auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
|
125 |
-
auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
|
126 |
-
for (int j = 3; j < size.width * 3 - 3; ++j) {
|
127 |
-
mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
|
128 |
-
mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
|
129 |
-
}
|
130 |
-
}
|
131 |
-
|
132 |
-
m_image_grad_computed = true;
|
133 |
-
}
|
134 |
-
|
135 |
-
void MaskedImage::compute_image_gradients() const {
|
136 |
-
const_cast<MaskedImage *>(this)->compute_image_gradients();
|
137 |
-
}
|
138 |
-
|
|
|
1 |
+
#include "masked_image.h"
|
2 |
+
#include <algorithm>
|
3 |
+
#include <iostream>
|
4 |
+
|
5 |
+
const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
|
6 |
+
const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
|
7 |
+
|
8 |
+
bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
|
9 |
+
auto mask_size = size();
|
10 |
+
for (int dy = -patch_size; dy <= patch_size; ++dy) {
|
11 |
+
for (int dx = -patch_size; dx <= patch_size; ++dx) {
|
12 |
+
int yy = y + dy, xx = x + dx;
|
13 |
+
if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
|
14 |
+
if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
|
15 |
+
}
|
16 |
+
}
|
17 |
+
}
|
18 |
+
return false;
|
19 |
+
}
|
20 |
+
|
21 |
+
MaskedImage MaskedImage::downsample() const {
|
22 |
+
const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
|
23 |
+
const auto &kernel = MaskedImage::kDownsampleKernel;
|
24 |
+
|
25 |
+
const auto size = this->size();
|
26 |
+
const auto new_size = cv::Size(size.width / 2, size.height / 2);
|
27 |
+
|
28 |
+
auto ret = MaskedImage(new_size.width, new_size.height);
|
29 |
+
if (!m_global_mask.empty()) ret.init_global_mask_mat();
|
30 |
+
for (int y = 0; y < size.height - 1; y += 2) {
|
31 |
+
for (int x = 0; x < size.width - 1; x += 2) {
|
32 |
+
int r = 0, g = 0, b = 0, ksum = 0;
|
33 |
+
bool is_gmasked = true;
|
34 |
+
|
35 |
+
for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
|
36 |
+
for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
|
37 |
+
int yy = y + dy, xx = x + dx;
|
38 |
+
if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
|
39 |
+
if (!is_globally_masked(yy, xx)) {
|
40 |
+
is_gmasked = false;
|
41 |
+
}
|
42 |
+
if (!is_masked(yy, xx)) {
|
43 |
+
auto source_ptr = get_image(yy, xx);
|
44 |
+
int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
|
45 |
+
r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
|
46 |
+
ksum += k;
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
|
53 |
+
|
54 |
+
if (!m_global_mask.empty()) {
|
55 |
+
ret.set_global_mask(y / 2, x / 2, is_gmasked);
|
56 |
+
}
|
57 |
+
if (ksum > 0) {
|
58 |
+
auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
|
59 |
+
target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
|
60 |
+
ret.set_mask(y / 2, x / 2, 0);
|
61 |
+
} else {
|
62 |
+
ret.set_mask(y / 2, x / 2, 1);
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
return ret;
|
68 |
+
}
|
69 |
+
|
70 |
+
MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
|
71 |
+
const auto size = this->size();
|
72 |
+
auto ret = MaskedImage(new_w, new_h);
|
73 |
+
if (!m_global_mask.empty()) ret.init_global_mask_mat();
|
74 |
+
for (int y = 0; y < new_h; ++y) {
|
75 |
+
for (int x = 0; x < new_w; ++x) {
|
76 |
+
int yy = y * size.height / new_h;
|
77 |
+
int xx = x * size.width / new_w;
|
78 |
+
|
79 |
+
if (is_globally_masked(yy, xx)) {
|
80 |
+
ret.set_global_mask(y, x, 1);
|
81 |
+
ret.set_mask(y, x, 1);
|
82 |
+
} else {
|
83 |
+
if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
|
84 |
+
|
85 |
+
if (is_masked(yy, xx)) {
|
86 |
+
ret.set_mask(y, x, 1);
|
87 |
+
} else {
|
88 |
+
auto source_ptr = get_image(yy, xx);
|
89 |
+
auto target_ptr = ret.get_mutable_image(y, x);
|
90 |
+
for (int c = 0; c < 3; ++c)
|
91 |
+
target_ptr[c] = source_ptr[c];
|
92 |
+
ret.set_mask(y, x, 0);
|
93 |
+
}
|
94 |
+
}
|
95 |
+
}
|
96 |
+
}
|
97 |
+
|
98 |
+
return ret;
|
99 |
+
}
|
100 |
+
|
101 |
+
MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
|
102 |
+
auto ret = upsample(new_w, new_h);
|
103 |
+
ret.set_global_mask_mat(new_global_mask);
|
104 |
+
return ret;
|
105 |
+
}
|
106 |
+
|
107 |
+
void MaskedImage::compute_image_gradients() {
|
108 |
+
if (m_image_grad_computed) {
|
109 |
+
return;
|
110 |
+
}
|
111 |
+
|
112 |
+
const auto size = m_image.size();
|
113 |
+
m_image_grady = cv::Mat(size, CV_8UC3);
|
114 |
+
m_image_gradx = cv::Mat(size, CV_8UC3);
|
115 |
+
m_image_grady = cv::Scalar::all(0);
|
116 |
+
m_image_gradx = cv::Scalar::all(0);
|
117 |
+
|
118 |
+
for (int i = 1; i < size.height - 1; ++i) {
|
119 |
+
const auto *ptr = m_image.ptr<unsigned char>(i, 0);
|
120 |
+
const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
|
121 |
+
const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
|
122 |
+
const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
|
123 |
+
const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
|
124 |
+
auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
|
125 |
+
auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
|
126 |
+
for (int j = 3; j < size.width * 3 - 3; ++j) {
|
127 |
+
mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
|
128 |
+
mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
m_image_grad_computed = true;
|
133 |
+
}
|
134 |
+
|
135 |
+
void MaskedImage::compute_image_gradients() const {
|
136 |
+
const_cast<MaskedImage *>(this)->compute_image_gradients();
|
137 |
+
}
|
138 |
+
|
PyPatchMatch/csrc/masked_image.h
CHANGED
@@ -1,112 +1,112 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <opencv2/core.hpp>
|
4 |
-
|
5 |
-
class MaskedImage {
|
6 |
-
public:
|
7 |
-
MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
|
8 |
-
// pass
|
9 |
-
}
|
10 |
-
MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
|
11 |
-
// pass
|
12 |
-
}
|
13 |
-
MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
|
14 |
-
// pass
|
15 |
-
}
|
16 |
-
MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
|
17 |
-
m_image(image), m_mask(mask), m_global_mask(global_mask),
|
18 |
-
m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
|
19 |
-
// pass
|
20 |
-
}
|
21 |
-
MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
|
22 |
-
m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
|
23 |
-
m_image = cv::Scalar::all(0);
|
24 |
-
|
25 |
-
m_mask = cv::Mat(cv::Size(width, height), CV_8U);
|
26 |
-
m_mask = cv::Scalar::all(0);
|
27 |
-
}
|
28 |
-
inline MaskedImage clone() {
|
29 |
-
return MaskedImage(
|
30 |
-
m_image.clone(), m_mask.clone(), m_global_mask.clone(),
|
31 |
-
m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
|
32 |
-
);
|
33 |
-
}
|
34 |
-
|
35 |
-
inline cv::Size size() const {
|
36 |
-
return m_image.size();
|
37 |
-
}
|
38 |
-
inline const cv::Mat &image() const {
|
39 |
-
return m_image;
|
40 |
-
}
|
41 |
-
inline const cv::Mat &mask() const {
|
42 |
-
return m_mask;
|
43 |
-
}
|
44 |
-
inline const cv::Mat &global_mask() const {
|
45 |
-
return m_global_mask;
|
46 |
-
}
|
47 |
-
inline const cv::Mat &grady() const {
|
48 |
-
assert(m_image_grad_computed);
|
49 |
-
return m_image_grady;
|
50 |
-
}
|
51 |
-
inline const cv::Mat &gradx() const {
|
52 |
-
assert(m_image_grad_computed);
|
53 |
-
return m_image_gradx;
|
54 |
-
}
|
55 |
-
|
56 |
-
inline void init_global_mask_mat() {
|
57 |
-
m_global_mask = cv::Mat(m_mask.size(), CV_8U);
|
58 |
-
m_global_mask.setTo(cv::Scalar(0));
|
59 |
-
}
|
60 |
-
inline void set_global_mask_mat(const cv::Mat &other) {
|
61 |
-
m_global_mask = other;
|
62 |
-
}
|
63 |
-
|
64 |
-
inline bool is_masked(int y, int x) const {
|
65 |
-
return static_cast<bool>(m_mask.at<unsigned char>(y, x));
|
66 |
-
}
|
67 |
-
inline bool is_globally_masked(int y, int x) const {
|
68 |
-
return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
|
69 |
-
}
|
70 |
-
inline void set_mask(int y, int x, bool value) {
|
71 |
-
m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
|
72 |
-
}
|
73 |
-
inline void set_global_mask(int y, int x, bool value) {
|
74 |
-
m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
|
75 |
-
}
|
76 |
-
inline void clear_mask() {
|
77 |
-
m_mask.setTo(cv::Scalar(0));
|
78 |
-
}
|
79 |
-
|
80 |
-
inline const unsigned char *get_image(int y, int x) const {
|
81 |
-
return m_image.ptr<unsigned char>(y, x);
|
82 |
-
}
|
83 |
-
inline unsigned char *get_mutable_image(int y, int x) {
|
84 |
-
return m_image.ptr<unsigned char>(y, x);
|
85 |
-
}
|
86 |
-
|
87 |
-
inline unsigned char get_image(int y, int x, int c) const {
|
88 |
-
return m_image.ptr<unsigned char>(y, x)[c];
|
89 |
-
}
|
90 |
-
inline int get_image_int(int y, int x, int c) const {
|
91 |
-
return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
|
92 |
-
}
|
93 |
-
|
94 |
-
bool contains_mask(int y, int x, int patch_size) const;
|
95 |
-
MaskedImage downsample() const;
|
96 |
-
MaskedImage upsample(int new_w, int new_h) const;
|
97 |
-
MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
|
98 |
-
void compute_image_gradients();
|
99 |
-
void compute_image_gradients() const;
|
100 |
-
|
101 |
-
static const cv::Size kDownsampleKernelSize;
|
102 |
-
static const int kDownsampleKernel[6];
|
103 |
-
|
104 |
-
private:
|
105 |
-
cv::Mat m_image;
|
106 |
-
cv::Mat m_mask;
|
107 |
-
cv::Mat m_global_mask;
|
108 |
-
cv::Mat m_image_grady;
|
109 |
-
cv::Mat m_image_gradx;
|
110 |
-
bool m_image_grad_computed = false;
|
111 |
-
};
|
112 |
-
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <opencv2/core.hpp>
|
4 |
+
|
5 |
+
class MaskedImage {
|
6 |
+
public:
|
7 |
+
MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
|
8 |
+
// pass
|
9 |
+
}
|
10 |
+
MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
|
11 |
+
// pass
|
12 |
+
}
|
13 |
+
MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
|
14 |
+
// pass
|
15 |
+
}
|
16 |
+
MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
|
17 |
+
m_image(image), m_mask(mask), m_global_mask(global_mask),
|
18 |
+
m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
|
19 |
+
// pass
|
20 |
+
}
|
21 |
+
MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
|
22 |
+
m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
|
23 |
+
m_image = cv::Scalar::all(0);
|
24 |
+
|
25 |
+
m_mask = cv::Mat(cv::Size(width, height), CV_8U);
|
26 |
+
m_mask = cv::Scalar::all(0);
|
27 |
+
}
|
28 |
+
inline MaskedImage clone() {
|
29 |
+
return MaskedImage(
|
30 |
+
m_image.clone(), m_mask.clone(), m_global_mask.clone(),
|
31 |
+
m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
|
32 |
+
);
|
33 |
+
}
|
34 |
+
|
35 |
+
inline cv::Size size() const {
|
36 |
+
return m_image.size();
|
37 |
+
}
|
38 |
+
inline const cv::Mat &image() const {
|
39 |
+
return m_image;
|
40 |
+
}
|
41 |
+
inline const cv::Mat &mask() const {
|
42 |
+
return m_mask;
|
43 |
+
}
|
44 |
+
inline const cv::Mat &global_mask() const {
|
45 |
+
return m_global_mask;
|
46 |
+
}
|
47 |
+
inline const cv::Mat &grady() const {
|
48 |
+
assert(m_image_grad_computed);
|
49 |
+
return m_image_grady;
|
50 |
+
}
|
51 |
+
inline const cv::Mat &gradx() const {
|
52 |
+
assert(m_image_grad_computed);
|
53 |
+
return m_image_gradx;
|
54 |
+
}
|
55 |
+
|
56 |
+
inline void init_global_mask_mat() {
|
57 |
+
m_global_mask = cv::Mat(m_mask.size(), CV_8U);
|
58 |
+
m_global_mask.setTo(cv::Scalar(0));
|
59 |
+
}
|
60 |
+
inline void set_global_mask_mat(const cv::Mat &other) {
|
61 |
+
m_global_mask = other;
|
62 |
+
}
|
63 |
+
|
64 |
+
inline bool is_masked(int y, int x) const {
|
65 |
+
return static_cast<bool>(m_mask.at<unsigned char>(y, x));
|
66 |
+
}
|
67 |
+
inline bool is_globally_masked(int y, int x) const {
|
68 |
+
return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
|
69 |
+
}
|
70 |
+
inline void set_mask(int y, int x, bool value) {
|
71 |
+
m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
|
72 |
+
}
|
73 |
+
inline void set_global_mask(int y, int x, bool value) {
|
74 |
+
m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
|
75 |
+
}
|
76 |
+
inline void clear_mask() {
|
77 |
+
m_mask.setTo(cv::Scalar(0));
|
78 |
+
}
|
79 |
+
|
80 |
+
inline const unsigned char *get_image(int y, int x) const {
|
81 |
+
return m_image.ptr<unsigned char>(y, x);
|
82 |
+
}
|
83 |
+
inline unsigned char *get_mutable_image(int y, int x) {
|
84 |
+
return m_image.ptr<unsigned char>(y, x);
|
85 |
+
}
|
86 |
+
|
87 |
+
inline unsigned char get_image(int y, int x, int c) const {
|
88 |
+
return m_image.ptr<unsigned char>(y, x)[c];
|
89 |
+
}
|
90 |
+
inline int get_image_int(int y, int x, int c) const {
|
91 |
+
return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
|
92 |
+
}
|
93 |
+
|
94 |
+
bool contains_mask(int y, int x, int patch_size) const;
|
95 |
+
MaskedImage downsample() const;
|
96 |
+
MaskedImage upsample(int new_w, int new_h) const;
|
97 |
+
MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
|
98 |
+
void compute_image_gradients();
|
99 |
+
void compute_image_gradients() const;
|
100 |
+
|
101 |
+
static const cv::Size kDownsampleKernelSize;
|
102 |
+
static const int kDownsampleKernel[6];
|
103 |
+
|
104 |
+
private:
|
105 |
+
cv::Mat m_image;
|
106 |
+
cv::Mat m_mask;
|
107 |
+
cv::Mat m_global_mask;
|
108 |
+
cv::Mat m_image_grady;
|
109 |
+
cv::Mat m_image_gradx;
|
110 |
+
bool m_image_grad_computed = false;
|
111 |
+
};
|
112 |
+
|
PyPatchMatch/csrc/nnf.cpp
CHANGED
@@ -1,268 +1,268 @@
|
|
1 |
-
#include <algorithm>
|
2 |
-
#include <iostream>
|
3 |
-
#include <cmath>
|
4 |
-
|
5 |
-
#include "masked_image.h"
|
6 |
-
#include "nnf.h"
|
7 |
-
|
8 |
-
/**
|
9 |
-
* Nearest-Neighbor Field (see PatchMatch algorithm).
|
10 |
-
* This algorithme uses a version proposed by Xavier Philippeau.
|
11 |
-
*
|
12 |
-
*/
|
13 |
-
|
14 |
-
template <typename T>
|
15 |
-
T clamp(T value, T min_value, T max_value) {
|
16 |
-
return std::min(std::max(value, min_value), max_value);
|
17 |
-
}
|
18 |
-
|
19 |
-
void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
|
20 |
-
auto this_size = source_size();
|
21 |
-
for (int i = 0; i < this_size.height; ++i) {
|
22 |
-
for (int j = 0; j < this_size.width; ++j) {
|
23 |
-
if (m_source.is_globally_masked(i, j)) continue;
|
24 |
-
|
25 |
-
auto this_ptr = mutable_ptr(i, j);
|
26 |
-
int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
|
27 |
-
if (distance < PatchDistanceMetric::kDistanceScale) {
|
28 |
-
continue;
|
29 |
-
}
|
30 |
-
|
31 |
-
int i_target = 0, j_target = 0;
|
32 |
-
for (int t = 0; t < max_retry; ++t) {
|
33 |
-
i_target = rand() % this_size.height;
|
34 |
-
j_target = rand() % this_size.width;
|
35 |
-
if (m_target.is_globally_masked(i_target, j_target)) continue;
|
36 |
-
|
37 |
-
distance = _distance(i, j, i_target, j_target);
|
38 |
-
if (distance < PatchDistanceMetric::kDistanceScale)
|
39 |
-
break;
|
40 |
-
}
|
41 |
-
|
42 |
-
this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
|
43 |
-
}
|
44 |
-
}
|
45 |
-
}
|
46 |
-
|
47 |
-
void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
|
48 |
-
const auto &this_size = source_size();
|
49 |
-
const auto &other_size = other.source_size();
|
50 |
-
double fi = static_cast<double>(this_size.height) / other_size.height;
|
51 |
-
double fj = static_cast<double>(this_size.width) / other_size.width;
|
52 |
-
|
53 |
-
for (int i = 0; i < this_size.height; ++i) {
|
54 |
-
for (int j = 0; j < this_size.width; ++j) {
|
55 |
-
if (m_source.is_globally_masked(i, j)) continue;
|
56 |
-
|
57 |
-
int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
|
58 |
-
int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
|
59 |
-
auto this_value = mutable_ptr(i, j);
|
60 |
-
auto other_value = other.ptr(ilow, jlow);
|
61 |
-
|
62 |
-
this_value[0] = static_cast<int>(other_value[0] * fi);
|
63 |
-
this_value[1] = static_cast<int>(other_value[1] * fj);
|
64 |
-
this_value[2] = _distance(i, j, this_value[0], this_value[1]);
|
65 |
-
}
|
66 |
-
}
|
67 |
-
|
68 |
-
_randomize_field(max_retry, false);
|
69 |
-
}
|
70 |
-
|
71 |
-
void NearestNeighborField::minimize(int nr_pass) {
|
72 |
-
const auto &this_size = source_size();
|
73 |
-
while (nr_pass--) {
|
74 |
-
for (int i = 0; i < this_size.height; ++i)
|
75 |
-
for (int j = 0; j < this_size.width; ++j) {
|
76 |
-
if (m_source.is_globally_masked(i, j)) continue;
|
77 |
-
if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
|
78 |
-
}
|
79 |
-
for (int i = this_size.height - 1; i >= 0; --i)
|
80 |
-
for (int j = this_size.width - 1; j >= 0; --j) {
|
81 |
-
if (m_source.is_globally_masked(i, j)) continue;
|
82 |
-
if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
|
83 |
-
}
|
84 |
-
}
|
85 |
-
}
|
86 |
-
|
87 |
-
void NearestNeighborField::_minimize_link(int y, int x, int direction) {
|
88 |
-
const auto &this_size = source_size();
|
89 |
-
const auto &this_target_size = target_size();
|
90 |
-
auto this_ptr = mutable_ptr(y, x);
|
91 |
-
|
92 |
-
// propagation along the y direction.
|
93 |
-
if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
|
94 |
-
int yp = at(y - direction, x, 0) + direction;
|
95 |
-
int xp = at(y - direction, x, 1);
|
96 |
-
int dp = _distance(y, x, yp, xp);
|
97 |
-
if (dp < at(y, x, 2)) {
|
98 |
-
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
99 |
-
}
|
100 |
-
}
|
101 |
-
|
102 |
-
// propagation along the x direction.
|
103 |
-
if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
|
104 |
-
int yp = at(y, x - direction, 0);
|
105 |
-
int xp = at(y, x - direction, 1) + direction;
|
106 |
-
int dp = _distance(y, x, yp, xp);
|
107 |
-
if (dp < at(y, x, 2)) {
|
108 |
-
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
109 |
-
}
|
110 |
-
}
|
111 |
-
|
112 |
-
// random search with a progressive step size.
|
113 |
-
int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
|
114 |
-
while (random_scale > 0) {
|
115 |
-
int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
|
116 |
-
int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
|
117 |
-
yp = clamp(yp, 0, target_size().height - 1);
|
118 |
-
xp = clamp(xp, 0, target_size().width - 1);
|
119 |
-
|
120 |
-
if (m_target.is_globally_masked(yp, xp)) {
|
121 |
-
random_scale /= 2;
|
122 |
-
}
|
123 |
-
|
124 |
-
int dp = _distance(y, x, yp, xp);
|
125 |
-
if (dp < at(y, x, 2)) {
|
126 |
-
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
127 |
-
}
|
128 |
-
random_scale /= 2;
|
129 |
-
}
|
130 |
-
}
|
131 |
-
|
132 |
-
const int PatchDistanceMetric::kDistanceScale = 65535;
|
133 |
-
const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
|
134 |
-
|
135 |
-
namespace {
|
136 |
-
|
137 |
-
inline int pow2(int i) {
|
138 |
-
return i * i;
|
139 |
-
}
|
140 |
-
|
141 |
-
int distance_masked_images(
|
142 |
-
const MaskedImage &source, int ys, int xs,
|
143 |
-
const MaskedImage &target, int yt, int xt,
|
144 |
-
int patch_size
|
145 |
-
) {
|
146 |
-
long double distance = 0;
|
147 |
-
long double wsum = 0;
|
148 |
-
|
149 |
-
source.compute_image_gradients();
|
150 |
-
target.compute_image_gradients();
|
151 |
-
|
152 |
-
auto source_size = source.size();
|
153 |
-
auto target_size = target.size();
|
154 |
-
|
155 |
-
for (int dy = -patch_size; dy <= patch_size; ++dy) {
|
156 |
-
const int yys = ys + dy, yyt = yt + dy;
|
157 |
-
|
158 |
-
if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
|
159 |
-
distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
|
160 |
-
wsum += 2 * patch_size + 1;
|
161 |
-
continue;
|
162 |
-
}
|
163 |
-
|
164 |
-
const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
|
165 |
-
const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
|
166 |
-
const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
|
167 |
-
const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
|
168 |
-
|
169 |
-
const unsigned char *p_sgm = nullptr;
|
170 |
-
const unsigned char *p_tgm = nullptr;
|
171 |
-
if (!source.global_mask().empty()) {
|
172 |
-
p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
|
173 |
-
p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
|
174 |
-
}
|
175 |
-
|
176 |
-
const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
|
177 |
-
const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
|
178 |
-
const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
|
179 |
-
const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
|
180 |
-
|
181 |
-
for (int dx = -patch_size; dx <= patch_size; ++dx) {
|
182 |
-
int xxs = xs + dx, xxt = xt + dx;
|
183 |
-
wsum += 1;
|
184 |
-
|
185 |
-
if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
|
186 |
-
distance += PatchSSDDistanceMetric::kSSDScale;
|
187 |
-
continue;
|
188 |
-
}
|
189 |
-
|
190 |
-
if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
|
191 |
-
distance += PatchSSDDistanceMetric::kSSDScale;
|
192 |
-
continue;
|
193 |
-
}
|
194 |
-
|
195 |
-
int ssd = 0;
|
196 |
-
for (int c = 0; c < 3; ++c) {
|
197 |
-
int s_value = p_si[xxs * 3 + c];
|
198 |
-
int t_value = p_ti[xxt * 3 + c];
|
199 |
-
int s_gy = p_sgy[xxs * 3 + c];
|
200 |
-
int t_gy = p_tgy[xxt * 3 + c];
|
201 |
-
int s_gx = p_sgx[xxs * 3 + c];
|
202 |
-
int t_gx = p_tgx[xxt * 3 + c];
|
203 |
-
|
204 |
-
ssd += pow2(static_cast<int>(s_value) - t_value);
|
205 |
-
ssd += pow2(static_cast<int>(s_gx) - t_gx);
|
206 |
-
ssd += pow2(static_cast<int>(s_gy) - t_gy);
|
207 |
-
}
|
208 |
-
distance += ssd;
|
209 |
-
}
|
210 |
-
}
|
211 |
-
|
212 |
-
distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
|
213 |
-
|
214 |
-
int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
|
215 |
-
if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
|
216 |
-
return res;
|
217 |
-
}
|
218 |
-
|
219 |
-
}
|
220 |
-
|
221 |
-
int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
222 |
-
return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
223 |
-
}
|
224 |
-
|
225 |
-
int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
226 |
-
fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
|
227 |
-
return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
228 |
-
}
|
229 |
-
|
230 |
-
int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
231 |
-
double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
|
232 |
-
double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
|
233 |
-
|
234 |
-
double score1 = sqrt(dx * dx + dy *dy) / m_scale;
|
235 |
-
if (score1 < 0 || score1 > 1) score1 = 1;
|
236 |
-
score1 *= PatchDistanceMetric::kDistanceScale;
|
237 |
-
|
238 |
-
double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
239 |
-
double score = score1 * m_weight + score2 / (1 + m_weight);
|
240 |
-
return static_cast<int>(score / (1 + m_weight));
|
241 |
-
}
|
242 |
-
|
243 |
-
int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
244 |
-
if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
|
245 |
-
return PatchDistanceMetric::kDistanceScale;
|
246 |
-
|
247 |
-
int source_scale = m_ijmap.size().height / source.size().height;
|
248 |
-
int target_scale = m_ijmap.size().height / target.size().height;
|
249 |
-
|
250 |
-
// fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
|
251 |
-
|
252 |
-
double score1 = PatchDistanceMetric::kDistanceScale;
|
253 |
-
if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
|
254 |
-
auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
|
255 |
-
auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
|
256 |
-
|
257 |
-
float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
|
258 |
-
float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
|
259 |
-
score1 = sqrt(di * di + dj *dj) / 0.707;
|
260 |
-
if (score1 < 0 || score1 > 1) score1 = 1;
|
261 |
-
score1 *= PatchDistanceMetric::kDistanceScale;
|
262 |
-
}
|
263 |
-
|
264 |
-
double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
265 |
-
double score = score1 * m_weight + score2;
|
266 |
-
return int(score / (1 + m_weight));
|
267 |
-
}
|
268 |
-
|
|
|
1 |
+
#include <algorithm>
|
2 |
+
#include <iostream>
|
3 |
+
#include <cmath>
|
4 |
+
|
5 |
+
#include "masked_image.h"
|
6 |
+
#include "nnf.h"
|
7 |
+
|
8 |
+
/**
|
9 |
+
* Nearest-Neighbor Field (see PatchMatch algorithm).
|
10 |
+
* This algorithme uses a version proposed by Xavier Philippeau.
|
11 |
+
*
|
12 |
+
*/
|
13 |
+
|
14 |
+
template <typename T>
|
15 |
+
T clamp(T value, T min_value, T max_value) {
|
16 |
+
return std::min(std::max(value, min_value), max_value);
|
17 |
+
}
|
18 |
+
|
19 |
+
void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
|
20 |
+
auto this_size = source_size();
|
21 |
+
for (int i = 0; i < this_size.height; ++i) {
|
22 |
+
for (int j = 0; j < this_size.width; ++j) {
|
23 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
24 |
+
|
25 |
+
auto this_ptr = mutable_ptr(i, j);
|
26 |
+
int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
|
27 |
+
if (distance < PatchDistanceMetric::kDistanceScale) {
|
28 |
+
continue;
|
29 |
+
}
|
30 |
+
|
31 |
+
int i_target = 0, j_target = 0;
|
32 |
+
for (int t = 0; t < max_retry; ++t) {
|
33 |
+
i_target = rand() % this_size.height;
|
34 |
+
j_target = rand() % this_size.width;
|
35 |
+
if (m_target.is_globally_masked(i_target, j_target)) continue;
|
36 |
+
|
37 |
+
distance = _distance(i, j, i_target, j_target);
|
38 |
+
if (distance < PatchDistanceMetric::kDistanceScale)
|
39 |
+
break;
|
40 |
+
}
|
41 |
+
|
42 |
+
this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
|
43 |
+
}
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
|
48 |
+
const auto &this_size = source_size();
|
49 |
+
const auto &other_size = other.source_size();
|
50 |
+
double fi = static_cast<double>(this_size.height) / other_size.height;
|
51 |
+
double fj = static_cast<double>(this_size.width) / other_size.width;
|
52 |
+
|
53 |
+
for (int i = 0; i < this_size.height; ++i) {
|
54 |
+
for (int j = 0; j < this_size.width; ++j) {
|
55 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
56 |
+
|
57 |
+
int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
|
58 |
+
int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
|
59 |
+
auto this_value = mutable_ptr(i, j);
|
60 |
+
auto other_value = other.ptr(ilow, jlow);
|
61 |
+
|
62 |
+
this_value[0] = static_cast<int>(other_value[0] * fi);
|
63 |
+
this_value[1] = static_cast<int>(other_value[1] * fj);
|
64 |
+
this_value[2] = _distance(i, j, this_value[0], this_value[1]);
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
_randomize_field(max_retry, false);
|
69 |
+
}
|
70 |
+
|
71 |
+
void NearestNeighborField::minimize(int nr_pass) {
|
72 |
+
const auto &this_size = source_size();
|
73 |
+
while (nr_pass--) {
|
74 |
+
for (int i = 0; i < this_size.height; ++i)
|
75 |
+
for (int j = 0; j < this_size.width; ++j) {
|
76 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
77 |
+
if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
|
78 |
+
}
|
79 |
+
for (int i = this_size.height - 1; i >= 0; --i)
|
80 |
+
for (int j = this_size.width - 1; j >= 0; --j) {
|
81 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
82 |
+
if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
|
83 |
+
}
|
84 |
+
}
|
85 |
+
}
|
86 |
+
|
87 |
+
void NearestNeighborField::_minimize_link(int y, int x, int direction) {
|
88 |
+
const auto &this_size = source_size();
|
89 |
+
const auto &this_target_size = target_size();
|
90 |
+
auto this_ptr = mutable_ptr(y, x);
|
91 |
+
|
92 |
+
// propagation along the y direction.
|
93 |
+
if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
|
94 |
+
int yp = at(y - direction, x, 0) + direction;
|
95 |
+
int xp = at(y - direction, x, 1);
|
96 |
+
int dp = _distance(y, x, yp, xp);
|
97 |
+
if (dp < at(y, x, 2)) {
|
98 |
+
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
// propagation along the x direction.
|
103 |
+
if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
|
104 |
+
int yp = at(y, x - direction, 0);
|
105 |
+
int xp = at(y, x - direction, 1) + direction;
|
106 |
+
int dp = _distance(y, x, yp, xp);
|
107 |
+
if (dp < at(y, x, 2)) {
|
108 |
+
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
109 |
+
}
|
110 |
+
}
|
111 |
+
|
112 |
+
// random search with a progressive step size.
|
113 |
+
int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
|
114 |
+
while (random_scale > 0) {
|
115 |
+
int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
|
116 |
+
int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
|
117 |
+
yp = clamp(yp, 0, target_size().height - 1);
|
118 |
+
xp = clamp(xp, 0, target_size().width - 1);
|
119 |
+
|
120 |
+
if (m_target.is_globally_masked(yp, xp)) {
|
121 |
+
random_scale /= 2;
|
122 |
+
}
|
123 |
+
|
124 |
+
int dp = _distance(y, x, yp, xp);
|
125 |
+
if (dp < at(y, x, 2)) {
|
126 |
+
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
127 |
+
}
|
128 |
+
random_scale /= 2;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
const int PatchDistanceMetric::kDistanceScale = 65535;
|
133 |
+
const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
|
134 |
+
|
135 |
+
namespace {
|
136 |
+
|
137 |
+
inline int pow2(int i) {
|
138 |
+
return i * i;
|
139 |
+
}
|
140 |
+
|
141 |
+
int distance_masked_images(
|
142 |
+
const MaskedImage &source, int ys, int xs,
|
143 |
+
const MaskedImage &target, int yt, int xt,
|
144 |
+
int patch_size
|
145 |
+
) {
|
146 |
+
long double distance = 0;
|
147 |
+
long double wsum = 0;
|
148 |
+
|
149 |
+
source.compute_image_gradients();
|
150 |
+
target.compute_image_gradients();
|
151 |
+
|
152 |
+
auto source_size = source.size();
|
153 |
+
auto target_size = target.size();
|
154 |
+
|
155 |
+
for (int dy = -patch_size; dy <= patch_size; ++dy) {
|
156 |
+
const int yys = ys + dy, yyt = yt + dy;
|
157 |
+
|
158 |
+
if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
|
159 |
+
distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
|
160 |
+
wsum += 2 * patch_size + 1;
|
161 |
+
continue;
|
162 |
+
}
|
163 |
+
|
164 |
+
const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
|
165 |
+
const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
|
166 |
+
const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
|
167 |
+
const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
|
168 |
+
|
169 |
+
const unsigned char *p_sgm = nullptr;
|
170 |
+
const unsigned char *p_tgm = nullptr;
|
171 |
+
if (!source.global_mask().empty()) {
|
172 |
+
p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
|
173 |
+
p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
|
174 |
+
}
|
175 |
+
|
176 |
+
const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
|
177 |
+
const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
|
178 |
+
const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
|
179 |
+
const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
|
180 |
+
|
181 |
+
for (int dx = -patch_size; dx <= patch_size; ++dx) {
|
182 |
+
int xxs = xs + dx, xxt = xt + dx;
|
183 |
+
wsum += 1;
|
184 |
+
|
185 |
+
if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
|
186 |
+
distance += PatchSSDDistanceMetric::kSSDScale;
|
187 |
+
continue;
|
188 |
+
}
|
189 |
+
|
190 |
+
if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
|
191 |
+
distance += PatchSSDDistanceMetric::kSSDScale;
|
192 |
+
continue;
|
193 |
+
}
|
194 |
+
|
195 |
+
int ssd = 0;
|
196 |
+
for (int c = 0; c < 3; ++c) {
|
197 |
+
int s_value = p_si[xxs * 3 + c];
|
198 |
+
int t_value = p_ti[xxt * 3 + c];
|
199 |
+
int s_gy = p_sgy[xxs * 3 + c];
|
200 |
+
int t_gy = p_tgy[xxt * 3 + c];
|
201 |
+
int s_gx = p_sgx[xxs * 3 + c];
|
202 |
+
int t_gx = p_tgx[xxt * 3 + c];
|
203 |
+
|
204 |
+
ssd += pow2(static_cast<int>(s_value) - t_value);
|
205 |
+
ssd += pow2(static_cast<int>(s_gx) - t_gx);
|
206 |
+
ssd += pow2(static_cast<int>(s_gy) - t_gy);
|
207 |
+
}
|
208 |
+
distance += ssd;
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
|
213 |
+
|
214 |
+
int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
|
215 |
+
if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
|
216 |
+
return res;
|
217 |
+
}
|
218 |
+
|
219 |
+
}
|
220 |
+
|
221 |
+
int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
222 |
+
return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
223 |
+
}
|
224 |
+
|
225 |
+
int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
226 |
+
fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
|
227 |
+
return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
228 |
+
}
|
229 |
+
|
230 |
+
int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
231 |
+
double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
|
232 |
+
double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
|
233 |
+
|
234 |
+
double score1 = sqrt(dx * dx + dy *dy) / m_scale;
|
235 |
+
if (score1 < 0 || score1 > 1) score1 = 1;
|
236 |
+
score1 *= PatchDistanceMetric::kDistanceScale;
|
237 |
+
|
238 |
+
double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
239 |
+
double score = score1 * m_weight + score2 / (1 + m_weight);
|
240 |
+
return static_cast<int>(score / (1 + m_weight));
|
241 |
+
}
|
242 |
+
|
243 |
+
int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
244 |
+
if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
|
245 |
+
return PatchDistanceMetric::kDistanceScale;
|
246 |
+
|
247 |
+
int source_scale = m_ijmap.size().height / source.size().height;
|
248 |
+
int target_scale = m_ijmap.size().height / target.size().height;
|
249 |
+
|
250 |
+
// fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
|
251 |
+
|
252 |
+
double score1 = PatchDistanceMetric::kDistanceScale;
|
253 |
+
if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
|
254 |
+
auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
|
255 |
+
auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
|
256 |
+
|
257 |
+
float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
|
258 |
+
float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
|
259 |
+
score1 = sqrt(di * di + dj *dj) / 0.707;
|
260 |
+
if (score1 < 0 || score1 > 1) score1 = 1;
|
261 |
+
score1 *= PatchDistanceMetric::kDistanceScale;
|
262 |
+
}
|
263 |
+
|
264 |
+
double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
265 |
+
double score = score1 * m_weight + score2;
|
266 |
+
return int(score / (1 + m_weight));
|
267 |
+
}
|
268 |
+
|
PyPatchMatch/csrc/nnf.h
CHANGED
@@ -1,133 +1,133 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <opencv2/core.hpp>
|
4 |
-
#include "masked_image.h"
|
5 |
-
|
6 |
-
class PatchDistanceMetric {
|
7 |
-
public:
|
8 |
-
PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
|
9 |
-
virtual ~PatchDistanceMetric() = default;
|
10 |
-
|
11 |
-
inline int patch_size() const { return m_patch_size; }
|
12 |
-
virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
|
13 |
-
static const int kDistanceScale;
|
14 |
-
|
15 |
-
protected:
|
16 |
-
int m_patch_size;
|
17 |
-
};
|
18 |
-
|
19 |
-
class NearestNeighborField {
|
20 |
-
public:
|
21 |
-
NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
|
22 |
-
// pass
|
23 |
-
}
|
24 |
-
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
|
25 |
-
: m_source(source), m_target(target), m_distance_metric(metric) {
|
26 |
-
m_field = cv::Mat(m_source.size(), CV_32SC3);
|
27 |
-
_randomize_field(max_retry);
|
28 |
-
}
|
29 |
-
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
|
30 |
-
: m_source(source), m_target(target), m_distance_metric(metric) {
|
31 |
-
m_field = cv::Mat(m_source.size(), CV_32SC3);
|
32 |
-
_initialize_field_from(other, max_retry);
|
33 |
-
}
|
34 |
-
|
35 |
-
const MaskedImage &source() const {
|
36 |
-
return m_source;
|
37 |
-
}
|
38 |
-
const MaskedImage &target() const {
|
39 |
-
return m_target;
|
40 |
-
}
|
41 |
-
inline cv::Size source_size() const {
|
42 |
-
return m_source.size();
|
43 |
-
}
|
44 |
-
inline cv::Size target_size() const {
|
45 |
-
return m_target.size();
|
46 |
-
}
|
47 |
-
inline void set_source(const MaskedImage &source) {
|
48 |
-
m_source = source;
|
49 |
-
}
|
50 |
-
inline void set_target(const MaskedImage &target) {
|
51 |
-
m_target = target;
|
52 |
-
}
|
53 |
-
|
54 |
-
inline int *mutable_ptr(int y, int x) {
|
55 |
-
return m_field.ptr<int>(y, x);
|
56 |
-
}
|
57 |
-
inline const int *ptr(int y, int x) const {
|
58 |
-
return m_field.ptr<int>(y, x);
|
59 |
-
}
|
60 |
-
|
61 |
-
inline int at(int y, int x, int c) const {
|
62 |
-
return m_field.ptr<int>(y, x)[c];
|
63 |
-
}
|
64 |
-
inline int &at(int y, int x, int c) {
|
65 |
-
return m_field.ptr<int>(y, x)[c];
|
66 |
-
}
|
67 |
-
inline void set_identity(int y, int x) {
|
68 |
-
auto ptr = mutable_ptr(y, x);
|
69 |
-
ptr[0] = y, ptr[1] = x, ptr[2] = 0;
|
70 |
-
}
|
71 |
-
|
72 |
-
void minimize(int nr_pass);
|
73 |
-
|
74 |
-
private:
|
75 |
-
inline int _distance(int source_y, int source_x, int target_y, int target_x) {
|
76 |
-
return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
|
77 |
-
}
|
78 |
-
|
79 |
-
void _randomize_field(int max_retry = 20, bool reset = true);
|
80 |
-
void _initialize_field_from(const NearestNeighborField &other, int max_retry);
|
81 |
-
void _minimize_link(int y, int x, int direction);
|
82 |
-
|
83 |
-
MaskedImage m_source;
|
84 |
-
MaskedImage m_target;
|
85 |
-
cv::Mat m_field; // { y_target, x_target, distance_scaled }
|
86 |
-
const PatchDistanceMetric *m_distance_metric;
|
87 |
-
};
|
88 |
-
|
89 |
-
|
90 |
-
class PatchSSDDistanceMetric : public PatchDistanceMetric {
|
91 |
-
public:
|
92 |
-
using PatchDistanceMetric::PatchDistanceMetric;
|
93 |
-
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
94 |
-
static const int kSSDScale;
|
95 |
-
};
|
96 |
-
|
97 |
-
class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
|
98 |
-
public:
|
99 |
-
DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
|
100 |
-
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
101 |
-
protected:
|
102 |
-
int m_width, m_height;
|
103 |
-
};
|
104 |
-
|
105 |
-
class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
|
106 |
-
public:
|
107 |
-
RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
|
108 |
-
: PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
|
109 |
-
|
110 |
-
assert(m_dy1 == 0);
|
111 |
-
assert(m_dx2 == 0);
|
112 |
-
m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
|
113 |
-
}
|
114 |
-
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
115 |
-
|
116 |
-
protected:
|
117 |
-
double m_dx1, m_dy1, m_dx2, m_dy2;
|
118 |
-
double m_scale, m_weight;
|
119 |
-
};
|
120 |
-
|
121 |
-
class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
|
122 |
-
public:
|
123 |
-
RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
|
124 |
-
: PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
|
125 |
-
|
126 |
-
}
|
127 |
-
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
128 |
-
|
129 |
-
protected:
|
130 |
-
cv::Mat m_ijmap;
|
131 |
-
double m_width, m_height, m_weight;
|
132 |
-
};
|
133 |
-
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <opencv2/core.hpp>
|
4 |
+
#include "masked_image.h"
|
5 |
+
|
6 |
+
class PatchDistanceMetric {
|
7 |
+
public:
|
8 |
+
PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
|
9 |
+
virtual ~PatchDistanceMetric() = default;
|
10 |
+
|
11 |
+
inline int patch_size() const { return m_patch_size; }
|
12 |
+
virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
|
13 |
+
static const int kDistanceScale;
|
14 |
+
|
15 |
+
protected:
|
16 |
+
int m_patch_size;
|
17 |
+
};
|
18 |
+
|
19 |
+
class NearestNeighborField {
|
20 |
+
public:
|
21 |
+
NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
|
22 |
+
// pass
|
23 |
+
}
|
24 |
+
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
|
25 |
+
: m_source(source), m_target(target), m_distance_metric(metric) {
|
26 |
+
m_field = cv::Mat(m_source.size(), CV_32SC3);
|
27 |
+
_randomize_field(max_retry);
|
28 |
+
}
|
29 |
+
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
|
30 |
+
: m_source(source), m_target(target), m_distance_metric(metric) {
|
31 |
+
m_field = cv::Mat(m_source.size(), CV_32SC3);
|
32 |
+
_initialize_field_from(other, max_retry);
|
33 |
+
}
|
34 |
+
|
35 |
+
const MaskedImage &source() const {
|
36 |
+
return m_source;
|
37 |
+
}
|
38 |
+
const MaskedImage &target() const {
|
39 |
+
return m_target;
|
40 |
+
}
|
41 |
+
inline cv::Size source_size() const {
|
42 |
+
return m_source.size();
|
43 |
+
}
|
44 |
+
inline cv::Size target_size() const {
|
45 |
+
return m_target.size();
|
46 |
+
}
|
47 |
+
inline void set_source(const MaskedImage &source) {
|
48 |
+
m_source = source;
|
49 |
+
}
|
50 |
+
inline void set_target(const MaskedImage &target) {
|
51 |
+
m_target = target;
|
52 |
+
}
|
53 |
+
|
54 |
+
inline int *mutable_ptr(int y, int x) {
|
55 |
+
return m_field.ptr<int>(y, x);
|
56 |
+
}
|
57 |
+
inline const int *ptr(int y, int x) const {
|
58 |
+
return m_field.ptr<int>(y, x);
|
59 |
+
}
|
60 |
+
|
61 |
+
inline int at(int y, int x, int c) const {
|
62 |
+
return m_field.ptr<int>(y, x)[c];
|
63 |
+
}
|
64 |
+
inline int &at(int y, int x, int c) {
|
65 |
+
return m_field.ptr<int>(y, x)[c];
|
66 |
+
}
|
67 |
+
inline void set_identity(int y, int x) {
|
68 |
+
auto ptr = mutable_ptr(y, x);
|
69 |
+
ptr[0] = y, ptr[1] = x, ptr[2] = 0;
|
70 |
+
}
|
71 |
+
|
72 |
+
void minimize(int nr_pass);
|
73 |
+
|
74 |
+
private:
|
75 |
+
inline int _distance(int source_y, int source_x, int target_y, int target_x) {
|
76 |
+
return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
|
77 |
+
}
|
78 |
+
|
79 |
+
void _randomize_field(int max_retry = 20, bool reset = true);
|
80 |
+
void _initialize_field_from(const NearestNeighborField &other, int max_retry);
|
81 |
+
void _minimize_link(int y, int x, int direction);
|
82 |
+
|
83 |
+
MaskedImage m_source;
|
84 |
+
MaskedImage m_target;
|
85 |
+
cv::Mat m_field; // { y_target, x_target, distance_scaled }
|
86 |
+
const PatchDistanceMetric *m_distance_metric;
|
87 |
+
};
|
88 |
+
|
89 |
+
|
90 |
+
class PatchSSDDistanceMetric : public PatchDistanceMetric {
|
91 |
+
public:
|
92 |
+
using PatchDistanceMetric::PatchDistanceMetric;
|
93 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
94 |
+
static const int kSSDScale;
|
95 |
+
};
|
96 |
+
|
97 |
+
class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
|
98 |
+
public:
|
99 |
+
DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
|
100 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
101 |
+
protected:
|
102 |
+
int m_width, m_height;
|
103 |
+
};
|
104 |
+
|
105 |
+
class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
|
106 |
+
public:
|
107 |
+
RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
|
108 |
+
: PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
|
109 |
+
|
110 |
+
assert(m_dy1 == 0);
|
111 |
+
assert(m_dx2 == 0);
|
112 |
+
m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
|
113 |
+
}
|
114 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
115 |
+
|
116 |
+
protected:
|
117 |
+
double m_dx1, m_dy1, m_dx2, m_dy2;
|
118 |
+
double m_scale, m_weight;
|
119 |
+
};
|
120 |
+
|
121 |
+
class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
|
122 |
+
public:
|
123 |
+
RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
|
124 |
+
: PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
|
125 |
+
|
126 |
+
}
|
127 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
128 |
+
|
129 |
+
protected:
|
130 |
+
cv::Mat m_ijmap;
|
131 |
+
double m_width, m_height, m_weight;
|
132 |
+
};
|
133 |
+
|
PyPatchMatch/csrc/pyinterface.cpp
CHANGED
@@ -1,107 +1,107 @@
|
|
1 |
-
#include "pyinterface.h"
|
2 |
-
#include "inpaint.h"
|
3 |
-
|
4 |
-
static unsigned int PM_seed = 1212;
|
5 |
-
static bool PM_verbose = false;
|
6 |
-
|
7 |
-
int _dtype_py_to_cv(int dtype_py);
|
8 |
-
int _dtype_cv_to_py(int dtype_cv);
|
9 |
-
cv::Mat _py_to_cv2(PM_mat_t pymat);
|
10 |
-
PM_mat_t _cv2_to_py(cv::Mat cvmat);
|
11 |
-
|
12 |
-
void PM_set_random_seed(unsigned int seed) {
|
13 |
-
PM_seed = seed;
|
14 |
-
}
|
15 |
-
|
16 |
-
void PM_set_verbose(int value) {
|
17 |
-
PM_verbose = static_cast<bool>(value);
|
18 |
-
}
|
19 |
-
|
20 |
-
void PM_free_pymat(PM_mat_t pymat) {
|
21 |
-
free(pymat.data_ptr);
|
22 |
-
}
|
23 |
-
|
24 |
-
PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
|
25 |
-
cv::Mat source = _py_to_cv2(source_py);
|
26 |
-
cv::Mat mask = _py_to_cv2(mask_py);
|
27 |
-
auto metric = PatchSSDDistanceMetric(patch_size);
|
28 |
-
cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
|
29 |
-
return _cv2_to_py(result);
|
30 |
-
}
|
31 |
-
|
32 |
-
PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
|
33 |
-
cv::Mat source = _py_to_cv2(source_py);
|
34 |
-
cv::Mat mask = _py_to_cv2(mask_py);
|
35 |
-
cv::Mat ijmap = _py_to_cv2(ijmap_py);
|
36 |
-
|
37 |
-
auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
|
38 |
-
cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
|
39 |
-
return _cv2_to_py(result);
|
40 |
-
}
|
41 |
-
|
42 |
-
PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
|
43 |
-
cv::Mat source = _py_to_cv2(source_py);
|
44 |
-
cv::Mat mask = _py_to_cv2(mask_py);
|
45 |
-
cv::Mat global_mask = _py_to_cv2(global_mask_py);
|
46 |
-
|
47 |
-
auto metric = PatchSSDDistanceMetric(patch_size);
|
48 |
-
cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
|
49 |
-
return _cv2_to_py(result);
|
50 |
-
}
|
51 |
-
|
52 |
-
PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
|
53 |
-
cv::Mat source = _py_to_cv2(source_py);
|
54 |
-
cv::Mat mask = _py_to_cv2(mask_py);
|
55 |
-
cv::Mat global_mask = _py_to_cv2(global_mask_py);
|
56 |
-
cv::Mat ijmap = _py_to_cv2(ijmap_py);
|
57 |
-
|
58 |
-
auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
|
59 |
-
cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
|
60 |
-
return _cv2_to_py(result);
|
61 |
-
}
|
62 |
-
|
63 |
-
int _dtype_py_to_cv(int dtype_py) {
|
64 |
-
switch (dtype_py) {
|
65 |
-
case PM_UINT8: return CV_8U;
|
66 |
-
case PM_INT8: return CV_8S;
|
67 |
-
case PM_UINT16: return CV_16U;
|
68 |
-
case PM_INT16: return CV_16S;
|
69 |
-
case PM_INT32: return CV_32S;
|
70 |
-
case PM_FLOAT32: return CV_32F;
|
71 |
-
case PM_FLOAT64: return CV_64F;
|
72 |
-
}
|
73 |
-
|
74 |
-
return CV_8U;
|
75 |
-
}
|
76 |
-
|
77 |
-
int _dtype_cv_to_py(int dtype_cv) {
|
78 |
-
switch (dtype_cv) {
|
79 |
-
case CV_8U: return PM_UINT8;
|
80 |
-
case CV_8S: return PM_INT8;
|
81 |
-
case CV_16U: return PM_UINT16;
|
82 |
-
case CV_16S: return PM_INT16;
|
83 |
-
case CV_32S: return PM_INT32;
|
84 |
-
case CV_32F: return PM_FLOAT32;
|
85 |
-
case CV_64F: return PM_FLOAT64;
|
86 |
-
}
|
87 |
-
|
88 |
-
return PM_UINT8;
|
89 |
-
}
|
90 |
-
|
91 |
-
cv::Mat _py_to_cv2(PM_mat_t pymat) {
|
92 |
-
int dtype = _dtype_py_to_cv(pymat.dtype);
|
93 |
-
dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
|
94 |
-
return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
|
95 |
-
}
|
96 |
-
|
97 |
-
PM_mat_t _cv2_to_py(cv::Mat cvmat) {
|
98 |
-
PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
|
99 |
-
int dtype = _dtype_cv_to_py(cvmat.depth());
|
100 |
-
size_t dsize = cvmat.total() * cvmat.elemSize();
|
101 |
-
|
102 |
-
void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
|
103 |
-
memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
|
104 |
-
|
105 |
-
return PM_mat_t {data_ptr, shape, dtype};
|
106 |
-
}
|
107 |
-
|
|
|
1 |
+
#include "pyinterface.h"
|
2 |
+
#include "inpaint.h"
|
3 |
+
|
4 |
+
static unsigned int PM_seed = 1212;
|
5 |
+
static bool PM_verbose = false;
|
6 |
+
|
7 |
+
int _dtype_py_to_cv(int dtype_py);
|
8 |
+
int _dtype_cv_to_py(int dtype_cv);
|
9 |
+
cv::Mat _py_to_cv2(PM_mat_t pymat);
|
10 |
+
PM_mat_t _cv2_to_py(cv::Mat cvmat);
|
11 |
+
|
12 |
+
void PM_set_random_seed(unsigned int seed) {
|
13 |
+
PM_seed = seed;
|
14 |
+
}
|
15 |
+
|
16 |
+
void PM_set_verbose(int value) {
|
17 |
+
PM_verbose = static_cast<bool>(value);
|
18 |
+
}
|
19 |
+
|
20 |
+
void PM_free_pymat(PM_mat_t pymat) {
|
21 |
+
free(pymat.data_ptr);
|
22 |
+
}
|
23 |
+
|
24 |
+
PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
|
25 |
+
cv::Mat source = _py_to_cv2(source_py);
|
26 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
27 |
+
auto metric = PatchSSDDistanceMetric(patch_size);
|
28 |
+
cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
|
29 |
+
return _cv2_to_py(result);
|
30 |
+
}
|
31 |
+
|
32 |
+
PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
|
33 |
+
cv::Mat source = _py_to_cv2(source_py);
|
34 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
35 |
+
cv::Mat ijmap = _py_to_cv2(ijmap_py);
|
36 |
+
|
37 |
+
auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
|
38 |
+
cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
|
39 |
+
return _cv2_to_py(result);
|
40 |
+
}
|
41 |
+
|
42 |
+
PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
|
43 |
+
cv::Mat source = _py_to_cv2(source_py);
|
44 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
45 |
+
cv::Mat global_mask = _py_to_cv2(global_mask_py);
|
46 |
+
|
47 |
+
auto metric = PatchSSDDistanceMetric(patch_size);
|
48 |
+
cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
|
49 |
+
return _cv2_to_py(result);
|
50 |
+
}
|
51 |
+
|
52 |
+
PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
|
53 |
+
cv::Mat source = _py_to_cv2(source_py);
|
54 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
55 |
+
cv::Mat global_mask = _py_to_cv2(global_mask_py);
|
56 |
+
cv::Mat ijmap = _py_to_cv2(ijmap_py);
|
57 |
+
|
58 |
+
auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
|
59 |
+
cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
|
60 |
+
return _cv2_to_py(result);
|
61 |
+
}
|
62 |
+
|
63 |
+
int _dtype_py_to_cv(int dtype_py) {
|
64 |
+
switch (dtype_py) {
|
65 |
+
case PM_UINT8: return CV_8U;
|
66 |
+
case PM_INT8: return CV_8S;
|
67 |
+
case PM_UINT16: return CV_16U;
|
68 |
+
case PM_INT16: return CV_16S;
|
69 |
+
case PM_INT32: return CV_32S;
|
70 |
+
case PM_FLOAT32: return CV_32F;
|
71 |
+
case PM_FLOAT64: return CV_64F;
|
72 |
+
}
|
73 |
+
|
74 |
+
return CV_8U;
|
75 |
+
}
|
76 |
+
|
77 |
+
int _dtype_cv_to_py(int dtype_cv) {
|
78 |
+
switch (dtype_cv) {
|
79 |
+
case CV_8U: return PM_UINT8;
|
80 |
+
case CV_8S: return PM_INT8;
|
81 |
+
case CV_16U: return PM_UINT16;
|
82 |
+
case CV_16S: return PM_INT16;
|
83 |
+
case CV_32S: return PM_INT32;
|
84 |
+
case CV_32F: return PM_FLOAT32;
|
85 |
+
case CV_64F: return PM_FLOAT64;
|
86 |
+
}
|
87 |
+
|
88 |
+
return PM_UINT8;
|
89 |
+
}
|
90 |
+
|
91 |
+
cv::Mat _py_to_cv2(PM_mat_t pymat) {
|
92 |
+
int dtype = _dtype_py_to_cv(pymat.dtype);
|
93 |
+
dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
|
94 |
+
return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
|
95 |
+
}
|
96 |
+
|
97 |
+
PM_mat_t _cv2_to_py(cv::Mat cvmat) {
|
98 |
+
PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
|
99 |
+
int dtype = _dtype_cv_to_py(cvmat.depth());
|
100 |
+
size_t dsize = cvmat.total() * cvmat.elemSize();
|
101 |
+
|
102 |
+
void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
|
103 |
+
memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
|
104 |
+
|
105 |
+
return PM_mat_t {data_ptr, shape, dtype};
|
106 |
+
}
|
107 |
+
|
PyPatchMatch/csrc/pyinterface.h
CHANGED
@@ -1,38 +1,38 @@
|
|
1 |
-
#include <opencv2/core.hpp>
|
2 |
-
#include <cstdlib>
|
3 |
-
#include <cstdio>
|
4 |
-
#include <cstring>
|
5 |
-
|
6 |
-
extern "C" {
|
7 |
-
|
8 |
-
struct PM_shape_t {
|
9 |
-
int width, height, channels;
|
10 |
-
};
|
11 |
-
|
12 |
-
enum PM_dtype_e {
|
13 |
-
PM_UINT8,
|
14 |
-
PM_INT8,
|
15 |
-
PM_UINT16,
|
16 |
-
PM_INT16,
|
17 |
-
PM_INT32,
|
18 |
-
PM_FLOAT32,
|
19 |
-
PM_FLOAT64,
|
20 |
-
};
|
21 |
-
|
22 |
-
struct PM_mat_t {
|
23 |
-
void *data_ptr;
|
24 |
-
PM_shape_t shape;
|
25 |
-
int dtype;
|
26 |
-
};
|
27 |
-
|
28 |
-
void PM_set_random_seed(unsigned int seed);
|
29 |
-
void PM_set_verbose(int value);
|
30 |
-
|
31 |
-
void PM_free_pymat(PM_mat_t pymat);
|
32 |
-
PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
|
33 |
-
PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
|
34 |
-
PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
|
35 |
-
PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
|
36 |
-
|
37 |
-
} /* extern "C" */
|
38 |
-
|
|
|
1 |
+
#include <opencv2/core.hpp>
|
2 |
+
#include <cstdlib>
|
3 |
+
#include <cstdio>
|
4 |
+
#include <cstring>
|
5 |
+
|
6 |
+
extern "C" {
|
7 |
+
|
8 |
+
struct PM_shape_t {
|
9 |
+
int width, height, channels;
|
10 |
+
};
|
11 |
+
|
12 |
+
enum PM_dtype_e {
|
13 |
+
PM_UINT8,
|
14 |
+
PM_INT8,
|
15 |
+
PM_UINT16,
|
16 |
+
PM_INT16,
|
17 |
+
PM_INT32,
|
18 |
+
PM_FLOAT32,
|
19 |
+
PM_FLOAT64,
|
20 |
+
};
|
21 |
+
|
22 |
+
struct PM_mat_t {
|
23 |
+
void *data_ptr;
|
24 |
+
PM_shape_t shape;
|
25 |
+
int dtype;
|
26 |
+
};
|
27 |
+
|
28 |
+
void PM_set_random_seed(unsigned int seed);
|
29 |
+
void PM_set_verbose(int value);
|
30 |
+
|
31 |
+
void PM_free_pymat(PM_mat_t pymat);
|
32 |
+
PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
|
33 |
+
PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
|
34 |
+
PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
|
35 |
+
PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
|
36 |
+
|
37 |
+
} /* extern "C" */
|
38 |
+
|
PyPatchMatch/examples/.gitignore
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
/cpp_example.exe
|
2 |
-
/images/*recovered.bmp
|
|
|
1 |
+
/cpp_example.exe
|
2 |
+
/images/*recovered.bmp
|
PyPatchMatch/examples/cpp_example.cpp
CHANGED
@@ -1,31 +1,31 @@
|
|
1 |
-
#include <iostream>
|
2 |
-
#include <opencv2/imgcodecs.hpp>
|
3 |
-
#include <opencv2/highgui.hpp>
|
4 |
-
|
5 |
-
#include "masked_image.h"
|
6 |
-
#include "nnf.h"
|
7 |
-
#include "inpaint.h"
|
8 |
-
|
9 |
-
int main() {
|
10 |
-
auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
|
11 |
-
|
12 |
-
auto mask = cv::Mat(source.size(), CV_8UC1);
|
13 |
-
mask = cv::Scalar::all(0);
|
14 |
-
for (int i = 0; i < source.size().height; ++i) {
|
15 |
-
for (int j = 0; j < source.size().width; ++j) {
|
16 |
-
auto source_ptr = source.ptr<unsigned char>(i, j);
|
17 |
-
if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
|
18 |
-
mask.at<unsigned char>(i, j) = 1;
|
19 |
-
}
|
20 |
-
}
|
21 |
-
}
|
22 |
-
|
23 |
-
auto metric = PatchSSDDistanceMetric(3);
|
24 |
-
auto result = Inpainting(source, mask, &metric).run(true, true);
|
25 |
-
// cv::imwrite("./images/forest_recovered.bmp", result);
|
26 |
-
// cv::imshow("Result", result);
|
27 |
-
// cv::waitKey();
|
28 |
-
|
29 |
-
return 0;
|
30 |
-
}
|
31 |
-
|
|
|
1 |
+
#include <iostream>
|
2 |
+
#include <opencv2/imgcodecs.hpp>
|
3 |
+
#include <opencv2/highgui.hpp>
|
4 |
+
|
5 |
+
#include "masked_image.h"
|
6 |
+
#include "nnf.h"
|
7 |
+
#include "inpaint.h"
|
8 |
+
|
9 |
+
int main() {
|
10 |
+
auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
|
11 |
+
|
12 |
+
auto mask = cv::Mat(source.size(), CV_8UC1);
|
13 |
+
mask = cv::Scalar::all(0);
|
14 |
+
for (int i = 0; i < source.size().height; ++i) {
|
15 |
+
for (int j = 0; j < source.size().width; ++j) {
|
16 |
+
auto source_ptr = source.ptr<unsigned char>(i, j);
|
17 |
+
if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
|
18 |
+
mask.at<unsigned char>(i, j) = 1;
|
19 |
+
}
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
auto metric = PatchSSDDistanceMetric(3);
|
24 |
+
auto result = Inpainting(source, mask, &metric).run(true, true);
|
25 |
+
// cv::imwrite("./images/forest_recovered.bmp", result);
|
26 |
+
// cv::imshow("Result", result);
|
27 |
+
// cv::waitKey();
|
28 |
+
|
29 |
+
return 0;
|
30 |
+
}
|
31 |
+
|
PyPatchMatch/examples/cpp_example_run.sh
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
-
#! /bin/bash
|
2 |
-
#
|
3 |
-
# cpp_example_run.sh
|
4 |
-
# Copyright (C) 2020 Jiayuan Mao <[email protected]>
|
5 |
-
#
|
6 |
-
# Distributed under terms of the MIT license.
|
7 |
-
#
|
8 |
-
|
9 |
-
set -x
|
10 |
-
|
11 |
-
CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
|
12 |
-
LDFLAGS="$(pkg-config --libs opencv)"
|
13 |
-
g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
|
14 |
-
|
15 |
-
export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
|
16 |
-
export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
|
17 |
-
time ./cpp_example.exe
|
18 |
-
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
#
|
3 |
+
# cpp_example_run.sh
|
4 |
+
# Copyright (C) 2020 Jiayuan Mao <[email protected]>
|
5 |
+
#
|
6 |
+
# Distributed under terms of the MIT license.
|
7 |
+
#
|
8 |
+
|
9 |
+
set -x
|
10 |
+
|
11 |
+
CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
|
12 |
+
LDFLAGS="$(pkg-config --libs opencv)"
|
13 |
+
g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
|
14 |
+
|
15 |
+
export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
|
16 |
+
export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
|
17 |
+
time ./cpp_example.exe
|
18 |
+
|
PyPatchMatch/examples/py_example.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
#! /usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
# File : test.py
|
4 |
-
# Author : Jiayuan Mao
|
5 |
-
# Email : [email protected]
|
6 |
-
# Date : 01/09/2020
|
7 |
-
#
|
8 |
-
# Distributed under terms of the MIT license.
|
9 |
-
|
10 |
-
from PIL import Image
|
11 |
-
|
12 |
-
import sys
|
13 |
-
sys.path.insert(0, '../')
|
14 |
-
import patch_match
|
15 |
-
|
16 |
-
|
17 |
-
if __name__ == '__main__':
|
18 |
-
source = Image.open('./images/forest_pruned.bmp')
|
19 |
-
result = patch_match.inpaint(source, patch_size=3)
|
20 |
-
Image.fromarray(result).save('./images/forest_recovered.bmp')
|
21 |
-
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : test.py
|
4 |
+
# Author : Jiayuan Mao
|
5 |
+
# Email : [email protected]
|
6 |
+
# Date : 01/09/2020
|
7 |
+
#
|
8 |
+
# Distributed under terms of the MIT license.
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
import sys
|
13 |
+
sys.path.insert(0, '../')
|
14 |
+
import patch_match
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
source = Image.open('./images/forest_pruned.bmp')
|
19 |
+
result = patch_match.inpaint(source, patch_size=3)
|
20 |
+
Image.fromarray(result).save('./images/forest_recovered.bmp')
|
21 |
+
|
PyPatchMatch/examples/py_example_global_mask.py
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
-
#! /usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
# File : test.py
|
4 |
-
# Author : Jiayuan Mao
|
5 |
-
# Email : [email protected]
|
6 |
-
# Date : 01/09/2020
|
7 |
-
#
|
8 |
-
# Distributed under terms of the MIT license.
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
from PIL import Image
|
12 |
-
|
13 |
-
import sys
|
14 |
-
sys.path.insert(0, '../')
|
15 |
-
import patch_match
|
16 |
-
|
17 |
-
|
18 |
-
if __name__ == '__main__':
|
19 |
-
patch_match.set_verbose(True)
|
20 |
-
source = Image.open('./images/forest_pruned.bmp')
|
21 |
-
source = np.array(source)
|
22 |
-
source[:100, :100] = 255
|
23 |
-
global_mask = np.zeros_like(source[..., 0])
|
24 |
-
global_mask[:100, :100] = 1
|
25 |
-
result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
|
26 |
-
Image.fromarray(result).save('./images/forest_recovered.bmp')
|
27 |
-
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : test.py
|
4 |
+
# Author : Jiayuan Mao
|
5 |
+
# Email : [email protected]
|
6 |
+
# Date : 01/09/2020
|
7 |
+
#
|
8 |
+
# Distributed under terms of the MIT license.
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, '../')
|
15 |
+
import patch_match
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
patch_match.set_verbose(True)
|
20 |
+
source = Image.open('./images/forest_pruned.bmp')
|
21 |
+
source = np.array(source)
|
22 |
+
source[:100, :100] = 255
|
23 |
+
global_mask = np.zeros_like(source[..., 0])
|
24 |
+
global_mask[:100, :100] = 1
|
25 |
+
result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
|
26 |
+
Image.fromarray(result).save('./images/forest_recovered.bmp')
|
27 |
+
|
PyPatchMatch/patch_match.py
CHANGED
@@ -1,201 +1,263 @@
|
|
1 |
-
#! /usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
# File : patch_match.py
|
4 |
-
# Author : Jiayuan Mao
|
5 |
-
# Email : [email protected]
|
6 |
-
# Date : 01/09/2020
|
7 |
-
#
|
8 |
-
# Distributed under terms of the MIT license.
|
9 |
-
|
10 |
-
import ctypes
|
11 |
-
import os.path as osp
|
12 |
-
from typing import Optional, Union
|
13 |
-
|
14 |
-
import numpy as np
|
15 |
-
from PIL import Image
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
('
|
34 |
-
('
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
('
|
42 |
-
('
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
)
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
)
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : patch_match.py
|
4 |
+
# Author : Jiayuan Mao
|
5 |
+
# Email : [email protected]
|
6 |
+
# Date : 01/09/2020
|
7 |
+
#
|
8 |
+
# Distributed under terms of the MIT license.
|
9 |
+
|
10 |
+
import ctypes
|
11 |
+
import os.path as osp
|
12 |
+
from typing import Optional, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
|
18 |
+
import os
|
19 |
+
if os.name!="nt":
|
20 |
+
# Otherwise, fall back to the subprocess.
|
21 |
+
import subprocess
|
22 |
+
print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
|
23 |
+
# subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
|
24 |
+
subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
|
25 |
+
|
26 |
+
|
27 |
+
__all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
|
28 |
+
|
29 |
+
|
30 |
+
class CShapeT(ctypes.Structure):
|
31 |
+
_fields_ = [
|
32 |
+
('width', ctypes.c_int),
|
33 |
+
('height', ctypes.c_int),
|
34 |
+
('channels', ctypes.c_int),
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
class CMatT(ctypes.Structure):
|
39 |
+
_fields_ = [
|
40 |
+
('data_ptr', ctypes.c_void_p),
|
41 |
+
('shape', CShapeT),
|
42 |
+
('dtype', ctypes.c_int)
|
43 |
+
]
|
44 |
+
|
45 |
+
import tempfile
|
46 |
+
from urllib.request import urlopen, Request
|
47 |
+
import shutil
|
48 |
+
from pathlib import Path
|
49 |
+
from tqdm import tqdm
|
50 |
+
|
51 |
+
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
52 |
+
r"""Download object at the given URL to a local path.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
url (string): URL of the object to download
|
56 |
+
dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
|
57 |
+
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
|
58 |
+
Default: None
|
59 |
+
progress (bool, optional): whether or not to display a progress bar to stderr
|
60 |
+
Default: True
|
61 |
+
https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
|
62 |
+
"""
|
63 |
+
file_size = None
|
64 |
+
req = Request(url)
|
65 |
+
u = urlopen(req)
|
66 |
+
meta = u.info()
|
67 |
+
if hasattr(meta, 'getheaders'):
|
68 |
+
content_length = meta.getheaders("Content-Length")
|
69 |
+
else:
|
70 |
+
content_length = meta.get_all("Content-Length")
|
71 |
+
if content_length is not None and len(content_length) > 0:
|
72 |
+
file_size = int(content_length[0])
|
73 |
+
|
74 |
+
# We deliberately save it in a temp file and move it after
|
75 |
+
# download is complete. This prevents a local working checkpoint
|
76 |
+
# being overridden by a broken download.
|
77 |
+
dst = os.path.expanduser(dst)
|
78 |
+
dst_dir = os.path.dirname(dst)
|
79 |
+
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
80 |
+
|
81 |
+
try:
|
82 |
+
with tqdm(total=file_size, disable=not progress,
|
83 |
+
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
84 |
+
while True:
|
85 |
+
buffer = u.read(8192)
|
86 |
+
if len(buffer) == 0:
|
87 |
+
break
|
88 |
+
f.write(buffer)
|
89 |
+
pbar.update(len(buffer))
|
90 |
+
|
91 |
+
f.close()
|
92 |
+
shutil.move(f.name, dst)
|
93 |
+
finally:
|
94 |
+
f.close()
|
95 |
+
if os.path.exists(f.name):
|
96 |
+
os.remove(f.name)
|
97 |
+
|
98 |
+
if os.name!="nt":
|
99 |
+
PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
|
100 |
+
else:
|
101 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
|
102 |
+
download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
|
103 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
|
104 |
+
download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll'))
|
105 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
|
106 |
+
print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder")
|
107 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
|
108 |
+
print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder")
|
109 |
+
PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
|
110 |
+
|
111 |
+
PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
|
112 |
+
PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
|
113 |
+
PMLIB.PM_free_pymat.argtypes = [CMatT]
|
114 |
+
PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
|
115 |
+
PMLIB.PM_inpaint.restype = CMatT
|
116 |
+
PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
|
117 |
+
PMLIB.PM_inpaint_regularity.restype = CMatT
|
118 |
+
PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
|
119 |
+
PMLIB.PM_inpaint2.restype = CMatT
|
120 |
+
PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
|
121 |
+
PMLIB.PM_inpaint2_regularity.restype = CMatT
|
122 |
+
|
123 |
+
|
124 |
+
def set_random_seed(seed: int):
|
125 |
+
PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
|
126 |
+
|
127 |
+
|
128 |
+
def set_verbose(verbose: bool):
|
129 |
+
PMLIB.PM_set_verbose(ctypes.c_int(verbose))
|
130 |
+
|
131 |
+
|
132 |
+
def inpaint(
|
133 |
+
image: Union[np.ndarray, Image.Image],
|
134 |
+
mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
135 |
+
*,
|
136 |
+
global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
137 |
+
patch_size: int = 15
|
138 |
+
) -> np.ndarray:
|
139 |
+
"""
|
140 |
+
PatchMatch based inpainting proposed in:
|
141 |
+
|
142 |
+
PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
|
143 |
+
C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
|
144 |
+
SIGGRAPH 2009
|
145 |
+
|
146 |
+
Args:
|
147 |
+
image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
|
148 |
+
mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
|
149 |
+
If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
|
150 |
+
global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
|
151 |
+
patch_size (int): the patch size for the inpainting algorithm.
|
152 |
+
|
153 |
+
Return:
|
154 |
+
result (np.ndarray): the repaired image, of the same size as the input image.
|
155 |
+
"""
|
156 |
+
|
157 |
+
if isinstance(image, Image.Image):
|
158 |
+
image = np.array(image)
|
159 |
+
image = np.ascontiguousarray(image)
|
160 |
+
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
|
161 |
+
|
162 |
+
if mask is None:
|
163 |
+
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
|
164 |
+
mask = np.ascontiguousarray(mask)
|
165 |
+
else:
|
166 |
+
mask = _canonize_mask_array(mask)
|
167 |
+
|
168 |
+
if global_mask is None:
|
169 |
+
ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
|
170 |
+
else:
|
171 |
+
global_mask = _canonize_mask_array(global_mask)
|
172 |
+
ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
|
173 |
+
|
174 |
+
ret_npmat = pymat_to_np(ret_pymat)
|
175 |
+
PMLIB.PM_free_pymat(ret_pymat)
|
176 |
+
|
177 |
+
return ret_npmat
|
178 |
+
|
179 |
+
|
180 |
+
def inpaint_regularity(
|
181 |
+
image: Union[np.ndarray, Image.Image],
|
182 |
+
mask: Optional[Union[np.ndarray, Image.Image]],
|
183 |
+
ijmap: np.ndarray,
|
184 |
+
*,
|
185 |
+
global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
186 |
+
patch_size: int = 15, guide_weight: float = 0.25
|
187 |
+
) -> np.ndarray:
|
188 |
+
if isinstance(image, Image.Image):
|
189 |
+
image = np.array(image)
|
190 |
+
image = np.ascontiguousarray(image)
|
191 |
+
|
192 |
+
assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
|
193 |
+
ijmap = np.ascontiguousarray(ijmap)
|
194 |
+
|
195 |
+
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
|
196 |
+
if mask is None:
|
197 |
+
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
|
198 |
+
mask = np.ascontiguousarray(mask)
|
199 |
+
else:
|
200 |
+
mask = _canonize_mask_array(mask)
|
201 |
+
|
202 |
+
|
203 |
+
if global_mask is None:
|
204 |
+
ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
|
205 |
+
else:
|
206 |
+
global_mask = _canonize_mask_array(global_mask)
|
207 |
+
ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
|
208 |
+
|
209 |
+
ret_npmat = pymat_to_np(ret_pymat)
|
210 |
+
PMLIB.PM_free_pymat(ret_pymat)
|
211 |
+
|
212 |
+
return ret_npmat
|
213 |
+
|
214 |
+
|
215 |
+
def _canonize_mask_array(mask):
|
216 |
+
if isinstance(mask, Image.Image):
|
217 |
+
mask = np.array(mask)
|
218 |
+
if mask.ndim == 2 and mask.dtype == 'uint8':
|
219 |
+
mask = mask[..., np.newaxis]
|
220 |
+
assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
|
221 |
+
return np.ascontiguousarray(mask)
|
222 |
+
|
223 |
+
|
224 |
+
dtype_pymat_to_ctypes = [
|
225 |
+
ctypes.c_uint8,
|
226 |
+
ctypes.c_int8,
|
227 |
+
ctypes.c_uint16,
|
228 |
+
ctypes.c_int16,
|
229 |
+
ctypes.c_int32,
|
230 |
+
ctypes.c_float,
|
231 |
+
ctypes.c_double,
|
232 |
+
]
|
233 |
+
|
234 |
+
|
235 |
+
dtype_np_to_pymat = {
|
236 |
+
'uint8': 0,
|
237 |
+
'int8': 1,
|
238 |
+
'uint16': 2,
|
239 |
+
'int16': 3,
|
240 |
+
'int32': 4,
|
241 |
+
'float32': 5,
|
242 |
+
'float64': 6,
|
243 |
+
}
|
244 |
+
|
245 |
+
|
246 |
+
def np_to_pymat(npmat):
|
247 |
+
assert npmat.ndim == 3
|
248 |
+
return CMatT(
|
249 |
+
ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
|
250 |
+
CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
|
251 |
+
dtype_np_to_pymat[str(npmat.dtype)]
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def pymat_to_np(pymat):
|
256 |
+
npmat = np.ctypeslib.as_array(
|
257 |
+
ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
|
258 |
+
(pymat.shape.height, pymat.shape.width, pymat.shape.channels)
|
259 |
+
)
|
260 |
+
ret = np.empty(npmat.shape, npmat.dtype)
|
261 |
+
ret[:] = npmat
|
262 |
+
return ret
|
263 |
+
|
PyPatchMatch/travis.sh
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
#! /bin/bash
|
2 |
-
#
|
3 |
-
# travis.sh
|
4 |
-
# Copyright (C) 2020 Jiayuan Mao <[email protected]>
|
5 |
-
#
|
6 |
-
# Distributed under terms of the MIT license.
|
7 |
-
#
|
8 |
-
|
9 |
-
make clean && make
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
#
|
3 |
+
# travis.sh
|
4 |
+
# Copyright (C) 2020 Jiayuan Mao <[email protected]>
|
5 |
+
#
|
6 |
+
# Distributed under terms of the MIT license.
|
7 |
+
#
|
8 |
+
|
9 |
+
make clean && make
|
config.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
shortcut:
|
2 |
+
clear: Escape
|
3 |
+
load: Ctrl+o
|
4 |
+
save: Ctrl+s
|
5 |
+
export: Ctrl+e
|
6 |
+
upload: Ctrl+u
|
7 |
+
selection: 1
|
8 |
+
canvas: 2
|
9 |
+
eraser: 3
|
10 |
+
outpaint: d
|
11 |
+
accept: a
|
12 |
+
cancel: c
|
13 |
+
retry: r
|
14 |
+
prev: q
|
15 |
+
next: e
|
16 |
+
zoom_in: z
|
17 |
+
zoom_out: x
|
18 |
+
random_seed: s
|
convert_checkpoint.py
ADDED
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
|
16 |
+
""" Conversion script for the LDM checkpoints. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
from omegaconf import OmegaConf
|
26 |
+
except ImportError:
|
27 |
+
raise ImportError(
|
28 |
+
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
29 |
+
)
|
30 |
+
|
31 |
+
from diffusers import (
|
32 |
+
AutoencoderKL,
|
33 |
+
DDIMScheduler,
|
34 |
+
LDMTextToImagePipeline,
|
35 |
+
LMSDiscreteScheduler,
|
36 |
+
PNDMScheduler,
|
37 |
+
StableDiffusionPipeline,
|
38 |
+
UNet2DConditionModel,
|
39 |
+
)
|
40 |
+
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
41 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
42 |
+
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
43 |
+
|
44 |
+
|
45 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
46 |
+
"""
|
47 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
48 |
+
"""
|
49 |
+
if n_shave_prefix_segments >= 0:
|
50 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
51 |
+
else:
|
52 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
53 |
+
|
54 |
+
|
55 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
56 |
+
"""
|
57 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
58 |
+
"""
|
59 |
+
mapping = []
|
60 |
+
for old_item in old_list:
|
61 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
62 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
63 |
+
|
64 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
65 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
66 |
+
|
67 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
68 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
69 |
+
|
70 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
71 |
+
|
72 |
+
mapping.append({"old": old_item, "new": new_item})
|
73 |
+
|
74 |
+
return mapping
|
75 |
+
|
76 |
+
|
77 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
78 |
+
"""
|
79 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
80 |
+
"""
|
81 |
+
mapping = []
|
82 |
+
for old_item in old_list:
|
83 |
+
new_item = old_item
|
84 |
+
|
85 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
86 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
87 |
+
|
88 |
+
mapping.append({"old": old_item, "new": new_item})
|
89 |
+
|
90 |
+
return mapping
|
91 |
+
|
92 |
+
|
93 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
94 |
+
"""
|
95 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
96 |
+
"""
|
97 |
+
mapping = []
|
98 |
+
for old_item in old_list:
|
99 |
+
new_item = old_item
|
100 |
+
|
101 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
102 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
103 |
+
|
104 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
105 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
106 |
+
|
107 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
108 |
+
|
109 |
+
mapping.append({"old": old_item, "new": new_item})
|
110 |
+
|
111 |
+
return mapping
|
112 |
+
|
113 |
+
|
114 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
115 |
+
"""
|
116 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
117 |
+
"""
|
118 |
+
mapping = []
|
119 |
+
for old_item in old_list:
|
120 |
+
new_item = old_item
|
121 |
+
|
122 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
123 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
124 |
+
|
125 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
126 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
127 |
+
|
128 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
129 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
130 |
+
|
131 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
132 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
133 |
+
|
134 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
135 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
136 |
+
|
137 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
138 |
+
|
139 |
+
mapping.append({"old": old_item, "new": new_item})
|
140 |
+
|
141 |
+
return mapping
|
142 |
+
|
143 |
+
|
144 |
+
def assign_to_checkpoint(
|
145 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
149 |
+
to them. It splits attention layers, and takes into account additional replacements
|
150 |
+
that may arise.
|
151 |
+
|
152 |
+
Assigns the weights to the new checkpoint.
|
153 |
+
"""
|
154 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
155 |
+
|
156 |
+
# Splits the attention layers into three variables.
|
157 |
+
if attention_paths_to_split is not None:
|
158 |
+
for path, path_map in attention_paths_to_split.items():
|
159 |
+
old_tensor = old_checkpoint[path]
|
160 |
+
channels = old_tensor.shape[0] // 3
|
161 |
+
|
162 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
163 |
+
|
164 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
165 |
+
|
166 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
167 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
168 |
+
|
169 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
170 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
171 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
172 |
+
|
173 |
+
for path in paths:
|
174 |
+
new_path = path["new"]
|
175 |
+
|
176 |
+
# These have already been assigned
|
177 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
178 |
+
continue
|
179 |
+
|
180 |
+
# Global renaming happens here
|
181 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
182 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
183 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
184 |
+
|
185 |
+
if additional_replacements is not None:
|
186 |
+
for replacement in additional_replacements:
|
187 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
188 |
+
|
189 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
190 |
+
if "proj_attn.weight" in new_path:
|
191 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
192 |
+
else:
|
193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
194 |
+
|
195 |
+
|
196 |
+
def conv_attn_to_linear(checkpoint):
|
197 |
+
keys = list(checkpoint.keys())
|
198 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
199 |
+
for key in keys:
|
200 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
201 |
+
if checkpoint[key].ndim > 2:
|
202 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
203 |
+
elif "proj_attn.weight" in key:
|
204 |
+
if checkpoint[key].ndim > 2:
|
205 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
206 |
+
|
207 |
+
|
208 |
+
def create_unet_diffusers_config(original_config):
|
209 |
+
"""
|
210 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
211 |
+
"""
|
212 |
+
unet_params = original_config.model.params.unet_config.params
|
213 |
+
|
214 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
215 |
+
|
216 |
+
down_block_types = []
|
217 |
+
resolution = 1
|
218 |
+
for i in range(len(block_out_channels)):
|
219 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
220 |
+
down_block_types.append(block_type)
|
221 |
+
if i != len(block_out_channels) - 1:
|
222 |
+
resolution *= 2
|
223 |
+
|
224 |
+
up_block_types = []
|
225 |
+
for i in range(len(block_out_channels)):
|
226 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
227 |
+
up_block_types.append(block_type)
|
228 |
+
resolution //= 2
|
229 |
+
|
230 |
+
config = dict(
|
231 |
+
sample_size=unet_params.image_size,
|
232 |
+
in_channels=unet_params.in_channels,
|
233 |
+
out_channels=unet_params.out_channels,
|
234 |
+
down_block_types=tuple(down_block_types),
|
235 |
+
up_block_types=tuple(up_block_types),
|
236 |
+
block_out_channels=tuple(block_out_channels),
|
237 |
+
layers_per_block=unet_params.num_res_blocks,
|
238 |
+
cross_attention_dim=unet_params.context_dim,
|
239 |
+
attention_head_dim=unet_params.num_heads,
|
240 |
+
)
|
241 |
+
|
242 |
+
return config
|
243 |
+
|
244 |
+
|
245 |
+
def create_vae_diffusers_config(original_config):
|
246 |
+
"""
|
247 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
248 |
+
"""
|
249 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
250 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
251 |
+
|
252 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
253 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
254 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
255 |
+
|
256 |
+
config = dict(
|
257 |
+
sample_size=vae_params.resolution,
|
258 |
+
in_channels=vae_params.in_channels,
|
259 |
+
out_channels=vae_params.out_ch,
|
260 |
+
down_block_types=tuple(down_block_types),
|
261 |
+
up_block_types=tuple(up_block_types),
|
262 |
+
block_out_channels=tuple(block_out_channels),
|
263 |
+
latent_channels=vae_params.z_channels,
|
264 |
+
layers_per_block=vae_params.num_res_blocks,
|
265 |
+
)
|
266 |
+
return config
|
267 |
+
|
268 |
+
|
269 |
+
def create_diffusers_schedular(original_config):
|
270 |
+
schedular = DDIMScheduler(
|
271 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
272 |
+
beta_start=original_config.model.params.linear_start,
|
273 |
+
beta_end=original_config.model.params.linear_end,
|
274 |
+
beta_schedule="scaled_linear",
|
275 |
+
)
|
276 |
+
return schedular
|
277 |
+
|
278 |
+
|
279 |
+
def create_ldm_bert_config(original_config):
|
280 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
281 |
+
config = LDMBertConfig(
|
282 |
+
d_model=bert_params.n_embed,
|
283 |
+
encoder_layers=bert_params.n_layer,
|
284 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
285 |
+
)
|
286 |
+
return config
|
287 |
+
|
288 |
+
|
289 |
+
def convert_ldm_unet_checkpoint(checkpoint, config):
|
290 |
+
"""
|
291 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
292 |
+
"""
|
293 |
+
|
294 |
+
# extract state_dict for UNet
|
295 |
+
unet_state_dict = {}
|
296 |
+
unet_key = "model.diffusion_model."
|
297 |
+
keys = list(checkpoint.keys())
|
298 |
+
for key in keys:
|
299 |
+
if key.startswith(unet_key):
|
300 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
301 |
+
|
302 |
+
new_checkpoint = {}
|
303 |
+
|
304 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
305 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
306 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
307 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
308 |
+
|
309 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
310 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
311 |
+
|
312 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
313 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
314 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
315 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
316 |
+
|
317 |
+
# Retrieves the keys for the input blocks only
|
318 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
319 |
+
input_blocks = {
|
320 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
321 |
+
for layer_id in range(num_input_blocks)
|
322 |
+
}
|
323 |
+
|
324 |
+
# Retrieves the keys for the middle blocks only
|
325 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
326 |
+
middle_blocks = {
|
327 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
328 |
+
for layer_id in range(num_middle_blocks)
|
329 |
+
}
|
330 |
+
|
331 |
+
# Retrieves the keys for the output blocks only
|
332 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
333 |
+
output_blocks = {
|
334 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
335 |
+
for layer_id in range(num_output_blocks)
|
336 |
+
}
|
337 |
+
|
338 |
+
for i in range(1, num_input_blocks):
|
339 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
340 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
341 |
+
|
342 |
+
resnets = [
|
343 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
344 |
+
]
|
345 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
346 |
+
|
347 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
348 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
349 |
+
f"input_blocks.{i}.0.op.weight"
|
350 |
+
)
|
351 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
352 |
+
f"input_blocks.{i}.0.op.bias"
|
353 |
+
)
|
354 |
+
|
355 |
+
paths = renew_resnet_paths(resnets)
|
356 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
357 |
+
assign_to_checkpoint(
|
358 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
359 |
+
)
|
360 |
+
|
361 |
+
if len(attentions):
|
362 |
+
paths = renew_attention_paths(attentions)
|
363 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
364 |
+
assign_to_checkpoint(
|
365 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
366 |
+
)
|
367 |
+
|
368 |
+
resnet_0 = middle_blocks[0]
|
369 |
+
attentions = middle_blocks[1]
|
370 |
+
resnet_1 = middle_blocks[2]
|
371 |
+
|
372 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
373 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
374 |
+
|
375 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
376 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
377 |
+
|
378 |
+
attentions_paths = renew_attention_paths(attentions)
|
379 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
380 |
+
assign_to_checkpoint(
|
381 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
382 |
+
)
|
383 |
+
|
384 |
+
for i in range(num_output_blocks):
|
385 |
+
block_id = i // (config["layers_per_block"] + 1)
|
386 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
387 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
388 |
+
output_block_list = {}
|
389 |
+
|
390 |
+
for layer in output_block_layers:
|
391 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
392 |
+
if layer_id in output_block_list:
|
393 |
+
output_block_list[layer_id].append(layer_name)
|
394 |
+
else:
|
395 |
+
output_block_list[layer_id] = [layer_name]
|
396 |
+
|
397 |
+
if len(output_block_list) > 1:
|
398 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
399 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
400 |
+
|
401 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
402 |
+
paths = renew_resnet_paths(resnets)
|
403 |
+
|
404 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
405 |
+
assign_to_checkpoint(
|
406 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
407 |
+
)
|
408 |
+
|
409 |
+
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
410 |
+
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
411 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
412 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
413 |
+
]
|
414 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
415 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
416 |
+
]
|
417 |
+
|
418 |
+
# Clear attentions as they have been attributed above.
|
419 |
+
if len(attentions) == 2:
|
420 |
+
attentions = []
|
421 |
+
|
422 |
+
if len(attentions):
|
423 |
+
paths = renew_attention_paths(attentions)
|
424 |
+
meta_path = {
|
425 |
+
"old": f"output_blocks.{i}.1",
|
426 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
427 |
+
}
|
428 |
+
assign_to_checkpoint(
|
429 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
433 |
+
for path in resnet_0_paths:
|
434 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
435 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
436 |
+
|
437 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
438 |
+
|
439 |
+
return new_checkpoint
|
440 |
+
|
441 |
+
|
442 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
443 |
+
# extract state dict for VAE
|
444 |
+
vae_state_dict = {}
|
445 |
+
vae_key = "first_stage_model."
|
446 |
+
keys = list(checkpoint.keys())
|
447 |
+
for key in keys:
|
448 |
+
if key.startswith(vae_key):
|
449 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
450 |
+
|
451 |
+
new_checkpoint = {}
|
452 |
+
|
453 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
454 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
455 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
456 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
457 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
458 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
459 |
+
|
460 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
461 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
462 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
463 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
464 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
465 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
466 |
+
|
467 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
468 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
469 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
470 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
471 |
+
|
472 |
+
# Retrieves the keys for the encoder down blocks only
|
473 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
474 |
+
down_blocks = {
|
475 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
476 |
+
}
|
477 |
+
|
478 |
+
# Retrieves the keys for the decoder up blocks only
|
479 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
480 |
+
up_blocks = {
|
481 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
482 |
+
}
|
483 |
+
|
484 |
+
for i in range(num_down_blocks):
|
485 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
486 |
+
|
487 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
488 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
489 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
490 |
+
)
|
491 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
492 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
493 |
+
)
|
494 |
+
|
495 |
+
paths = renew_vae_resnet_paths(resnets)
|
496 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
497 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
498 |
+
|
499 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
500 |
+
num_mid_res_blocks = 2
|
501 |
+
for i in range(1, num_mid_res_blocks + 1):
|
502 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
503 |
+
|
504 |
+
paths = renew_vae_resnet_paths(resnets)
|
505 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
506 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
507 |
+
|
508 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
509 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
510 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
511 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
512 |
+
conv_attn_to_linear(new_checkpoint)
|
513 |
+
|
514 |
+
for i in range(num_up_blocks):
|
515 |
+
block_id = num_up_blocks - 1 - i
|
516 |
+
resnets = [
|
517 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
518 |
+
]
|
519 |
+
|
520 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
521 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
522 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
523 |
+
]
|
524 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
525 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
526 |
+
]
|
527 |
+
|
528 |
+
paths = renew_vae_resnet_paths(resnets)
|
529 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
530 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
531 |
+
|
532 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
533 |
+
num_mid_res_blocks = 2
|
534 |
+
for i in range(1, num_mid_res_blocks + 1):
|
535 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
536 |
+
|
537 |
+
paths = renew_vae_resnet_paths(resnets)
|
538 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
539 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
540 |
+
|
541 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
542 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
543 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
544 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
545 |
+
conv_attn_to_linear(new_checkpoint)
|
546 |
+
return new_checkpoint
|
547 |
+
|
548 |
+
|
549 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
550 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
551 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
552 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
553 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
554 |
+
|
555 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
556 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
557 |
+
|
558 |
+
def _copy_linear(hf_linear, pt_linear):
|
559 |
+
hf_linear.weight = pt_linear.weight
|
560 |
+
hf_linear.bias = pt_linear.bias
|
561 |
+
|
562 |
+
def _copy_layer(hf_layer, pt_layer):
|
563 |
+
# copy layer norms
|
564 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
565 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
566 |
+
|
567 |
+
# copy attn
|
568 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
569 |
+
|
570 |
+
# copy MLP
|
571 |
+
pt_mlp = pt_layer[1][1]
|
572 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
573 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
574 |
+
|
575 |
+
def _copy_layers(hf_layers, pt_layers):
|
576 |
+
for i, hf_layer in enumerate(hf_layers):
|
577 |
+
if i != 0:
|
578 |
+
i += i
|
579 |
+
pt_layer = pt_layers[i : i + 2]
|
580 |
+
_copy_layer(hf_layer, pt_layer)
|
581 |
+
|
582 |
+
hf_model = LDMBertModel(config).eval()
|
583 |
+
|
584 |
+
# copy embeds
|
585 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
586 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
587 |
+
|
588 |
+
# copy layer norm
|
589 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
590 |
+
|
591 |
+
# copy hidden layers
|
592 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
593 |
+
|
594 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
595 |
+
|
596 |
+
return hf_model
|
597 |
+
|
598 |
+
|
599 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
600 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
601 |
+
|
602 |
+
keys = list(checkpoint.keys())
|
603 |
+
|
604 |
+
text_model_dict = {}
|
605 |
+
|
606 |
+
for key in keys:
|
607 |
+
if key.startswith("cond_stage_model.transformer"):
|
608 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
609 |
+
|
610 |
+
text_model.load_state_dict(text_model_dict)
|
611 |
+
|
612 |
+
return text_model
|
613 |
+
|
614 |
+
import os
|
615 |
+
def convert_checkpoint(checkpoint_path, inpainting=False):
|
616 |
+
parser = argparse.ArgumentParser()
|
617 |
+
|
618 |
+
parser.add_argument(
|
619 |
+
"--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
|
620 |
+
)
|
621 |
+
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
622 |
+
parser.add_argument(
|
623 |
+
"--original_config_file",
|
624 |
+
default=None,
|
625 |
+
type=str,
|
626 |
+
help="The YAML config file corresponding to the original architecture.",
|
627 |
+
)
|
628 |
+
parser.add_argument(
|
629 |
+
"--scheduler_type",
|
630 |
+
default="pndm",
|
631 |
+
type=str,
|
632 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
633 |
+
)
|
634 |
+
parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
|
635 |
+
|
636 |
+
args = parser.parse_args([])
|
637 |
+
if args.original_config_file is None:
|
638 |
+
if inpainting:
|
639 |
+
args.original_config_file = "./models/v1-inpainting-inference.yaml"
|
640 |
+
else:
|
641 |
+
args.original_config_file = "./models/v1-inference.yaml"
|
642 |
+
|
643 |
+
original_config = OmegaConf.load(args.original_config_file)
|
644 |
+
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
|
645 |
+
|
646 |
+
num_train_timesteps = original_config.model.params.timesteps
|
647 |
+
beta_start = original_config.model.params.linear_start
|
648 |
+
beta_end = original_config.model.params.linear_end
|
649 |
+
if args.scheduler_type == "pndm":
|
650 |
+
scheduler = PNDMScheduler(
|
651 |
+
beta_end=beta_end,
|
652 |
+
beta_schedule="scaled_linear",
|
653 |
+
beta_start=beta_start,
|
654 |
+
num_train_timesteps=num_train_timesteps,
|
655 |
+
skip_prk_steps=True,
|
656 |
+
)
|
657 |
+
elif args.scheduler_type == "lms":
|
658 |
+
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
659 |
+
elif args.scheduler_type == "ddim":
|
660 |
+
scheduler = DDIMScheduler(
|
661 |
+
beta_start=beta_start,
|
662 |
+
beta_end=beta_end,
|
663 |
+
beta_schedule="scaled_linear",
|
664 |
+
clip_sample=False,
|
665 |
+
set_alpha_to_one=False,
|
666 |
+
)
|
667 |
+
else:
|
668 |
+
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
669 |
+
|
670 |
+
# Convert the UNet2DConditionModel model.
|
671 |
+
unet_config = create_unet_diffusers_config(original_config)
|
672 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
673 |
+
|
674 |
+
unet = UNet2DConditionModel(**unet_config)
|
675 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
676 |
+
|
677 |
+
# Convert the VAE model.
|
678 |
+
vae_config = create_vae_diffusers_config(original_config)
|
679 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
680 |
+
|
681 |
+
vae = AutoencoderKL(**vae_config)
|
682 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
683 |
+
|
684 |
+
# Convert the text model.
|
685 |
+
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
686 |
+
if text_model_type == "FrozenCLIPEmbedder":
|
687 |
+
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
688 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
689 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
690 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
691 |
+
pipe = StableDiffusionPipeline(
|
692 |
+
vae=vae,
|
693 |
+
text_encoder=text_model,
|
694 |
+
tokenizer=tokenizer,
|
695 |
+
unet=unet,
|
696 |
+
scheduler=scheduler,
|
697 |
+
safety_checker=safety_checker,
|
698 |
+
feature_extractor=feature_extractor,
|
699 |
+
)
|
700 |
+
else:
|
701 |
+
text_config = create_ldm_bert_config(original_config)
|
702 |
+
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
703 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
704 |
+
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
705 |
+
|
706 |
+
return pipe
|
css/w2ui.min.css
ADDED
The diff for this file is too large to render.
See raw diff
|
|
js/fabric.min.js
ADDED
The diff for this file is too large to render.
See raw diff
|
|
js/keyboard.js
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
window.my_setup_keyboard=setInterval(function(){
|
3 |
+
let app=document.querySelector("gradio-app");
|
4 |
+
app=app.shadowRoot??app;
|
5 |
+
let frame=app.querySelector("#sdinfframe").contentWindow;
|
6 |
+
console.log("Check iframe...");
|
7 |
+
if(frame.setup_shortcut)
|
8 |
+
{
|
9 |
+
frame.setup_shortcut(json);
|
10 |
+
clearInterval(window.my_setup_keyboard);
|
11 |
+
}
|
12 |
+
}, 1000);
|
13 |
+
var config=JSON.parse(json);
|
14 |
+
var key_map={};
|
15 |
+
Object.keys(config.shortcut).forEach(k=>{
|
16 |
+
key_map[config.shortcut[k]]=k;
|
17 |
+
});
|
18 |
+
document.addEventListener("keydown", e => {
|
19 |
+
if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
|
20 |
+
{
|
21 |
+
let key=e.key;
|
22 |
+
if(e.ctrlKey)
|
23 |
+
{
|
24 |
+
key="Ctrl+"+e.key;
|
25 |
+
if(key in key_map)
|
26 |
+
{
|
27 |
+
e.preventDefault();
|
28 |
+
}
|
29 |
+
}
|
30 |
+
let app=document.querySelector("gradio-app");
|
31 |
+
app=app.shadowRoot??app;
|
32 |
+
let frame=app.querySelector("#sdinfframe").contentDocument;
|
33 |
+
frame.dispatchEvent(
|
34 |
+
new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
|
35 |
+
);
|
36 |
+
}
|
37 |
+
})
|
js/mode.js
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
function(mode){
|
2 |
-
let app=document.querySelector("gradio-app");
|
3 |
-
let frame=app.querySelector("#sdinfframe").contentWindow;
|
4 |
-
frame.
|
5 |
-
return mode;
|
6 |
}
|
|
|
1 |
+
function(mode){
|
2 |
+
let app=document.querySelector("gradio-app").shadowRoot;
|
3 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
4 |
+
frame.querySelector("#mode").value=mode;
|
5 |
+
return mode;
|
6 |
}
|
js/outpaint.js
CHANGED
@@ -1,31 +1,23 @@
|
|
1 |
-
function(a){
|
2 |
-
if(!window.my_observe_outpaint)
|
3 |
-
{
|
4 |
-
console.log("setup outpaint here");
|
5 |
-
window.my_observe_outpaint = new MutationObserver(function (event) {
|
6 |
-
console.log(event);
|
7 |
-
let app=document.querySelector("gradio-app");
|
8 |
-
|
9 |
-
|
10 |
-
frame.postMessage(["outpaint",
|
11 |
-
});
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
document.querySelector("gradio-app").querySelector("#proceed").click();
|
24 |
-
}
|
25 |
-
});
|
26 |
-
}
|
27 |
-
let app=document.querySelector("gradio-app");
|
28 |
-
let frame=app.querySelector("#sdinfframe").contentWindow;
|
29 |
-
frame.postMessage(["transfer"],"*")
|
30 |
-
return a;
|
31 |
}
|
|
|
1 |
+
function(a){
|
2 |
+
if(!window.my_observe_outpaint)
|
3 |
+
{
|
4 |
+
console.log("setup outpaint here");
|
5 |
+
window.my_observe_outpaint = new MutationObserver(function (event) {
|
6 |
+
console.log(event);
|
7 |
+
let app=document.querySelector("gradio-app");
|
8 |
+
app=app.shadowRoot??app;
|
9 |
+
let frame=app.querySelector("#sdinfframe").contentWindow;
|
10 |
+
frame.postMessage(["outpaint", ""], "*");
|
11 |
+
});
|
12 |
+
var app=document.querySelector("gradio-app");
|
13 |
+
app=app.shadowRoot??app;
|
14 |
+
window.my_observe_outpaint_target=app.querySelector("#output span");
|
15 |
+
window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
|
16 |
+
attributes: false,
|
17 |
+
subtree: true,
|
18 |
+
childList: true,
|
19 |
+
characterData: true
|
20 |
+
});
|
21 |
+
}
|
22 |
+
return a;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
}
|
js/proceed.js
CHANGED
@@ -1,22 +1,42 @@
|
|
1 |
-
function(sel_buffer_str,
|
2 |
-
prompt_text,
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
}
|
|
|
1 |
+
function(sel_buffer_str,
|
2 |
+
prompt_text,
|
3 |
+
negative_prompt_text,
|
4 |
+
strength,
|
5 |
+
guidance,
|
6 |
+
step,
|
7 |
+
resize_check,
|
8 |
+
fill_mode,
|
9 |
+
enable_safety,
|
10 |
+
use_correction,
|
11 |
+
enable_img2img,
|
12 |
+
use_seed,
|
13 |
+
seed_val,
|
14 |
+
generate_num,
|
15 |
+
scheduler,
|
16 |
+
scheduler_eta,
|
17 |
+
state){
|
18 |
+
let app=document.querySelector("gradio-app");
|
19 |
+
app=app.shadowRoot??app;
|
20 |
+
sel_buffer=app.querySelector("#input textarea").value;
|
21 |
+
let use_correction_bak=false;
|
22 |
+
({resize_check,enable_safety,use_correction_bak,enable_img2img,use_seed,seed_val}=window.config_obj);
|
23 |
+
return [
|
24 |
+
sel_buffer,
|
25 |
+
prompt_text,
|
26 |
+
negative_prompt_text,
|
27 |
+
strength,
|
28 |
+
guidance,
|
29 |
+
step,
|
30 |
+
resize_check,
|
31 |
+
fill_mode,
|
32 |
+
enable_safety,
|
33 |
+
use_correction,
|
34 |
+
enable_img2img,
|
35 |
+
use_seed,
|
36 |
+
seed_val,
|
37 |
+
generate_num,
|
38 |
+
scheduler,
|
39 |
+
scheduler_eta,
|
40 |
+
state,
|
41 |
+
]
|
42 |
}
|
js/setup.js
CHANGED
@@ -1,22 +1,28 @@
|
|
1 |
-
function(token_val, width, height, size){
|
2 |
-
let app=document.querySelector("gradio-app");
|
3 |
-
app.
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
}
|
|
|
1 |
+
function(token_val, width, height, size, model_choice, model_path){
|
2 |
+
let app=document.querySelector("gradio-app");
|
3 |
+
app=app.shadowRoot??app;
|
4 |
+
app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
|
5 |
+
// app.querySelector("#setup_row").style.display="none";
|
6 |
+
app.querySelector("#model_path_input").style.display="none";
|
7 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
8 |
+
|
9 |
+
if(frame.querySelector("#setup").value=="0")
|
10 |
+
{
|
11 |
+
window.my_setup=setInterval(function(){
|
12 |
+
let app=document.querySelector("gradio-app");
|
13 |
+
app=app.shadowRoot??app;
|
14 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
15 |
+
console.log("Check PyScript...")
|
16 |
+
if(frame.querySelector("#setup").value=="1")
|
17 |
+
{
|
18 |
+
frame.querySelector("#draw").click();
|
19 |
+
clearInterval(window.my_setup);
|
20 |
+
}
|
21 |
+
}, 100)
|
22 |
+
}
|
23 |
+
else
|
24 |
+
{
|
25 |
+
frame.querySelector("#draw").click();
|
26 |
+
}
|
27 |
+
return [token_val, width, height, size, model_choice, model_path];
|
28 |
}
|
js/toolbar.js
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
|
2 |
+
// import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
|
3 |
+
|
4 |
+
// https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
|
5 |
+
function getBase64(file) {
|
6 |
+
var reader = new FileReader();
|
7 |
+
reader.readAsDataURL(file);
|
8 |
+
reader.onload = function () {
|
9 |
+
add_image(reader.result);
|
10 |
+
// console.log(reader.result);
|
11 |
+
};
|
12 |
+
reader.onerror = function (error) {
|
13 |
+
console.log("Error: ", error);
|
14 |
+
};
|
15 |
+
}
|
16 |
+
|
17 |
+
function getText(file) {
|
18 |
+
var reader = new FileReader();
|
19 |
+
reader.readAsText(file);
|
20 |
+
reader.onload = function () {
|
21 |
+
window.postMessage(["load",reader.result],"*")
|
22 |
+
// console.log(reader.result);
|
23 |
+
};
|
24 |
+
reader.onerror = function (error) {
|
25 |
+
console.log("Error: ", error);
|
26 |
+
};
|
27 |
+
}
|
28 |
+
|
29 |
+
document.querySelector("#upload_file").addEventListener("change", (event)=>{
|
30 |
+
console.log(event);
|
31 |
+
let file = document.querySelector("#upload_file").files[0];
|
32 |
+
getBase64(file);
|
33 |
+
})
|
34 |
+
|
35 |
+
document.querySelector("#upload_state").addEventListener("change", (event)=>{
|
36 |
+
console.log(event);
|
37 |
+
let file = document.querySelector("#upload_state").files[0];
|
38 |
+
getText(file);
|
39 |
+
})
|
40 |
+
|
41 |
+
open_setting = function() {
|
42 |
+
if (!w2ui.foo) {
|
43 |
+
new w2form({
|
44 |
+
name: "foo",
|
45 |
+
style: "border: 0px; background-color: transparent;",
|
46 |
+
fields: [{
|
47 |
+
field: "canvas_width",
|
48 |
+
type: "int",
|
49 |
+
required: true,
|
50 |
+
html: {
|
51 |
+
label: "Canvas Width"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
{
|
55 |
+
field: "canvas_height",
|
56 |
+
type: "int",
|
57 |
+
required: true,
|
58 |
+
html: {
|
59 |
+
label: "Canvas Height"
|
60 |
+
}
|
61 |
+
},
|
62 |
+
],
|
63 |
+
record: {
|
64 |
+
canvas_width: 1200,
|
65 |
+
canvas_height: 600,
|
66 |
+
},
|
67 |
+
actions: {
|
68 |
+
Save() {
|
69 |
+
this.validate();
|
70 |
+
let record = this.getCleanRecord();
|
71 |
+
window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
|
72 |
+
w2popup.close();
|
73 |
+
},
|
74 |
+
custom: {
|
75 |
+
text: "Cancel",
|
76 |
+
style: "text-transform: uppercase",
|
77 |
+
onClick(event) {
|
78 |
+
w2popup.close();
|
79 |
+
}
|
80 |
+
}
|
81 |
+
}
|
82 |
+
});
|
83 |
+
}
|
84 |
+
w2popup.open({
|
85 |
+
title: "Form in a Popup",
|
86 |
+
body: "<div id='form' style='width: 100%; height: 100%;''></div>",
|
87 |
+
style: "padding: 15px 0px 0px 0px",
|
88 |
+
width: 500,
|
89 |
+
height: 280,
|
90 |
+
showMax: true,
|
91 |
+
async onToggle(event) {
|
92 |
+
await event.complete
|
93 |
+
w2ui.foo.resize();
|
94 |
+
}
|
95 |
+
})
|
96 |
+
.then((event) => {
|
97 |
+
w2ui.foo.render("#form")
|
98 |
+
});
|
99 |
+
}
|
100 |
+
|
101 |
+
var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
|
102 |
+
var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting"];
|
103 |
+
var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting"];
|
104 |
+
var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting"];
|
105 |
+
var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
|
106 |
+
var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
|
107 |
+
|
108 |
+
function check_button(id,text="",checked=true,tooltip="")
|
109 |
+
{
|
110 |
+
return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
|
111 |
+
}
|
112 |
+
|
113 |
+
var toolbar=new w2toolbar({
|
114 |
+
box: "#toolbar",
|
115 |
+
name: "toolbar",
|
116 |
+
tooltip: "top",
|
117 |
+
items: [
|
118 |
+
{ type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
|
119 |
+
{ type: "break" },
|
120 |
+
{ type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
|
121 |
+
{ type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
|
122 |
+
{ type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
|
123 |
+
{ type: "break" },
|
124 |
+
{ type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
|
125 |
+
{ type: "break" },
|
126 |
+
{ type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
|
127 |
+
{ type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
|
128 |
+
{ type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
|
129 |
+
{ type: "break" },
|
130 |
+
{ type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
|
131 |
+
{ type: "break" },
|
132 |
+
{ type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disable:true,},
|
133 |
+
{ type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
|
134 |
+
{ type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disable:true,},
|
135 |
+
{ type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disable:true,},
|
136 |
+
{ type: "html", id: "current", hidden: true, disable:true,
|
137 |
+
async onRefresh(event) {
|
138 |
+
await event.complete
|
139 |
+
let fragment = query.html(`
|
140 |
+
<div class="w2ui-tb-text">
|
141 |
+
<div class="w2ui-tb-count">
|
142 |
+
<span>${this.sel_value ?? "1/1"}</span>
|
143 |
+
</div> </div>`)
|
144 |
+
query(this.box).find("#tb_toolbar_item_current").append(fragment)
|
145 |
+
}
|
146 |
+
},
|
147 |
+
{ type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disable:true,},
|
148 |
+
{ type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disable:true,},
|
149 |
+
{ type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disable:true,},
|
150 |
+
{ type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disable:true,},
|
151 |
+
{ type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disable:true,},
|
152 |
+
{ type: "break" },
|
153 |
+
{ type: "spacer" },
|
154 |
+
{ type: "break" },
|
155 |
+
{ type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
|
156 |
+
{ type: "html", id: "eraser_size", hidden: true,
|
157 |
+
async onRefresh(event) {
|
158 |
+
await event.complete
|
159 |
+
// let fragment = query.html(`
|
160 |
+
// <input type="number" size="${this.eraser_size ? this.eraser_size.length:"2"}" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
|
161 |
+
// <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">`)
|
162 |
+
let fragment = query.html(`
|
163 |
+
<input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
|
164 |
+
`)
|
165 |
+
fragment.filter("input").on("change", event => {
|
166 |
+
this.eraser_size = event.target.value;
|
167 |
+
window.overlay.freeDrawingBrush.width=this.eraser_size;
|
168 |
+
this.setCount("eraser_size_btn", event.target.value);
|
169 |
+
window.postMessage(["eraser_size", event.target.value],"*")
|
170 |
+
this.refresh();
|
171 |
+
})
|
172 |
+
query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
|
173 |
+
}
|
174 |
+
},
|
175 |
+
// { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
|
176 |
+
{ type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
|
177 |
+
{ type: "break" },
|
178 |
+
{ type: "html", id: "scale",
|
179 |
+
async onRefresh(event) {
|
180 |
+
await event.complete
|
181 |
+
let fragment = query.html(`
|
182 |
+
<div class="">
|
183 |
+
<div style="padding: 4px; border: 1px solid silver">
|
184 |
+
<span>${this.scale_value ?? "100%"}</span>
|
185 |
+
</div></div>`)
|
186 |
+
query(this.box).find("#tb_toolbar_item_scale").append(fragment)
|
187 |
+
}
|
188 |
+
},
|
189 |
+
{ type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
|
190 |
+
{ type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
|
191 |
+
{ type: "break" },
|
192 |
+
{ type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
|
193 |
+
{ type: "new-line"},
|
194 |
+
{ type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
|
195 |
+
{ type: "break" },
|
196 |
+
check_button("enable_img2img","Enable Img2Img",false),
|
197 |
+
// check_button("use_correction","Photometric Correction",false),
|
198 |
+
check_button("resize_check","Resize Small Input",true),
|
199 |
+
check_button("enable_safety","Enable Safety Checker",true),
|
200 |
+
check_button("square_selection","Square Selection Only",false),
|
201 |
+
{type: "break"},
|
202 |
+
check_button("use_seed","Use Seed:",false),
|
203 |
+
{ type: "html", id: "seed_val",
|
204 |
+
async onRefresh(event) {
|
205 |
+
await event.complete
|
206 |
+
let fragment = query.html(`
|
207 |
+
<input type="number" style="margin: 0px 3px; padding: 4px; width:100px;" value="${this.config_obj.seed_val ?? "0"}">`)
|
208 |
+
fragment.filter("input").on("change", event => {
|
209 |
+
this.config_obj.seed_val = event.target.value;
|
210 |
+
parent.config_obj=this.config_obj;
|
211 |
+
this.refresh();
|
212 |
+
})
|
213 |
+
query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
|
214 |
+
}
|
215 |
+
},
|
216 |
+
{ type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
|
217 |
+
],
|
218 |
+
onClick(event) {
|
219 |
+
switch(event.target){
|
220 |
+
case "setting":
|
221 |
+
open_setting();
|
222 |
+
break;
|
223 |
+
case "upload":
|
224 |
+
this.upload_mode=true
|
225 |
+
document.querySelector("#overlay_container").style.pointerEvents="auto";
|
226 |
+
this.click("canvas");
|
227 |
+
this.click("selection");
|
228 |
+
this.show("confirm","cancel_overlay","add_image","delete_image");
|
229 |
+
this.enable("confirm","cancel_overlay","add_image","delete_image");
|
230 |
+
this.disable(...upload_button_lst);
|
231 |
+
query("#upload_file").click();
|
232 |
+
if(this.upload_tip)
|
233 |
+
{
|
234 |
+
this.upload_tip=false;
|
235 |
+
w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
|
236 |
+
}
|
237 |
+
break;
|
238 |
+
case "resize_selection":
|
239 |
+
this.resize_mode=true;
|
240 |
+
this.disable(...resize_button_lst);
|
241 |
+
this.enable("confirm","cancel_overlay");
|
242 |
+
this.show("confirm","cancel_overlay");
|
243 |
+
window.postMessage(["resize_selection",""],"*");
|
244 |
+
document.querySelector("#overlay_container").style.pointerEvents="auto";
|
245 |
+
break;
|
246 |
+
case "confirm":
|
247 |
+
if(this.upload_mode)
|
248 |
+
{
|
249 |
+
export_image();
|
250 |
+
}
|
251 |
+
else
|
252 |
+
{
|
253 |
+
let sel_box=this.selection_box;
|
254 |
+
window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
|
255 |
+
}
|
256 |
+
case "cancel_overlay":
|
257 |
+
end_overlay();
|
258 |
+
this.hide("confirm","cancel_overlay","add_image","delete_image");
|
259 |
+
if(this.upload_mode){
|
260 |
+
this.enable(...upload_button_lst);
|
261 |
+
}
|
262 |
+
else
|
263 |
+
{
|
264 |
+
this.enable(...resize_button_lst);
|
265 |
+
window.postMessage(["resize_selection","",""],"*");
|
266 |
+
if(event.target=="cancel_overlay")
|
267 |
+
{
|
268 |
+
this.selection_box=this.selection_box_bak;
|
269 |
+
}
|
270 |
+
}
|
271 |
+
if(this.selection_box)
|
272 |
+
{
|
273 |
+
this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
|
274 |
+
}
|
275 |
+
this.disable("confirm","cancel_overlay","add_image","delete_image");
|
276 |
+
this.upload_mode=false;
|
277 |
+
this.resize_mode=false;
|
278 |
+
this.click("selection");
|
279 |
+
break;
|
280 |
+
case "add_image":
|
281 |
+
query("#upload_file").click();
|
282 |
+
break;
|
283 |
+
case "delete_image":
|
284 |
+
let active_obj = window.overlay.getActiveObject();
|
285 |
+
if(active_obj)
|
286 |
+
{
|
287 |
+
window.overlay.remove(active_obj);
|
288 |
+
window.overlay.renderAll();
|
289 |
+
}
|
290 |
+
else
|
291 |
+
{
|
292 |
+
w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
|
293 |
+
}
|
294 |
+
break;
|
295 |
+
case "load":
|
296 |
+
query("#upload_state").click();
|
297 |
+
this.selection_box=null;
|
298 |
+
this.setCount("resize_selection","");
|
299 |
+
break;
|
300 |
+
case "next":
|
301 |
+
case "prev":
|
302 |
+
window.postMessage(["outpaint", "", event.target], "*");
|
303 |
+
break;
|
304 |
+
case "outpaint":
|
305 |
+
this.click("selection");
|
306 |
+
this.disable(...outpaint_button_lst);
|
307 |
+
this.show(...outpaint_result_lst);
|
308 |
+
if(this.outpaint_tip)
|
309 |
+
{
|
310 |
+
this.outpaint_tip=false;
|
311 |
+
w2utils.notify("The canvas stays locked until you accept/cancel current outpainting",{timeout:10000,where:query("#container")})
|
312 |
+
}
|
313 |
+
document.querySelector("#container").style.pointerEvents="none";
|
314 |
+
case "retry":
|
315 |
+
this.disable(...outpaint_result_func_lst);
|
316 |
+
window.postMessage(["transfer",""],"*")
|
317 |
+
break;
|
318 |
+
case "accept":
|
319 |
+
case "cancel":
|
320 |
+
this.hide(...outpaint_result_lst);
|
321 |
+
this.disable(...outpaint_result_func_lst);
|
322 |
+
this.enable(...outpaint_button_lst);
|
323 |
+
document.querySelector("#container").style.pointerEvents="auto";
|
324 |
+
window.postMessage(["click", event.target],"*");
|
325 |
+
let app=parent.document.querySelector("gradio-app");
|
326 |
+
app=app.shadowRoot??app;
|
327 |
+
app.querySelector("#cancel").click();
|
328 |
+
break;
|
329 |
+
case "eraser":
|
330 |
+
case "selection":
|
331 |
+
case "canvas":
|
332 |
+
if(event.target=="eraser")
|
333 |
+
{
|
334 |
+
this.show("eraser_size","eraser_size_btn");
|
335 |
+
window.overlay.freeDrawingBrush.width=this.eraser_size;
|
336 |
+
window.overlay.isDrawingMode = true;
|
337 |
+
}
|
338 |
+
else
|
339 |
+
{
|
340 |
+
this.hide("eraser_size","eraser_size_btn");
|
341 |
+
window.overlay.isDrawingMode = false;
|
342 |
+
}
|
343 |
+
if(this.upload_mode)
|
344 |
+
{
|
345 |
+
if(event.target=="canvas")
|
346 |
+
{
|
347 |
+
window.postMessage(["mode", event.target],"*")
|
348 |
+
document.querySelector("#overlay_container").style.pointerEvents="none";
|
349 |
+
document.querySelector("#overlay_container").style.opacity = 0.5;
|
350 |
+
}
|
351 |
+
else
|
352 |
+
{
|
353 |
+
document.querySelector("#overlay_container").style.pointerEvents="auto";
|
354 |
+
document.querySelector("#overlay_container").style.opacity = 1.0;
|
355 |
+
}
|
356 |
+
}
|
357 |
+
else
|
358 |
+
{
|
359 |
+
window.postMessage(["mode", event.target],"*")
|
360 |
+
}
|
361 |
+
break;
|
362 |
+
case "help":
|
363 |
+
w2popup.open({
|
364 |
+
title: "Document",
|
365 |
+
body: "Usage: <a href='https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md' target='_blank'>https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md</a>"
|
366 |
+
})
|
367 |
+
break;
|
368 |
+
case "clear":
|
369 |
+
w2confirm("Reset canvas?").yes(() => {
|
370 |
+
window.postMessage(["click", event.target],"*");
|
371 |
+
}).no(() => {})
|
372 |
+
break;
|
373 |
+
case "random_seed":
|
374 |
+
this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
|
375 |
+
parent.config_obj=this.config_obj;
|
376 |
+
this.refresh();
|
377 |
+
break;
|
378 |
+
case "enable_img2img":
|
379 |
+
case "use_correction":
|
380 |
+
case "resize_check":
|
381 |
+
case "enable_safety":
|
382 |
+
case "use_seed":
|
383 |
+
case "square_selection":
|
384 |
+
let target=this.get(event.target);
|
385 |
+
target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
|
386 |
+
this.config_obj[event.target]=!target.checked;
|
387 |
+
parent.config_obj=this.config_obj;
|
388 |
+
this.refresh();
|
389 |
+
break;
|
390 |
+
case "save":
|
391 |
+
case "export":
|
392 |
+
ask_filename(event.target);
|
393 |
+
break;
|
394 |
+
default:
|
395 |
+
// clear, save, export, outpaint, retry
|
396 |
+
// break, save, export, accept, retry, outpaint
|
397 |
+
window.postMessage(["click", event.target],"*")
|
398 |
+
}
|
399 |
+
console.log("Target: "+ event.target, event)
|
400 |
+
}
|
401 |
+
})
|
402 |
+
window.w2ui=w2ui;
|
403 |
+
w2ui.toolbar.config_obj={
|
404 |
+
resize_check: true,
|
405 |
+
enable_safety: true,
|
406 |
+
use_correction: false,
|
407 |
+
enable_img2img: false,
|
408 |
+
use_seed: false,
|
409 |
+
seed_val: 0,
|
410 |
+
square_selection: false,
|
411 |
+
};
|
412 |
+
w2ui.toolbar.outpaint_tip=true;
|
413 |
+
w2ui.toolbar.upload_tip=true;
|
414 |
+
window.update_count=function(cur,total){
|
415 |
+
w2ui.toolbar.sel_value=`${cur}/${total}`;
|
416 |
+
w2ui.toolbar.refresh();
|
417 |
+
}
|
418 |
+
window.update_eraser=function(val,max_val){
|
419 |
+
w2ui.toolbar.eraser_size=`${val}`;
|
420 |
+
w2ui.toolbar.eraser_max=`${max_val}`;
|
421 |
+
w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
|
422 |
+
w2ui.toolbar.refresh();
|
423 |
+
}
|
424 |
+
window.update_scale=function(val){
|
425 |
+
w2ui.toolbar.scale_value=`${val}`;
|
426 |
+
w2ui.toolbar.refresh();
|
427 |
+
}
|
428 |
+
window.enable_result_lst=function(){
|
429 |
+
w2ui.toolbar.enable(...outpaint_result_lst);
|
430 |
+
}
|
431 |
+
function onObjectScaled(e)
|
432 |
+
{
|
433 |
+
let object = e.target;
|
434 |
+
if(object.isType("rect"))
|
435 |
+
{
|
436 |
+
let width=object.getScaledWidth();
|
437 |
+
let height=object.getScaledHeight();
|
438 |
+
object.scale(1);
|
439 |
+
width=Math.max(Math.min(width,window.overlay.width-object.left),256);
|
440 |
+
height=Math.max(Math.min(height,window.overlay.height-object.top),256);
|
441 |
+
let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
|
442 |
+
let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
|
443 |
+
if(window.w2ui.toolbar.config_obj.square_selection)
|
444 |
+
{
|
445 |
+
let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
|
446 |
+
width=max_val;
|
447 |
+
height=max_val;
|
448 |
+
}
|
449 |
+
object.set({ width: width, height: height, left:l,top:t})
|
450 |
+
window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
|
451 |
+
window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
|
452 |
+
window.w2ui.toolbar.refresh();
|
453 |
+
}
|
454 |
+
}
|
455 |
+
function onObjectMoved(e)
|
456 |
+
{
|
457 |
+
let object = e.target;
|
458 |
+
if(object.isType("rect"))
|
459 |
+
{
|
460 |
+
let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
|
461 |
+
let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
|
462 |
+
object.set({left:l,top:t});
|
463 |
+
window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
|
464 |
+
}
|
465 |
+
}
|
466 |
+
window.setup_overlay=function(width,height)
|
467 |
+
{
|
468 |
+
if(window.overlay)
|
469 |
+
{
|
470 |
+
window.overlay.setDimensions({width:width,height:height});
|
471 |
+
let app=parent.document.querySelector("gradio-app");
|
472 |
+
app=app.shadowRoot??app;
|
473 |
+
app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
|
474 |
+
document.querySelector("#container").style.height= height+"px";
|
475 |
+
document.querySelector("#container").style.width = width+"px";
|
476 |
+
}
|
477 |
+
else
|
478 |
+
{
|
479 |
+
canvas=new fabric.Canvas("overlay_canvas");
|
480 |
+
canvas.setDimensions({width:width,height:height});
|
481 |
+
let app=parent.document.querySelector("gradio-app");
|
482 |
+
app=app.shadowRoot??app;
|
483 |
+
app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
|
484 |
+
canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
|
485 |
+
canvas.on("object:scaling", onObjectScaled);
|
486 |
+
canvas.on("object:moving", onObjectMoved);
|
487 |
+
window.overlay=canvas;
|
488 |
+
}
|
489 |
+
document.querySelector("#overlay_container").style.pointerEvents="none";
|
490 |
+
}
|
491 |
+
window.update_overlay=function(width,height)
|
492 |
+
{
|
493 |
+
window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
|
494 |
+
// document.querySelector("#overlay_container").style.pointerEvents="none";
|
495 |
+
}
|
496 |
+
window.adjust_selection=function(x,y,width,height)
|
497 |
+
{
|
498 |
+
var rect = new fabric.Rect({
|
499 |
+
left: x,
|
500 |
+
top: y,
|
501 |
+
fill: "rgba(0,0,0,0)",
|
502 |
+
strokeWidth: 3,
|
503 |
+
stroke: "rgba(0,0,0,0.7)",
|
504 |
+
cornerColor: "red",
|
505 |
+
cornerStrokeColor: "red",
|
506 |
+
borderColor: "rgba(255, 0, 0, 1.0)",
|
507 |
+
width: width,
|
508 |
+
height: height,
|
509 |
+
lockRotation: true,
|
510 |
+
});
|
511 |
+
rect.setControlsVisibility({ mtr: false });
|
512 |
+
window.overlay.add(rect);
|
513 |
+
window.overlay.setActiveObject(window.overlay.item(0));
|
514 |
+
window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
|
515 |
+
window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
|
516 |
+
}
|
517 |
+
function add_image(url)
|
518 |
+
{
|
519 |
+
fabric.Image.fromURL(url,function(img){
|
520 |
+
window.overlay.add(img);
|
521 |
+
window.overlay.setActiveObject(img);
|
522 |
+
},{left:100,top:100});
|
523 |
+
}
|
524 |
+
function export_image()
|
525 |
+
{
|
526 |
+
data=window.overlay.toDataURL();
|
527 |
+
document.querySelector("#upload_content").value=data;
|
528 |
+
window.postMessage(["upload",""],"*");
|
529 |
+
end_overlay();
|
530 |
+
}
|
531 |
+
function end_overlay()
|
532 |
+
{
|
533 |
+
window.overlay.clear();
|
534 |
+
document.querySelector("#overlay_container").style.opacity = 1.0;
|
535 |
+
document.querySelector("#overlay_container").style.pointerEvents="none";
|
536 |
+
}
|
537 |
+
function ask_filename(target)
|
538 |
+
{
|
539 |
+
w2prompt({
|
540 |
+
label: "Enter filename",
|
541 |
+
value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
|
542 |
+
})
|
543 |
+
.change((event) => {
|
544 |
+
console.log("change", event.detail.originalEvent.target.value);
|
545 |
+
})
|
546 |
+
.ok((event) => {
|
547 |
+
console.log("value=", event.detail.value);
|
548 |
+
window.postMessage(["click",target,event.detail.value],"*");
|
549 |
+
})
|
550 |
+
.cancel((event) => {
|
551 |
+
console.log("cancel");
|
552 |
+
});
|
553 |
+
}
|
554 |
+
|
555 |
+
document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
|
556 |
+
window.setup_shortcut=function(json)
|
557 |
+
{
|
558 |
+
var config=JSON.parse(json);
|
559 |
+
var key_map={};
|
560 |
+
Object.keys(config.shortcut).forEach(k=>{
|
561 |
+
key_map[config.shortcut[k]]=k;
|
562 |
+
})
|
563 |
+
document.addEventListener("keydown",(e)=>{
|
564 |
+
if(e.target.tagName!="INPUT")
|
565 |
+
{
|
566 |
+
let key=e.key;
|
567 |
+
if(e.ctrlKey)
|
568 |
+
{
|
569 |
+
key="Ctrl+"+e.key;
|
570 |
+
if(key in key_map)
|
571 |
+
{
|
572 |
+
e.preventDefault();
|
573 |
+
}
|
574 |
+
}
|
575 |
+
if(key in key_map)
|
576 |
+
{
|
577 |
+
w2ui.toolbar.click(key_map[key]);
|
578 |
+
}
|
579 |
+
}
|
580 |
+
})
|
581 |
+
}
|
js/upload.js
CHANGED
@@ -1,20 +1,19 @@
|
|
1 |
-
function(a,b){
|
2 |
-
if(!window.my_observe_upload)
|
3 |
-
{
|
4 |
-
console.log("setup upload here");
|
5 |
-
window.my_observe_upload = new MutationObserver(function (event) {
|
6 |
-
console.log(event);
|
7 |
-
var frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow;
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
window.
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
return [a,b];
|
20 |
}
|
|
|
1 |
+
function(a,b){
|
2 |
+
if(!window.my_observe_upload)
|
3 |
+
{
|
4 |
+
console.log("setup upload here");
|
5 |
+
window.my_observe_upload = new MutationObserver(function (event) {
|
6 |
+
console.log(event);
|
7 |
+
var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
|
8 |
+
frame.querySelector("#upload").click();
|
9 |
+
});
|
10 |
+
window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
|
11 |
+
window.my_observe_upload.observe(window.my_observe_upload_target, {
|
12 |
+
attributes: false,
|
13 |
+
subtree: true,
|
14 |
+
childList: true,
|
15 |
+
characterData: true
|
16 |
+
});
|
17 |
+
}
|
18 |
+
return [a,b];
|
|
|
19 |
}
|
js/w2ui.min.js
ADDED
The diff for this file is too large to render.
See raw diff
|
|
js/xss.js
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
var setup_outpaint=function(){
|
2 |
+
if(!window.my_observe_outpaint)
|
3 |
+
{
|
4 |
+
console.log("setup outpaint here");
|
5 |
+
window.my_observe_outpaint = new MutationObserver(function (event) {
|
6 |
+
console.log(event);
|
7 |
+
let app=document.querySelector("gradio-app");
|
8 |
+
app=app.shadowRoot??app;
|
9 |
+
let frame=app.querySelector("#sdinfframe").contentWindow;
|
10 |
+
frame.postMessage(["outpaint", ""], "*");
|
11 |
+
});
|
12 |
+
var app=document.querySelector("gradio-app");
|
13 |
+
app=app.shadowRoot??app;
|
14 |
+
window.my_observe_outpaint_target=app.querySelector("#output span");
|
15 |
+
window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
|
16 |
+
attributes: false,
|
17 |
+
subtree: true,
|
18 |
+
childList: true,
|
19 |
+
characterData: true
|
20 |
+
});
|
21 |
+
}
|
22 |
+
};
|
23 |
+
window.config_obj={
|
24 |
+
resize_check: true,
|
25 |
+
enable_safety: true,
|
26 |
+
use_correction: false,
|
27 |
+
enable_img2img: false,
|
28 |
+
use_seed: false,
|
29 |
+
seed_val: 0,
|
30 |
+
};
|
31 |
+
setup_outpaint();
|
models/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
models/v1-inpainting-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 7.5e-05
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: hybrid # important
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
finetune_keys: null
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|