Spaces:
Runtime error
Runtime error
add
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE.md +159 -0
- app.py +39 -0
- configs/mixtrain/seg_rec_poly_fuse_feature.yaml +97 -0
- configs/pretrain/seg_rec_poly_fuse_feature.yaml +94 -0
- evaluation/icdar2015/e2e/prepare_results.py +263 -0
- evaluation/icdar2015/e2e/rrc_evaluation_funcs.py +369 -0
- evaluation/icdar2015/e2e/script.py +461 -0
- evaluation/icdar2015/gt.zip +0 -0
- evaluation/rotated_icdar2013/e2e/prepare_results.py +267 -0
- evaluation/rotated_icdar2013/e2e/rrc_evaluation_funcs.py +369 -0
- evaluation/rotated_icdar2013/e2e/script.py +460 -0
- evaluation/rotated_icdar2013/gt/gt.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_-15.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_-30.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_-45.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_-60.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_-75.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_-90.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_0.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_15.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_30.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_45.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_60.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_75.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_85.zip +0 -0
- evaluation/rotated_icdar2013/gt/gt_90.zip +0 -0
- evaluation/totaltext/e2e/prepare_results.py +234 -0
- evaluation/totaltext/e2e/rrc_evaluation_funcs.py +369 -0
- evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py +363 -0
- evaluation/totaltext/e2e/script.py +452 -0
- evaluation/totaltext/gt.zip +0 -0
- evaluation/weighted_editdistance.py +55 -0
- example1.jpg +0 -0
- example2.jpg +0 -0
- example3.jpg +0 -0
- maskrcnn_benchmark/config/__init__.py +2 -0
- maskrcnn_benchmark/config/defaults.py +373 -0
- maskrcnn_benchmark/config/paths_catalog.py +237 -0
- maskrcnn_benchmark/csrc/ROIAlign.h +46 -0
- maskrcnn_benchmark/csrc/ROIPool.h +48 -0
- maskrcnn_benchmark/csrc/SigmoidFocalLoss.h +41 -0
- maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp +257 -0
- maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp +75 -0
- maskrcnn_benchmark/csrc/cpu/vision.h +16 -0
- maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu +346 -0
- maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu +202 -0
- maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu +189 -0
- maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu +691 -0
- maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu +874 -0
- maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu +87 -0
LICENSE.md
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Creative Commons Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
|
4 |
+
|
5 |
+
### Using Creative Commons Public Licenses
|
6 |
+
|
7 |
+
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
8 |
+
|
9 |
+
* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
|
10 |
+
|
11 |
+
* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
|
12 |
+
|
13 |
+
## Creative Commons Attribution-NonCommercial 4.0 International Public License
|
14 |
+
|
15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
16 |
+
|
17 |
+
### Section 1 – Definitions.
|
18 |
+
|
19 |
+
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
20 |
+
|
21 |
+
b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
22 |
+
|
23 |
+
c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
24 |
+
|
25 |
+
d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
26 |
+
|
27 |
+
e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
28 |
+
|
29 |
+
f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
30 |
+
|
31 |
+
g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
32 |
+
|
33 |
+
h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
|
34 |
+
|
35 |
+
i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
36 |
+
|
37 |
+
j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
38 |
+
|
39 |
+
k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
40 |
+
|
41 |
+
l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. __Your__ has a corresponding meaning.
|
42 |
+
|
43 |
+
### Section 2 – Scope.
|
44 |
+
|
45 |
+
a. ___License grant.___
|
46 |
+
|
47 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
48 |
+
|
49 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
50 |
+
|
51 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
52 |
+
|
53 |
+
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
54 |
+
|
55 |
+
3. __Term.__ The term of this Public License is specified in Section 6(a).
|
56 |
+
|
57 |
+
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
58 |
+
|
59 |
+
5. __Downstream recipients.__
|
60 |
+
|
61 |
+
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
62 |
+
|
63 |
+
B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
64 |
+
|
65 |
+
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
66 |
+
|
67 |
+
b. ___Other rights.___
|
68 |
+
|
69 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
70 |
+
|
71 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
72 |
+
|
73 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
74 |
+
|
75 |
+
### Section 3 – License Conditions.
|
76 |
+
|
77 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
78 |
+
|
79 |
+
a. ___Attribution.___
|
80 |
+
|
81 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
82 |
+
|
83 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
84 |
+
|
85 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
86 |
+
|
87 |
+
ii. a copyright notice;
|
88 |
+
|
89 |
+
iii. a notice that refers to this Public License;
|
90 |
+
|
91 |
+
iv. a notice that refers to the disclaimer of warranties;
|
92 |
+
|
93 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
94 |
+
|
95 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
96 |
+
|
97 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
98 |
+
|
99 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
100 |
+
|
101 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
102 |
+
|
103 |
+
4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
|
104 |
+
|
105 |
+
### Section 4 – Sui Generis Database Rights.
|
106 |
+
|
107 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
108 |
+
|
109 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
110 |
+
|
111 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
|
112 |
+
|
113 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
114 |
+
|
115 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
116 |
+
|
117 |
+
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
118 |
+
|
119 |
+
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
|
120 |
+
|
121 |
+
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
|
122 |
+
|
123 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
124 |
+
|
125 |
+
### Section 6 – Term and Termination.
|
126 |
+
|
127 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
128 |
+
|
129 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
130 |
+
|
131 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
132 |
+
|
133 |
+
2. upon express reinstatement by the Licensor.
|
134 |
+
|
135 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
136 |
+
|
137 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
138 |
+
|
139 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
140 |
+
|
141 |
+
### Section 7 – Other Terms and Conditions.
|
142 |
+
|
143 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
144 |
+
|
145 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
146 |
+
|
147 |
+
### Section 8 – Interpretation.
|
148 |
+
|
149 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
150 |
+
|
151 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
152 |
+
|
153 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
154 |
+
|
155 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
156 |
+
|
157 |
+
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
158 |
+
>
|
159 |
+
> Creative Commons may be contacted at creativecommons.org
|
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system('python setup.py build develop')
|
3 |
+
os.system('pip install --upgrade --no-cache-dir gdown')
|
4 |
+
os.system('gdown -O output/mixtrain/ 1XQsikiNY7ILgZvmvOeUf9oPDG4fTp0zs')
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import pandas as pd
|
8 |
+
import gradio as gr
|
9 |
+
from tools.demo import TextDemo
|
10 |
+
from maskrcnn_benchmark.config import cfg
|
11 |
+
|
12 |
+
|
13 |
+
def infer(filepath):
|
14 |
+
cfg.merge_from_file('configs/mixtrain/seg_rec_poly_fuse_feature.yaml')
|
15 |
+
# manual override some options
|
16 |
+
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
|
17 |
+
|
18 |
+
text_demo = TextDemo(
|
19 |
+
cfg,
|
20 |
+
min_image_size=800,
|
21 |
+
confidence_threshold=0.7,
|
22 |
+
output_polygon=True
|
23 |
+
)
|
24 |
+
image = cv2.imread(filepath)
|
25 |
+
result_polygons, result_words = text_demo.run_on_opencv_image(image)
|
26 |
+
text_demo.visualization(image, result_polygons, result_words)
|
27 |
+
cv2.imwrite('result.jpg', image)
|
28 |
+
return 'result.jpg', pd.DataFrame(result_words)
|
29 |
+
|
30 |
+
|
31 |
+
iface = gr.Interface(
|
32 |
+
fn=infer,
|
33 |
+
title="Mask TextSpotter v3",
|
34 |
+
description="Mask TextSpotter v3 is an end-to-end trainable scene text spotter that adopts a Segmentation Proposal Network (SPN) instead of an RPN. Mask TextSpotter v3 significantly improves robustness to rotations, aspect ratios, and shapes.",
|
35 |
+
inputs=[gr.inputs.Image(label="image", type="filepath")],
|
36 |
+
outputs=[gr.outputs.Image(), gr.outputs.Dataframe(headers=['word'])],
|
37 |
+
examples=['example1.jpg', 'example2.jpg', 'example3.jpg'],
|
38 |
+
article="<a href=\"https://github.com/MhLiao/MaskTextSpotterV3\">GitHub Repo</a>",
|
39 |
+
).launch(enable_queue=True, cache_examples=True)
|
configs/mixtrain/seg_rec_poly_fuse_feature.yaml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
3 |
+
# WEIGHT: './output/path-to-pretrain-model' # for training
|
4 |
+
WEIGHT: './output/mixtrain/trained_model.pth' # for testing
|
5 |
+
BACKBONE:
|
6 |
+
CONV_BODY: "R-50-FPN"
|
7 |
+
OUT_CHANNELS: 256
|
8 |
+
RESNETS:
|
9 |
+
BACKBONE_OUT_CHANNELS: 256
|
10 |
+
RPN:
|
11 |
+
USE_FPN: True
|
12 |
+
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
|
13 |
+
PRE_NMS_TOP_N_TRAIN: 2000
|
14 |
+
PRE_NMS_TOP_N_TEST: 1000
|
15 |
+
POST_NMS_TOP_N_TEST: 1000
|
16 |
+
FPN_POST_NMS_TOP_N_TEST: 1000
|
17 |
+
SEG:
|
18 |
+
USE_FPN: True
|
19 |
+
USE_FUSE_FEATURE: True
|
20 |
+
TOP_N_TRAIN: 1000
|
21 |
+
TOP_N_TEST: 1000
|
22 |
+
BINARY_THRESH: 0.1
|
23 |
+
BOX_THRESH: 0.1
|
24 |
+
MIN_SIZE: 5
|
25 |
+
SHRINK_RATIO: 0.4
|
26 |
+
EXPAND_RATIO: 3.0
|
27 |
+
ROI_HEADS:
|
28 |
+
USE_FPN: True
|
29 |
+
BATCH_SIZE_PER_IMAGE: 512
|
30 |
+
ROI_BOX_HEAD:
|
31 |
+
POOLER_RESOLUTION: 7
|
32 |
+
POOLER_SCALES: (0.25,)
|
33 |
+
POOLER_SAMPLING_RATIO: 2
|
34 |
+
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
|
35 |
+
PREDICTOR: "FPNPredictor"
|
36 |
+
NUM_CLASSES: 2
|
37 |
+
USE_MASKED_FEATURE: True
|
38 |
+
ROI_MASK_HEAD:
|
39 |
+
POOLER_SCALES: (0.25,)
|
40 |
+
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
|
41 |
+
PREDICTOR: "SeqCharMaskRCNNC4Predictor"
|
42 |
+
POOLER_RESOLUTION: 14
|
43 |
+
POOLER_RESOLUTION_H: 32
|
44 |
+
POOLER_RESOLUTION_W: 32
|
45 |
+
POOLER_SAMPLING_RATIO: 2
|
46 |
+
RESOLUTION: 28
|
47 |
+
RESOLUTION_H: 64
|
48 |
+
RESOLUTION_W: 64
|
49 |
+
SHARE_BOX_FEATURE_EXTRACTOR: False
|
50 |
+
CHAR_NUM_CLASSES: 37
|
51 |
+
USE_WEIGHTED_CHAR_MASK: True
|
52 |
+
MASK_BATCH_SIZE_PER_IM: 64
|
53 |
+
USE_MASKED_FEATURE: True
|
54 |
+
MASK_ON: True
|
55 |
+
CHAR_MASK_ON: True
|
56 |
+
SEG_ON: True
|
57 |
+
# TRAIN_DETECTION_ONLY: True
|
58 |
+
SEQUENCE:
|
59 |
+
SEQ_ON: True
|
60 |
+
NUM_CHAR: 38
|
61 |
+
BOS_TOKEN: 0
|
62 |
+
MAX_LENGTH: 32
|
63 |
+
TEACHER_FORCE_RATIO: 1.0
|
64 |
+
DATASETS:
|
65 |
+
# TRAIN: ("synthtext_train",)
|
66 |
+
TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train")
|
67 |
+
RATIOS: [0.25,0.25,0.25,0.125,0.125]
|
68 |
+
# TEST: ("icdar_2015_test",)
|
69 |
+
TEST: ("total_text_test",)
|
70 |
+
# TEST: ("rotated_ic13_test_45",)
|
71 |
+
AUG: True
|
72 |
+
IGNORE_DIFFICULT: True
|
73 |
+
MAX_ROTATE_THETA: 90
|
74 |
+
DATALOADER:
|
75 |
+
SIZE_DIVISIBILITY: 32
|
76 |
+
NUM_WORKERS: 4
|
77 |
+
ASPECT_RATIO_GROUPING: False
|
78 |
+
SOLVER:
|
79 |
+
BASE_LR: 0.002 #0.02
|
80 |
+
WARMUP_FACTOR: 0.1
|
81 |
+
WEIGHT_DECAY: 0.0001
|
82 |
+
STEPS: (100000, 160000)
|
83 |
+
MAX_ITER: 300000
|
84 |
+
IMS_PER_BATCH: 8
|
85 |
+
RESUME: False
|
86 |
+
DISPLAY_FREQ: 20
|
87 |
+
OUTPUT_DIR: "./output/mixtrain"
|
88 |
+
TEST:
|
89 |
+
VIS: True
|
90 |
+
CHAR_THRESH: 192
|
91 |
+
IMS_PER_BATCH: 1
|
92 |
+
INPUT:
|
93 |
+
MIN_SIZE_TRAIN: (800, 1000, 1200, 1400)
|
94 |
+
MAX_SIZE_TRAIN: 2333
|
95 |
+
MIN_SIZE_TEST: 1000
|
96 |
+
# MIN_SIZE_TEST: 1440
|
97 |
+
MAX_SIZE_TEST: 4000
|
configs/pretrain/seg_rec_poly_fuse_feature.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
3 |
+
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
|
4 |
+
BACKBONE:
|
5 |
+
CONV_BODY: "R-50-FPN"
|
6 |
+
OUT_CHANNELS: 256
|
7 |
+
RESNETS:
|
8 |
+
BACKBONE_OUT_CHANNELS: 256
|
9 |
+
RPN:
|
10 |
+
USE_FPN: True
|
11 |
+
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
|
12 |
+
PRE_NMS_TOP_N_TRAIN: 2000
|
13 |
+
PRE_NMS_TOP_N_TEST: 1000
|
14 |
+
POST_NMS_TOP_N_TEST: 1000
|
15 |
+
FPN_POST_NMS_TOP_N_TEST: 1000
|
16 |
+
SEG:
|
17 |
+
USE_FPN: True
|
18 |
+
USE_FUSE_FEATURE: True
|
19 |
+
TOP_N_TRAIN: 1000
|
20 |
+
TOP_N_TEST: 1000
|
21 |
+
BINARY_THRESH: 0.1
|
22 |
+
BOX_THRESH: 0.1
|
23 |
+
MIN_SIZE: 5
|
24 |
+
SHRINK_RATIO: 0.4
|
25 |
+
EXPAND_RATIO: 3.0
|
26 |
+
ROI_HEADS:
|
27 |
+
USE_FPN: True
|
28 |
+
BATCH_SIZE_PER_IMAGE: 512
|
29 |
+
ROI_BOX_HEAD:
|
30 |
+
POOLER_RESOLUTION: 7
|
31 |
+
POOLER_SCALES: (0.25,)
|
32 |
+
POOLER_SAMPLING_RATIO: 2
|
33 |
+
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
|
34 |
+
PREDICTOR: "FPNPredictor"
|
35 |
+
NUM_CLASSES: 2
|
36 |
+
USE_MASKED_FEATURE: True
|
37 |
+
ROI_MASK_HEAD:
|
38 |
+
POOLER_SCALES: (0.25,)
|
39 |
+
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
|
40 |
+
PREDICTOR: "SeqCharMaskRCNNC4Predictor"
|
41 |
+
POOLER_RESOLUTION: 14
|
42 |
+
POOLER_RESOLUTION_H: 32
|
43 |
+
POOLER_RESOLUTION_W: 32
|
44 |
+
POOLER_SAMPLING_RATIO: 2
|
45 |
+
RESOLUTION: 28
|
46 |
+
RESOLUTION_H: 64
|
47 |
+
RESOLUTION_W: 64
|
48 |
+
SHARE_BOX_FEATURE_EXTRACTOR: False
|
49 |
+
CHAR_NUM_CLASSES: 37
|
50 |
+
USE_WEIGHTED_CHAR_MASK: True
|
51 |
+
MASK_BATCH_SIZE_PER_IM: 64
|
52 |
+
USE_MASKED_FEATURE: True
|
53 |
+
MASK_ON: True
|
54 |
+
CHAR_MASK_ON: True
|
55 |
+
SEG_ON: True
|
56 |
+
SEQUENCE:
|
57 |
+
SEQ_ON: True
|
58 |
+
NUM_CHAR: 38
|
59 |
+
BOS_TOKEN: 0
|
60 |
+
MAX_LENGTH: 32
|
61 |
+
TEACHER_FORCE_RATIO: 1.0
|
62 |
+
DATASETS:
|
63 |
+
TRAIN: ("synthtext_train",)
|
64 |
+
# TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train")
|
65 |
+
# RATIOS: [0.25,0.25,0.25,0.125,0.125]
|
66 |
+
TEST: ("icdar_2015_test",)
|
67 |
+
# TEST: ("total_text_test",)
|
68 |
+
AUG: True
|
69 |
+
IGNORE_DIFFICULT: True
|
70 |
+
MAX_ROTATE_THETA: 90
|
71 |
+
DATALOADER:
|
72 |
+
SIZE_DIVISIBILITY: 32
|
73 |
+
NUM_WORKERS: 4
|
74 |
+
ASPECT_RATIO_GROUPING: False
|
75 |
+
SOLVER:
|
76 |
+
BASE_LR: 0.02 #0.02
|
77 |
+
WARMUP_FACTOR: 0.1
|
78 |
+
WEIGHT_DECAY: 0.0001
|
79 |
+
STEPS: (100000, 200000)
|
80 |
+
MAX_ITER: 300000
|
81 |
+
IMS_PER_BATCH: 8
|
82 |
+
RESUME: True
|
83 |
+
DISPLAY_FREQ: 20
|
84 |
+
OUTPUT_DIR: "./output/pretrain"
|
85 |
+
TEST:
|
86 |
+
VIS: False
|
87 |
+
CHAR_THRESH: 192
|
88 |
+
IMS_PER_BATCH: 1
|
89 |
+
INPUT:
|
90 |
+
MIN_SIZE_TRAIN: (600, 800)
|
91 |
+
# MIN_SIZE_TRAIN: (800, 1000, 1200, 1400)
|
92 |
+
MAX_SIZE_TRAIN: 2333
|
93 |
+
MIN_SIZE_TEST: 1440
|
94 |
+
MAX_SIZE_TEST: 4000
|
evaluation/icdar2015/e2e/prepare_results.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
sys.path.append('./')
|
6 |
+
import shapely
|
7 |
+
from shapely.geometry import Polygon,MultiPoint
|
8 |
+
import numpy as np
|
9 |
+
import editdistance
|
10 |
+
sys.path.append('../../')
|
11 |
+
from weighted_editdistance import weighted_edit_distance
|
12 |
+
from tqdm import tqdm
|
13 |
+
try:
|
14 |
+
import pickle
|
15 |
+
except ImportError:
|
16 |
+
import cPickle as pickle
|
17 |
+
|
18 |
+
def list_from_str(st):
|
19 |
+
line = st.split(',')
|
20 |
+
# box[0:4], polygon[4:12], word, seq_word, detection_score, rec_socre, seq_score, char_score_path
|
21 |
+
new_line = [float(a) for a in line[4:12]]+[float(line[-4])]+[line[-5]]+[line[-6]]+[float(line[-3])]+[float(line[-2])] + [line[-1]]
|
22 |
+
return new_line
|
23 |
+
|
24 |
+
def polygon_from_list(line):
|
25 |
+
"""
|
26 |
+
Create a shapely polygon object from gt or dt line.
|
27 |
+
"""
|
28 |
+
polygon_points = np.array(line).reshape(4, 2)
|
29 |
+
polygon = Polygon(polygon_points).convex_hull
|
30 |
+
return polygon
|
31 |
+
|
32 |
+
def polygon_iou(list1, list2):
|
33 |
+
"""
|
34 |
+
Intersection over union between two shapely polygons.
|
35 |
+
"""
|
36 |
+
polygon_points1 = np.array(list1).reshape(4, 2)
|
37 |
+
poly1 = Polygon(polygon_points1).convex_hull
|
38 |
+
polygon_points2 = np.array(list2).reshape(4, 2)
|
39 |
+
poly2 = Polygon(polygon_points2).convex_hull
|
40 |
+
union_poly = np.concatenate((polygon_points1,polygon_points2))
|
41 |
+
if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
|
42 |
+
iou = 0
|
43 |
+
else:
|
44 |
+
try:
|
45 |
+
inter_area = poly1.intersection(poly2).area
|
46 |
+
#union_area = poly1.area + poly2.area - inter_area
|
47 |
+
union_area = MultiPoint(union_poly).convex_hull.area
|
48 |
+
iou = float(inter_area) / (union_area+1e-6)
|
49 |
+
except shapely.geos.TopologicalError:
|
50 |
+
print('shapely.geos.TopologicalError occured, iou set to 0')
|
51 |
+
iou = 0
|
52 |
+
return iou
|
53 |
+
|
54 |
+
def nms(boxes,overlap):
|
55 |
+
rec_scores = [b[-2] for b in boxes]
|
56 |
+
indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k])
|
57 |
+
box_num = len(boxes)
|
58 |
+
nms_flag = [True]*box_num
|
59 |
+
for i in range(box_num):
|
60 |
+
ii = indices[i]
|
61 |
+
if not nms_flag[ii]:
|
62 |
+
continue
|
63 |
+
for j in range(box_num):
|
64 |
+
jj = indices[j]
|
65 |
+
if j == i:
|
66 |
+
continue
|
67 |
+
if not nms_flag[jj]:
|
68 |
+
continue
|
69 |
+
box1 = boxes[ii]
|
70 |
+
box2 = boxes[jj]
|
71 |
+
box1_score = rec_scores[ii]
|
72 |
+
box2_score = rec_scores[jj]
|
73 |
+
str1 = box1[9]
|
74 |
+
str2 = box2[9]
|
75 |
+
box_i = [box1[0],box1[1],box1[4],box1[5]]
|
76 |
+
box_j = [box2[0],box2[1],box2[4],box2[5]]
|
77 |
+
poly1 = polygon_from_list(box1[0:8])
|
78 |
+
poly2 = polygon_from_list(box2[0:8])
|
79 |
+
iou = polygon_iou(box1[0:8],box2[0:8])
|
80 |
+
thresh = overlap
|
81 |
+
|
82 |
+
if iou > thresh:
|
83 |
+
if box1_score > box2_score:
|
84 |
+
nms_flag[jj] = False
|
85 |
+
if box1_score == box2_score and poly1.area > poly2.area:
|
86 |
+
nms_flag[jj] = False
|
87 |
+
if box1_score == box2_score and poly1.area<=poly2.area:
|
88 |
+
nms_flag[ii] = False
|
89 |
+
break
|
90 |
+
|
91 |
+
return nms_flag
|
92 |
+
|
93 |
+
def packing(save_dir, cache_dir, pack_name):
|
94 |
+
files = os.listdir(save_dir)
|
95 |
+
if not os.path.exists(cache_dir):
|
96 |
+
os.mkdir(cache_dir)
|
97 |
+
os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*')
|
98 |
+
|
99 |
+
def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False):
|
100 |
+
'''
|
101 |
+
results_dir: result directory
|
102 |
+
score_det: score of detection bounding box
|
103 |
+
score_rec: score of the mask recognition branch
|
104 |
+
socre_rec_seq: score of the sequence recognition branch
|
105 |
+
overlap: overlap threshold used for nms
|
106 |
+
lexicon_type: 1 for generic; 2 for weak; 3 for strong
|
107 |
+
use_seq: use the recognition result of sequence branch
|
108 |
+
use_mix: use both the recognition result of the mask and sequence branches, selected by score
|
109 |
+
'''
|
110 |
+
print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix)
|
111 |
+
if not os.path.exists(cache_dir):
|
112 |
+
os.mkdir(cache_dir)
|
113 |
+
nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq))
|
114 |
+
if not os.path.exists(nms_dir):
|
115 |
+
os.mkdir(nms_dir)
|
116 |
+
if lexicon_type==1:
|
117 |
+
# generic lexicon
|
118 |
+
lexicon_path = '../../lexicons/ic15/GenericVocabulary_new.txt'
|
119 |
+
lexicon_fid=open(lexicon_path, 'r')
|
120 |
+
pair_list = open('../../lexicons/ic15/GenericVocabulary_pair_list.txt', 'r')
|
121 |
+
pairs = dict()
|
122 |
+
for line in pair_list.readlines():
|
123 |
+
line=line.strip()
|
124 |
+
word = line.split(' ')[0].upper()
|
125 |
+
word_gt = line[len(word)+1:]
|
126 |
+
pairs[word] = word_gt
|
127 |
+
lexicon_fid=open(lexicon_path, 'r')
|
128 |
+
lexicon=[]
|
129 |
+
for line in lexicon_fid.readlines():
|
130 |
+
line=line.strip()
|
131 |
+
lexicon.append(line)
|
132 |
+
if lexicon_type==2:
|
133 |
+
# weak lexicon
|
134 |
+
lexicon_path = '../../lexicons/ic15/ch4_test_vocabulary_new.txt'
|
135 |
+
lexicon_fid=open(lexicon_path, 'r')
|
136 |
+
pair_list = open('../../lexicons/ic15/ch4_test_vocabulary_pair_list.txt', 'r')
|
137 |
+
pairs = dict()
|
138 |
+
for line in pair_list.readlines():
|
139 |
+
line=line.strip()
|
140 |
+
word = line.split(' ')[0].upper()
|
141 |
+
word_gt = line[len(word)+1:]
|
142 |
+
pairs[word] = word_gt
|
143 |
+
lexicon_fid=open(lexicon_path, 'r')
|
144 |
+
lexicon=[]
|
145 |
+
for line in lexicon_fid.readlines():
|
146 |
+
line=line.strip()
|
147 |
+
lexicon.append(line)
|
148 |
+
|
149 |
+
for i in tqdm(range(1,501)):
|
150 |
+
img = 'img_'+str(i)+'.jpg'
|
151 |
+
gt_img = 'gt_img_'+str(i)+'.txt'
|
152 |
+
if lexicon_type==3:
|
153 |
+
# weak
|
154 |
+
lexicon_path = '../../lexicons/ic15/new_strong_lexicon/new_voc_img_' + str(i) + '.txt'
|
155 |
+
lexicon_fid=open(lexicon_path, 'r')
|
156 |
+
pair_list = open('../../lexicons/ic15/new_strong_lexicon/pair_voc_img_' + str(i) + '.txt', 'r')
|
157 |
+
pairs = dict()
|
158 |
+
for line in pair_list.readlines():
|
159 |
+
line=line.strip()
|
160 |
+
word = line.split(' ')[0].upper()
|
161 |
+
word_gt = line[len(word)+1:]
|
162 |
+
pairs[word] = word_gt
|
163 |
+
lexicon_fid=open(lexicon_path, 'r')
|
164 |
+
lexicon=[]
|
165 |
+
for line in lexicon_fid.readlines():
|
166 |
+
line=line.strip()
|
167 |
+
lexicon.append(line)
|
168 |
+
result_path = os.path.join(results_dir,'res_img_'+str(i)+'.txt')
|
169 |
+
if os.path.isfile(result_path):
|
170 |
+
with open(result_path,'r') as f:
|
171 |
+
dt_lines = [a.strip() for a in f.readlines()]
|
172 |
+
dt_lines = [list_from_str(dt) for dt in dt_lines]
|
173 |
+
else:
|
174 |
+
dt_lines = []
|
175 |
+
dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det]
|
176 |
+
nms_flag = nms(dt_lines,overlap)
|
177 |
+
boxes = []
|
178 |
+
for k in range(len(dt_lines)):
|
179 |
+
dt = dt_lines[k]
|
180 |
+
if nms_flag[k]:
|
181 |
+
if dt not in boxes:
|
182 |
+
boxes.append(dt)
|
183 |
+
|
184 |
+
with open(os.path.join(nms_dir,'res_img_'+str(i)+'.txt'),'w') as f:
|
185 |
+
for g in boxes:
|
186 |
+
gt_coors = [int(b) for b in g[0:8]]
|
187 |
+
with open('../../../' + g[-1], "rb") as input_file:
|
188 |
+
# with open(g[-1], "rb") as input_file:
|
189 |
+
dict_scores = pickle.load(input_file)
|
190 |
+
if use_char and use_seq:
|
191 |
+
if g[-2]>g[-3]:
|
192 |
+
word = g[-5]
|
193 |
+
scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
|
194 |
+
else:
|
195 |
+
word = g[-4]
|
196 |
+
scores = dict_scores['seg_char_scores']
|
197 |
+
elif use_seq:
|
198 |
+
word = g[-5]
|
199 |
+
scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
|
200 |
+
else:
|
201 |
+
word = g[-4]
|
202 |
+
scores = dict_scores['seg_char_scores']
|
203 |
+
match_word, match_dist = find_match_word(word, lexicon, pairs, scores, use_lexicon, weighted_ed)
|
204 |
+
if match_dist<1.5 or lexicon_type==1:
|
205 |
+
gt_coor_strs = [str(a) for a in gt_coors]+ [match_word]
|
206 |
+
f.write(','.join(gt_coor_strs)+'\r\n')
|
207 |
+
|
208 |
+
pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap)
|
209 |
+
|
210 |
+
packing(nms_dir,cache_dir,pack_name)
|
211 |
+
submit_file_path = os.path.join(cache_dir, pack_name+'.zip')
|
212 |
+
return submit_file_path
|
213 |
+
|
214 |
+
def find_match_word(rec_str, lexicon, pairs, scores_numpy, use_ed = True, weighted_ed = False):
|
215 |
+
if not use_ed:
|
216 |
+
return rec_str
|
217 |
+
rec_str = rec_str.upper()
|
218 |
+
dist_min = 100
|
219 |
+
dist_min_pre = 100
|
220 |
+
match_word = ''
|
221 |
+
match_dist = 100
|
222 |
+
if not weighted_ed:
|
223 |
+
for word in lexicon:
|
224 |
+
word = word.upper()
|
225 |
+
ed = editdistance.eval(rec_str, word)
|
226 |
+
length_dist = abs(len(word) - len(rec_str))
|
227 |
+
# dist = ed + length_dist
|
228 |
+
dist = ed
|
229 |
+
if dist<dist_min:
|
230 |
+
dist_min = dist
|
231 |
+
match_word = pairs[word]
|
232 |
+
match_dist = dist
|
233 |
+
return match_word, match_dist
|
234 |
+
else:
|
235 |
+
small_lexicon_dict = dict()
|
236 |
+
for word in lexicon:
|
237 |
+
word = word.upper()
|
238 |
+
ed = editdistance.eval(rec_str, word)
|
239 |
+
small_lexicon_dict[word] = ed
|
240 |
+
dist = ed
|
241 |
+
if dist<dist_min_pre:
|
242 |
+
dist_min_pre = dist
|
243 |
+
small_lexicon = []
|
244 |
+
for word in small_lexicon_dict:
|
245 |
+
if small_lexicon_dict[word]<=dist_min_pre+2:
|
246 |
+
small_lexicon.append(word)
|
247 |
+
|
248 |
+
for word in small_lexicon:
|
249 |
+
word = word.upper()
|
250 |
+
ed = weighted_edit_distance(rec_str, word, scores_numpy)
|
251 |
+
dist = ed
|
252 |
+
if dist<dist_min:
|
253 |
+
dist_min = dist
|
254 |
+
match_word = pairs[word]
|
255 |
+
match_dist = dist
|
256 |
+
return match_word, match_dist
|
257 |
+
|
258 |
+
|
259 |
+
def prepare_results_for_evaluation(results_dir, lexicon_type, cache_dir, score_det, score_rec, score_rec_seq):
|
260 |
+
if not os.path.isdir(cache_dir):
|
261 |
+
os.mkdir(cache_dir)
|
262 |
+
result_path = test_single(results_dir,score_det=score_det,score_rec=score_rec,score_rec_seq=score_rec_seq,overlap=0.2,cache_dir=cache_dir,lexicon_type=lexicon_type, use_lexicon=True, weighted_ed=True, use_seq=True, use_char=True, mix=True)
|
263 |
+
return result_path
|
evaluation/icdar2015/e2e/rrc_evaluation_funcs.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
#encoding: UTF-8
|
3 |
+
import json
|
4 |
+
import sys;sys.path.append('./')
|
5 |
+
import zipfile
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import codecs
|
10 |
+
import importlib
|
11 |
+
try:
|
12 |
+
from StringIO import StringIO
|
13 |
+
except ImportError:
|
14 |
+
from io import StringIO
|
15 |
+
|
16 |
+
def print_help():
|
17 |
+
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
|
18 |
+
sys.exit(2)
|
19 |
+
|
20 |
+
|
21 |
+
def load_zip_file_keys(file,fileNameRegExp=''):
|
22 |
+
"""
|
23 |
+
Returns an array with the entries of the ZIP file that match with the regular expression.
|
24 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
25 |
+
"""
|
26 |
+
try:
|
27 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
28 |
+
except :
|
29 |
+
raise Exception('Error loading the ZIP archive.')
|
30 |
+
|
31 |
+
pairs = []
|
32 |
+
|
33 |
+
for name in archive.namelist():
|
34 |
+
addFile = True
|
35 |
+
keyName = name
|
36 |
+
if fileNameRegExp!="":
|
37 |
+
m = re.match(fileNameRegExp,name)
|
38 |
+
if m == None:
|
39 |
+
addFile = False
|
40 |
+
else:
|
41 |
+
if len(m.groups())>0:
|
42 |
+
keyName = m.group(1)
|
43 |
+
|
44 |
+
if addFile:
|
45 |
+
pairs.append( keyName )
|
46 |
+
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def load_zip_file(file,fileNameRegExp='',allEntries=False):
|
51 |
+
"""
|
52 |
+
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
|
53 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
54 |
+
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
58 |
+
except :
|
59 |
+
raise Exception('Error loading the ZIP archive')
|
60 |
+
|
61 |
+
pairs = []
|
62 |
+
for name in archive.namelist():
|
63 |
+
addFile = True
|
64 |
+
keyName = name
|
65 |
+
if fileNameRegExp!="":
|
66 |
+
m = re.match(fileNameRegExp,name)
|
67 |
+
if m == None:
|
68 |
+
addFile = False
|
69 |
+
else:
|
70 |
+
if len(m.groups())>0:
|
71 |
+
keyName = m.group(1)
|
72 |
+
|
73 |
+
if addFile:
|
74 |
+
pairs.append( [ keyName , archive.read(name)] )
|
75 |
+
else:
|
76 |
+
if allEntries:
|
77 |
+
raise Exception('ZIP entry not valid: %s' %name)
|
78 |
+
|
79 |
+
return dict(pairs)
|
80 |
+
|
81 |
+
def decode_utf8(raw):
|
82 |
+
"""
|
83 |
+
Returns a Unicode object on success, or None on failure
|
84 |
+
"""
|
85 |
+
try:
|
86 |
+
raw = codecs.decode(raw,'utf-8', 'replace')
|
87 |
+
#extracts BOM if exists
|
88 |
+
raw = raw.encode('utf8')
|
89 |
+
if raw.startswith(codecs.BOM_UTF8):
|
90 |
+
raw = raw.replace(codecs.BOM_UTF8, '', 1)
|
91 |
+
return raw.decode('utf-8')
|
92 |
+
except:
|
93 |
+
return None
|
94 |
+
|
95 |
+
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
96 |
+
"""
|
97 |
+
This function validates that all lines of the file calling the Line validation function for each line
|
98 |
+
"""
|
99 |
+
utf8File = decode_utf8(file_contents)
|
100 |
+
if (utf8File is None) :
|
101 |
+
raise Exception("The file %s is not UTF-8" %fileName)
|
102 |
+
|
103 |
+
lines = utf8File.split( "\r\n" if CRLF else "\n" )
|
104 |
+
for line in lines:
|
105 |
+
line = line.replace("\r","").replace("\n","")
|
106 |
+
if(line != ""):
|
107 |
+
try:
|
108 |
+
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
109 |
+
except Exception as e:
|
110 |
+
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
|
115 |
+
"""
|
116 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
117 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
118 |
+
Posible values are:
|
119 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
120 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
121 |
+
"""
|
122 |
+
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
123 |
+
|
124 |
+
|
125 |
+
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
126 |
+
"""
|
127 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
128 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
129 |
+
Posible values are:
|
130 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
131 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
132 |
+
Returns values from a textline. Points , [Confidences], [Transcriptions]
|
133 |
+
"""
|
134 |
+
confidence = 0.0
|
135 |
+
transcription = "";
|
136 |
+
points = []
|
137 |
+
|
138 |
+
numPoints = 4;
|
139 |
+
|
140 |
+
if LTRB:
|
141 |
+
|
142 |
+
numPoints = 4;
|
143 |
+
|
144 |
+
if withTranscription and withConfidence:
|
145 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
146 |
+
if m == None :
|
147 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
148 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
|
149 |
+
elif withConfidence:
|
150 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
151 |
+
if m == None :
|
152 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
|
153 |
+
elif withTranscription:
|
154 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
|
155 |
+
if m == None :
|
156 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
|
157 |
+
else:
|
158 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
|
159 |
+
if m == None :
|
160 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
|
161 |
+
|
162 |
+
xmin = int(m.group(1))
|
163 |
+
ymin = int(m.group(2))
|
164 |
+
xmax = int(m.group(3))
|
165 |
+
ymax = int(m.group(4))
|
166 |
+
if(xmax<xmin):
|
167 |
+
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
|
168 |
+
if(ymax<ymin):
|
169 |
+
raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
|
170 |
+
|
171 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
172 |
+
|
173 |
+
if (imWidth>0 and imHeight>0):
|
174 |
+
validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
|
175 |
+
validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
|
176 |
+
|
177 |
+
else:
|
178 |
+
|
179 |
+
numPoints = 8;
|
180 |
+
|
181 |
+
if withTranscription and withConfidence:
|
182 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
183 |
+
if m == None :
|
184 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
|
185 |
+
elif withConfidence:
|
186 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
187 |
+
if m == None :
|
188 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
|
189 |
+
elif withTranscription:
|
190 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
|
191 |
+
if m == None :
|
192 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
|
193 |
+
else:
|
194 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
|
195 |
+
if m == None :
|
196 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
|
197 |
+
|
198 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
199 |
+
|
200 |
+
validate_clockwise_points(points)
|
201 |
+
|
202 |
+
if (imWidth>0 and imHeight>0):
|
203 |
+
validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
|
204 |
+
validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
|
205 |
+
validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
|
206 |
+
validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
|
207 |
+
|
208 |
+
|
209 |
+
if withConfidence:
|
210 |
+
try:
|
211 |
+
confidence = float(m.group(numPoints+1))
|
212 |
+
except ValueError:
|
213 |
+
raise Exception("Confidence value must be a float")
|
214 |
+
|
215 |
+
if withTranscription:
|
216 |
+
posTranscription = numPoints + (2 if withConfidence else 1)
|
217 |
+
transcription = m.group(posTranscription)
|
218 |
+
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
|
219 |
+
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
|
220 |
+
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
|
221 |
+
|
222 |
+
return points,confidence,transcription
|
223 |
+
|
224 |
+
|
225 |
+
def validate_point_inside_bounds(x,y,imWidth,imHeight):
|
226 |
+
if(x<0 or x>imWidth):
|
227 |
+
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
|
228 |
+
if(y<0 or y>imHeight):
|
229 |
+
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
|
230 |
+
|
231 |
+
def validate_clockwise_points(points):
|
232 |
+
"""
|
233 |
+
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
|
234 |
+
"""
|
235 |
+
|
236 |
+
if len(points) != 8:
|
237 |
+
raise Exception("Points list not valid." + str(len(points)))
|
238 |
+
|
239 |
+
point = [
|
240 |
+
[int(points[0]) , int(points[1])],
|
241 |
+
[int(points[2]) , int(points[3])],
|
242 |
+
[int(points[4]) , int(points[5])],
|
243 |
+
[int(points[6]) , int(points[7])]
|
244 |
+
]
|
245 |
+
edge = [
|
246 |
+
( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
|
247 |
+
( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
|
248 |
+
( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
|
249 |
+
( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
|
250 |
+
]
|
251 |
+
|
252 |
+
summatory = edge[0] + edge[1] + edge[2] + edge[3];
|
253 |
+
if summatory>0:
|
254 |
+
raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
|
255 |
+
|
256 |
+
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
|
257 |
+
"""
|
258 |
+
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
|
259 |
+
xmin,ymin,xmax,ymax,[confidence],[transcription]
|
260 |
+
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
|
261 |
+
"""
|
262 |
+
pointsList = []
|
263 |
+
transcriptionsList = []
|
264 |
+
confidencesList = []
|
265 |
+
|
266 |
+
lines = content.split( "\r\n" if CRLF else "\n" )
|
267 |
+
for line in lines:
|
268 |
+
line = line.replace("\r","").replace("\n","")
|
269 |
+
if(line != "") :
|
270 |
+
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
|
271 |
+
pointsList.append(points)
|
272 |
+
transcriptionsList.append(transcription)
|
273 |
+
confidencesList.append(confidence)
|
274 |
+
|
275 |
+
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
|
276 |
+
import numpy as np
|
277 |
+
sorted_ind = np.argsort(-np.array(confidencesList))
|
278 |
+
confidencesList = [confidencesList[i] for i in sorted_ind]
|
279 |
+
pointsList = [pointsList[i] for i in sorted_ind]
|
280 |
+
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
|
281 |
+
|
282 |
+
return pointsList,confidencesList,transcriptionsList
|
283 |
+
|
284 |
+
def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
|
285 |
+
"""
|
286 |
+
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
|
287 |
+
Params:
|
288 |
+
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
|
289 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
290 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
291 |
+
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
|
292 |
+
"""
|
293 |
+
|
294 |
+
if (p == None):
|
295 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
296 |
+
if(len(sys.argv)<3):
|
297 |
+
print_help()
|
298 |
+
|
299 |
+
evalParams = default_evaluation_params_fn()
|
300 |
+
if 'p' in p.keys():
|
301 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
302 |
+
|
303 |
+
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
|
304 |
+
try:
|
305 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
306 |
+
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
|
307 |
+
resDict.update(evalData)
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
resDict['Message']= str(e)
|
311 |
+
resDict['calculated']=False
|
312 |
+
|
313 |
+
if 'o' in p:
|
314 |
+
if not os.path.exists(p['o']):
|
315 |
+
os.makedirs(p['o'])
|
316 |
+
|
317 |
+
resultsOutputname = p['o'] + '/results.zip'
|
318 |
+
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
|
319 |
+
|
320 |
+
del resDict['per_sample']
|
321 |
+
if 'output_items' in resDict.keys():
|
322 |
+
del resDict['output_items']
|
323 |
+
|
324 |
+
outZip.writestr('method.json',json.dumps(resDict))
|
325 |
+
|
326 |
+
if not resDict['calculated']:
|
327 |
+
if show_result:
|
328 |
+
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
|
329 |
+
if 'o' in p:
|
330 |
+
outZip.close()
|
331 |
+
return resDict
|
332 |
+
|
333 |
+
if 'o' in p:
|
334 |
+
if per_sample == True:
|
335 |
+
for k,v in evalData['per_sample'].items():
|
336 |
+
outZip.writestr( k + '.json',json.dumps(v))
|
337 |
+
|
338 |
+
if 'output_items' in evalData.keys():
|
339 |
+
for k, v in evalData['output_items'].items():
|
340 |
+
outZip.writestr( k,v)
|
341 |
+
|
342 |
+
outZip.close()
|
343 |
+
|
344 |
+
if show_result:
|
345 |
+
sys.stdout.write("Calculated!")
|
346 |
+
sys.stdout.write(json.dumps(resDict['method']))
|
347 |
+
|
348 |
+
return resDict
|
349 |
+
|
350 |
+
|
351 |
+
def main_validation(default_evaluation_params_fn,validate_data_fn):
|
352 |
+
"""
|
353 |
+
This process validates a method
|
354 |
+
Params:
|
355 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
356 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
357 |
+
"""
|
358 |
+
try:
|
359 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
360 |
+
evalParams = default_evaluation_params_fn()
|
361 |
+
if 'p' in p.keys():
|
362 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
363 |
+
|
364 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
365 |
+
print('SUCCESS')
|
366 |
+
sys.exit(0)
|
367 |
+
except Exception as e:
|
368 |
+
print(str(e))
|
369 |
+
sys.exit(101)
|
evaluation/icdar2015/e2e/script.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# encoding=utf8
|
4 |
+
from collections import namedtuple
|
5 |
+
import rrc_evaluation_funcs
|
6 |
+
import importlib
|
7 |
+
from prepare_results import prepare_results_for_evaluation
|
8 |
+
|
9 |
+
def evaluation_imports():
|
10 |
+
"""
|
11 |
+
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
|
12 |
+
"""
|
13 |
+
return {
|
14 |
+
'Polygon':'plg',
|
15 |
+
'numpy':'np'
|
16 |
+
}
|
17 |
+
|
18 |
+
def default_evaluation_params():
|
19 |
+
"""
|
20 |
+
default_evaluation_params: Default parameters to use for the validation and evaluation.
|
21 |
+
"""
|
22 |
+
return {
|
23 |
+
'IOU_CONSTRAINT' :0.5,
|
24 |
+
'AREA_PRECISION_CONSTRAINT' :0.5,
|
25 |
+
'WORD_SPOTTING' :False,
|
26 |
+
'MIN_LENGTH_CARE_WORD' :3,
|
27 |
+
'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
|
28 |
+
'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
|
29 |
+
'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
|
30 |
+
'CRLF':False, # Lines are delimited by Windows CRLF format
|
31 |
+
'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
|
32 |
+
'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
|
33 |
+
'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
|
34 |
+
}
|
35 |
+
|
36 |
+
def validate_data(gtFilePath, submFilePath, evaluationParams):
|
37 |
+
"""
|
38 |
+
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
|
39 |
+
Validates also that there are no missing files in the folder.
|
40 |
+
If some error detected, the method raises the error
|
41 |
+
"""
|
42 |
+
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
43 |
+
|
44 |
+
subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
|
45 |
+
|
46 |
+
#Validate format of GroundTruth
|
47 |
+
for k in gt:
|
48 |
+
rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
|
49 |
+
|
50 |
+
#Validate format of results
|
51 |
+
for k in subm:
|
52 |
+
if (k in gt) == False :
|
53 |
+
raise Exception("The sample %s not present in GT" %k)
|
54 |
+
|
55 |
+
rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
|
56 |
+
|
57 |
+
|
58 |
+
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
|
59 |
+
"""
|
60 |
+
Method evaluate_method: evaluate method and returns the results
|
61 |
+
Results. Dictionary with the following values:
|
62 |
+
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
|
63 |
+
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
|
64 |
+
"""
|
65 |
+
for module,alias in evaluation_imports().items():
|
66 |
+
globals()[alias] = importlib.import_module(module)
|
67 |
+
|
68 |
+
def polygon_from_points(points,correctOffset=False):
|
69 |
+
"""
|
70 |
+
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
|
71 |
+
"""
|
72 |
+
|
73 |
+
if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax
|
74 |
+
points[2] -= 1
|
75 |
+
points[4] -= 1
|
76 |
+
points[5] -= 1
|
77 |
+
points[7] -= 1
|
78 |
+
|
79 |
+
resBoxes=np.empty([1,8],dtype='int32')
|
80 |
+
resBoxes[0,0]=int(points[0])
|
81 |
+
resBoxes[0,4]=int(points[1])
|
82 |
+
resBoxes[0,1]=int(points[2])
|
83 |
+
resBoxes[0,5]=int(points[3])
|
84 |
+
resBoxes[0,2]=int(points[4])
|
85 |
+
resBoxes[0,6]=int(points[5])
|
86 |
+
resBoxes[0,3]=int(points[6])
|
87 |
+
resBoxes[0,7]=int(points[7])
|
88 |
+
pointMat = resBoxes[0].reshape([2,4]).T
|
89 |
+
return plg.Polygon( pointMat)
|
90 |
+
|
91 |
+
def rectangle_to_polygon(rect):
|
92 |
+
resBoxes=np.empty([1,8],dtype='int32')
|
93 |
+
resBoxes[0,0]=int(rect.xmin)
|
94 |
+
resBoxes[0,4]=int(rect.ymax)
|
95 |
+
resBoxes[0,1]=int(rect.xmin)
|
96 |
+
resBoxes[0,5]=int(rect.ymin)
|
97 |
+
resBoxes[0,2]=int(rect.xmax)
|
98 |
+
resBoxes[0,6]=int(rect.ymin)
|
99 |
+
resBoxes[0,3]=int(rect.xmax)
|
100 |
+
resBoxes[0,7]=int(rect.ymax)
|
101 |
+
|
102 |
+
pointMat = resBoxes[0].reshape([2,4]).T
|
103 |
+
|
104 |
+
return plg.Polygon( pointMat)
|
105 |
+
|
106 |
+
def rectangle_to_points(rect):
|
107 |
+
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
|
108 |
+
return points
|
109 |
+
|
110 |
+
def get_union(pD,pG):
|
111 |
+
areaA = pD.area();
|
112 |
+
areaB = pG.area();
|
113 |
+
return areaA + areaB - get_intersection(pD, pG);
|
114 |
+
|
115 |
+
def get_intersection_over_union(pD,pG):
|
116 |
+
try:
|
117 |
+
return get_intersection(pD, pG) / get_union(pD, pG);
|
118 |
+
except:
|
119 |
+
return 0
|
120 |
+
|
121 |
+
def get_intersection(pD,pG):
|
122 |
+
pInt = pD & pG
|
123 |
+
if len(pInt) == 0:
|
124 |
+
return 0
|
125 |
+
return pInt.area()
|
126 |
+
|
127 |
+
def compute_ap(confList, matchList,numGtCare):
|
128 |
+
correct = 0
|
129 |
+
AP = 0
|
130 |
+
if len(confList)>0:
|
131 |
+
confList = np.array(confList)
|
132 |
+
matchList = np.array(matchList)
|
133 |
+
sorted_ind = np.argsort(-confList)
|
134 |
+
confList = confList[sorted_ind]
|
135 |
+
matchList = matchList[sorted_ind]
|
136 |
+
for n in range(len(confList)):
|
137 |
+
match = matchList[n]
|
138 |
+
if match:
|
139 |
+
correct += 1
|
140 |
+
AP += float(correct)/(n + 1)
|
141 |
+
|
142 |
+
if numGtCare>0:
|
143 |
+
AP /= numGtCare
|
144 |
+
|
145 |
+
return AP
|
146 |
+
|
147 |
+
def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
|
148 |
+
|
149 |
+
if onlyRemoveFirstLastCharacterGT:
|
150 |
+
#special characters in GT are allowed only at initial or final position
|
151 |
+
if (transGt==transDet):
|
152 |
+
return True
|
153 |
+
|
154 |
+
if specialCharacters.find(transGt[0])>-1:
|
155 |
+
if transGt[1:]==transDet:
|
156 |
+
return True
|
157 |
+
|
158 |
+
if specialCharacters.find(transGt[-1])>-1:
|
159 |
+
if transGt[0:len(transGt)-1]==transDet:
|
160 |
+
return True
|
161 |
+
|
162 |
+
if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
|
163 |
+
if transGt[1:len(transGt)-1]==transDet:
|
164 |
+
return True
|
165 |
+
return False
|
166 |
+
else:
|
167 |
+
#Special characters are removed from the begining and the end of both Detection and GroundTruth
|
168 |
+
while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
|
169 |
+
transGt = transGt[1:]
|
170 |
+
|
171 |
+
while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
|
172 |
+
transDet = transDet[1:]
|
173 |
+
|
174 |
+
while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
|
175 |
+
transGt = transGt[0:len(transGt)-1]
|
176 |
+
|
177 |
+
while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
|
178 |
+
transDet = transDet[0:len(transDet)-1]
|
179 |
+
|
180 |
+
return transGt == transDet
|
181 |
+
|
182 |
+
|
183 |
+
def include_in_dictionary(transcription):
|
184 |
+
"""
|
185 |
+
Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
|
186 |
+
"""
|
187 |
+
#special case 's at final
|
188 |
+
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
|
189 |
+
transcription = transcription[0:len(transcription)-2]
|
190 |
+
|
191 |
+
#hypens at init or final of the word
|
192 |
+
transcription = transcription.strip('-');
|
193 |
+
|
194 |
+
specialCharacters = "'!?.:,*\"()·[]/";
|
195 |
+
for character in specialCharacters:
|
196 |
+
transcription = transcription.replace(character,' ')
|
197 |
+
|
198 |
+
transcription = transcription.strip()
|
199 |
+
|
200 |
+
if len(transcription) != len(transcription.replace(" ","")) :
|
201 |
+
return False;
|
202 |
+
|
203 |
+
if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
|
204 |
+
return False;
|
205 |
+
|
206 |
+
notAllowed = "×÷·";
|
207 |
+
|
208 |
+
range1 = [ ord(u'a'), ord(u'z') ]
|
209 |
+
range2 = [ ord(u'A'), ord(u'Z') ]
|
210 |
+
range3 = [ ord(u'À'), ord(u'ƿ') ]
|
211 |
+
range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
|
212 |
+
range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
|
213 |
+
range6 = [ ord(u'-'), ord(u'-') ]
|
214 |
+
|
215 |
+
for char in transcription :
|
216 |
+
charCode = ord(char)
|
217 |
+
if(notAllowed.find(char) != -1):
|
218 |
+
return False
|
219 |
+
|
220 |
+
valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
|
221 |
+
if valid == False:
|
222 |
+
return False
|
223 |
+
|
224 |
+
return True
|
225 |
+
|
226 |
+
def include_in_dictionary_transcription(transcription):
|
227 |
+
"""
|
228 |
+
Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
|
229 |
+
"""
|
230 |
+
#special case 's at final
|
231 |
+
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
|
232 |
+
transcription = transcription[0:len(transcription)-2]
|
233 |
+
|
234 |
+
#hypens at init or final of the word
|
235 |
+
transcription = transcription.strip('-');
|
236 |
+
|
237 |
+
specialCharacters = "'!?.:,*\"()·[]/";
|
238 |
+
for character in specialCharacters:
|
239 |
+
transcription = transcription.replace(character,' ')
|
240 |
+
|
241 |
+
transcription = transcription.strip()
|
242 |
+
|
243 |
+
return transcription
|
244 |
+
|
245 |
+
perSampleMetrics = {}
|
246 |
+
|
247 |
+
matchedSum = 0
|
248 |
+
|
249 |
+
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
250 |
+
|
251 |
+
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
252 |
+
subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
|
253 |
+
|
254 |
+
numGlobalCareGt = 0;
|
255 |
+
numGlobalCareDet = 0;
|
256 |
+
|
257 |
+
arrGlobalConfidences = [];
|
258 |
+
arrGlobalMatches = [];
|
259 |
+
|
260 |
+
for resFile in gt:
|
261 |
+
|
262 |
+
gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
|
263 |
+
if (gtFile is None) :
|
264 |
+
raise Exception("The file %s is not UTF-8" %resFile)
|
265 |
+
|
266 |
+
recall = 0
|
267 |
+
precision = 0
|
268 |
+
hmean = 0
|
269 |
+
detCorrect = 0
|
270 |
+
iouMat = np.empty([1,1])
|
271 |
+
gtPols = []
|
272 |
+
detPols = []
|
273 |
+
gtTrans = []
|
274 |
+
detTrans = []
|
275 |
+
gtPolPoints = []
|
276 |
+
detPolPoints = []
|
277 |
+
gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
|
278 |
+
detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
|
279 |
+
detMatchedNums = []
|
280 |
+
pairs = []
|
281 |
+
|
282 |
+
arrSampleConfidences = [];
|
283 |
+
arrSampleMatch = [];
|
284 |
+
sampleAP = 0;
|
285 |
+
|
286 |
+
evaluationLog = ""
|
287 |
+
|
288 |
+
pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
|
289 |
+
for n in range(len(pointsList)):
|
290 |
+
points = pointsList[n]
|
291 |
+
transcription = transcriptionsList[n]
|
292 |
+
dontCare = transcription == "###"
|
293 |
+
if evaluationParams['LTRB']:
|
294 |
+
gtRect = Rectangle(*points)
|
295 |
+
gtPol = rectangle_to_polygon(gtRect)
|
296 |
+
else:
|
297 |
+
gtPol = polygon_from_points(points)
|
298 |
+
gtPols.append(gtPol)
|
299 |
+
gtPolPoints.append(points)
|
300 |
+
|
301 |
+
#On word spotting we will filter some transcriptions with special characters
|
302 |
+
if evaluationParams['WORD_SPOTTING'] :
|
303 |
+
if dontCare == False :
|
304 |
+
if include_in_dictionary(transcription) == False :
|
305 |
+
dontCare = True
|
306 |
+
else:
|
307 |
+
transcription = include_in_dictionary_transcription(transcription)
|
308 |
+
|
309 |
+
gtTrans.append(transcription)
|
310 |
+
if dontCare:
|
311 |
+
gtDontCarePolsNum.append( len(gtPols)-1 )
|
312 |
+
|
313 |
+
evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
|
314 |
+
|
315 |
+
if resFile in subm:
|
316 |
+
|
317 |
+
detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
|
318 |
+
|
319 |
+
pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
|
320 |
+
|
321 |
+
for n in range(len(pointsList)):
|
322 |
+
points = pointsList[n]
|
323 |
+
transcription = transcriptionsList[n]
|
324 |
+
|
325 |
+
if evaluationParams['LTRB']:
|
326 |
+
detRect = Rectangle(*points)
|
327 |
+
detPol = rectangle_to_polygon(detRect)
|
328 |
+
else:
|
329 |
+
detPol = polygon_from_points(points)
|
330 |
+
detPols.append(detPol)
|
331 |
+
detPolPoints.append(points)
|
332 |
+
detTrans.append(transcription)
|
333 |
+
|
334 |
+
if len(gtDontCarePolsNum)>0 :
|
335 |
+
for dontCarePol in gtDontCarePolsNum:
|
336 |
+
dontCarePol = gtPols[dontCarePol]
|
337 |
+
intersected_area = get_intersection(dontCarePol,detPol)
|
338 |
+
pdDimensions = detPol.area()
|
339 |
+
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
340 |
+
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
|
341 |
+
detDontCarePolsNum.append( len(detPols)-1 )
|
342 |
+
break
|
343 |
+
|
344 |
+
evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
|
345 |
+
|
346 |
+
if len(gtPols)>0 and len(detPols)>0:
|
347 |
+
#Calculate IoU and precision matrixs
|
348 |
+
outputShape=[len(gtPols),len(detPols)]
|
349 |
+
iouMat = np.empty(outputShape)
|
350 |
+
gtRectMat = np.zeros(len(gtPols),np.int8)
|
351 |
+
detRectMat = np.zeros(len(detPols),np.int8)
|
352 |
+
for gtNum in range(len(gtPols)):
|
353 |
+
for detNum in range(len(detPols)):
|
354 |
+
pG = gtPols[gtNum]
|
355 |
+
pD = detPols[detNum]
|
356 |
+
iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
|
357 |
+
|
358 |
+
for gtNum in range(len(gtPols)):
|
359 |
+
for detNum in range(len(detPols)):
|
360 |
+
if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
|
361 |
+
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
|
362 |
+
gtRectMat[gtNum] = 1
|
363 |
+
detRectMat[detNum] = 1
|
364 |
+
#detection matched only if transcription is equal
|
365 |
+
if evaluationParams['WORD_SPOTTING']:
|
366 |
+
correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
|
367 |
+
else:
|
368 |
+
correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
|
369 |
+
detCorrect += (1 if correct else 0)
|
370 |
+
if correct:
|
371 |
+
detMatchedNums.append(detNum)
|
372 |
+
pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
|
373 |
+
evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
|
374 |
+
|
375 |
+
if evaluationParams['CONFIDENCES']:
|
376 |
+
for detNum in range(len(detPols)):
|
377 |
+
if detNum not in detDontCarePolsNum :
|
378 |
+
#we exclude the don't care detections
|
379 |
+
match = detNum in detMatchedNums
|
380 |
+
|
381 |
+
arrSampleConfidences.append(confidencesList[detNum])
|
382 |
+
arrSampleMatch.append(match)
|
383 |
+
|
384 |
+
arrGlobalConfidences.append(confidencesList[detNum]);
|
385 |
+
arrGlobalMatches.append(match);
|
386 |
+
|
387 |
+
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
388 |
+
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
389 |
+
if numGtCare == 0:
|
390 |
+
recall = float(1)
|
391 |
+
precision = float(0) if numDetCare >0 else float(1)
|
392 |
+
sampleAP = precision
|
393 |
+
else:
|
394 |
+
recall = float(detCorrect) / numGtCare
|
395 |
+
precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
|
396 |
+
if evaluationParams['CONFIDENCES']:
|
397 |
+
sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
|
398 |
+
|
399 |
+
hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
|
400 |
+
|
401 |
+
matchedSum += detCorrect
|
402 |
+
numGlobalCareGt += numGtCare
|
403 |
+
numGlobalCareDet += numDetCare
|
404 |
+
|
405 |
+
perSampleMetrics[resFile] = {
|
406 |
+
'precision':precision,
|
407 |
+
'recall':recall,
|
408 |
+
'hmean':hmean,
|
409 |
+
'pairs':pairs,
|
410 |
+
'AP':sampleAP,
|
411 |
+
'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
|
412 |
+
'gtPolPoints':gtPolPoints,
|
413 |
+
'detPolPoints':detPolPoints,
|
414 |
+
'gtTrans':gtTrans,
|
415 |
+
'detTrans':detTrans,
|
416 |
+
'gtDontCare':gtDontCarePolsNum,
|
417 |
+
'detDontCare':detDontCarePolsNum,
|
418 |
+
'evaluationParams': evaluationParams,
|
419 |
+
'evaluationLog': evaluationLog
|
420 |
+
}
|
421 |
+
|
422 |
+
# Compute AP
|
423 |
+
AP = 0
|
424 |
+
if evaluationParams['CONFIDENCES']:
|
425 |
+
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
|
426 |
+
|
427 |
+
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
|
428 |
+
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
|
429 |
+
methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
|
430 |
+
|
431 |
+
methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
|
432 |
+
|
433 |
+
resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
|
434 |
+
|
435 |
+
|
436 |
+
return resDict;
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
if __name__=='__main__':
|
441 |
+
'''
|
442 |
+
results_dir: result directory
|
443 |
+
score_det: score of detection bounding box
|
444 |
+
score_rec: score of the mask recognition branch
|
445 |
+
score_rec_seq: score of the sequence recognition branch
|
446 |
+
lexicon_type: 1 for generic; 2 for weak; 3 for strong
|
447 |
+
'''
|
448 |
+
results_dir = '../../../output/mixtrain/inference/icdar_2015_test/model_0250000_1440_results/'
|
449 |
+
lexicon_type = 3
|
450 |
+
score_det = 0.01
|
451 |
+
score_rec = 0.4
|
452 |
+
# score_rec_seq set to 0.7 for lexicon_type 3 or 2; 0.8 for lexicon_type 1
|
453 |
+
score_rec_seq = 0.7
|
454 |
+
evaluate_result_path = prepare_results_for_evaluation(results_dir,
|
455 |
+
lexicon_type=lexicon_type, cache_dir='./cache_files',
|
456 |
+
score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
|
457 |
+
p = {
|
458 |
+
'g': "../gt.zip",
|
459 |
+
's': evaluate_result_path
|
460 |
+
}
|
461 |
+
rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
|
evaluation/icdar2015/gt.zip
ADDED
Binary file (250 kB). View file
|
|
evaluation/rotated_icdar2013/e2e/prepare_results.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
sys.path.append('./')
|
6 |
+
import shapely
|
7 |
+
from shapely.geometry import Polygon,MultiPoint
|
8 |
+
import numpy as np
|
9 |
+
import editdistance
|
10 |
+
sys.path.append('../../')
|
11 |
+
from weighted_editdistance import weighted_edit_distance
|
12 |
+
from tqdm import tqdm
|
13 |
+
try:
|
14 |
+
import pickle
|
15 |
+
except ImportError:
|
16 |
+
import cPickle as pickle
|
17 |
+
|
18 |
+
def list_from_str(st):
|
19 |
+
line = st.split(',')
|
20 |
+
# box[0:4], polygon[4:12], word, seq_word, detection_score, rec_socre, seq_score, char_score_path
|
21 |
+
new_line = [float(a) for a in line[4:12]]+[float(line[-4])]+[line[-5]]+[line[-6]]+[float(line[-3])]+[float(line[-2])] + [line[-1]]
|
22 |
+
return new_line
|
23 |
+
|
24 |
+
def polygon_from_list(line):
|
25 |
+
"""
|
26 |
+
Create a shapely polygon object from gt or dt line.
|
27 |
+
"""
|
28 |
+
polygon_points = np.array(line).reshape(4, 2)
|
29 |
+
polygon = Polygon(polygon_points).convex_hull
|
30 |
+
return polygon
|
31 |
+
|
32 |
+
def polygon_iou(list1, list2):
|
33 |
+
"""
|
34 |
+
Intersection over union between two shapely polygons.
|
35 |
+
"""
|
36 |
+
polygon_points1 = np.array(list1).reshape(4, 2)
|
37 |
+
poly1 = Polygon(polygon_points1).convex_hull
|
38 |
+
polygon_points2 = np.array(list2).reshape(4, 2)
|
39 |
+
poly2 = Polygon(polygon_points2).convex_hull
|
40 |
+
union_poly = np.concatenate((polygon_points1,polygon_points2))
|
41 |
+
if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
|
42 |
+
iou = 0
|
43 |
+
else:
|
44 |
+
try:
|
45 |
+
inter_area = poly1.intersection(poly2).area
|
46 |
+
#union_area = poly1.area + poly2.area - inter_area
|
47 |
+
union_area = MultiPoint(union_poly).convex_hull.area
|
48 |
+
iou = float(inter_area) / (union_area+1e-6)
|
49 |
+
except shapely.geos.TopologicalError:
|
50 |
+
print('shapely.geos.TopologicalError occured, iou set to 0')
|
51 |
+
iou = 0
|
52 |
+
return iou
|
53 |
+
|
54 |
+
def nms(boxes,overlap):
|
55 |
+
rec_scores = [b[-2] for b in boxes]
|
56 |
+
indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k])
|
57 |
+
box_num = len(boxes)
|
58 |
+
nms_flag = [True]*box_num
|
59 |
+
for i in range(box_num):
|
60 |
+
ii = indices[i]
|
61 |
+
if not nms_flag[ii]:
|
62 |
+
continue
|
63 |
+
for j in range(box_num):
|
64 |
+
jj = indices[j]
|
65 |
+
if j == i:
|
66 |
+
continue
|
67 |
+
if not nms_flag[jj]:
|
68 |
+
continue
|
69 |
+
box1 = boxes[ii]
|
70 |
+
box2 = boxes[jj]
|
71 |
+
box1_score = rec_scores[ii]
|
72 |
+
box2_score = rec_scores[jj]
|
73 |
+
str1 = box1[9]
|
74 |
+
str2 = box2[9]
|
75 |
+
box_i = [box1[0],box1[1],box1[4],box1[5]]
|
76 |
+
box_j = [box2[0],box2[1],box2[4],box2[5]]
|
77 |
+
poly1 = polygon_from_list(box1[0:8])
|
78 |
+
poly2 = polygon_from_list(box2[0:8])
|
79 |
+
iou = polygon_iou(box1[0:8],box2[0:8])
|
80 |
+
thresh = overlap
|
81 |
+
|
82 |
+
if iou > thresh:
|
83 |
+
if box1_score > box2_score:
|
84 |
+
nms_flag[jj] = False
|
85 |
+
if box1_score == box2_score and poly1.area > poly2.area:
|
86 |
+
nms_flag[jj] = False
|
87 |
+
if box1_score == box2_score and poly1.area<=poly2.area:
|
88 |
+
nms_flag[ii] = False
|
89 |
+
break
|
90 |
+
|
91 |
+
return nms_flag
|
92 |
+
|
93 |
+
def packing(save_dir, cache_dir, pack_name):
|
94 |
+
files = os.listdir(save_dir)
|
95 |
+
if not os.path.exists(cache_dir):
|
96 |
+
os.mkdir(cache_dir)
|
97 |
+
os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*')
|
98 |
+
|
99 |
+
def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False):
|
100 |
+
'''
|
101 |
+
results_dir: result directory
|
102 |
+
score_det: score of detection bounding box
|
103 |
+
score_rec: score of the mask recognition branch
|
104 |
+
socre_rec_seq: score of the sequence recognition branch
|
105 |
+
overlap: overlap threshold used for nms
|
106 |
+
lexicon_type: 1 for generic; 2 for weak; 3 for strong
|
107 |
+
use_seq: use the recognition result of sequence branch
|
108 |
+
use_mix: use both the recognition result of the mask and sequence branches, selected by score
|
109 |
+
'''
|
110 |
+
print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix)
|
111 |
+
if not os.path.exists(cache_dir):
|
112 |
+
os.mkdir(cache_dir)
|
113 |
+
nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq))
|
114 |
+
if not os.path.exists(nms_dir):
|
115 |
+
os.mkdir(nms_dir)
|
116 |
+
if lexicon_type==1:
|
117 |
+
# generic lexicon
|
118 |
+
lexicon_path = '../../lexicons/ic13/GenericVocabulary_new.txt'
|
119 |
+
lexicon_fid=open(lexicon_path, 'r')
|
120 |
+
pair_list = open('../../lexicons/ic13/GenericVocabulary_pair_list.txt', 'r')
|
121 |
+
pairs = dict()
|
122 |
+
for line in pair_list.readlines():
|
123 |
+
line=line.strip()
|
124 |
+
word = line.split(' ')[0].upper()
|
125 |
+
word_gt = line[len(word)+1:]
|
126 |
+
pairs[word] = word_gt
|
127 |
+
lexicon_fid=open(lexicon_path, 'r')
|
128 |
+
lexicon=[]
|
129 |
+
for line in lexicon_fid.readlines():
|
130 |
+
line=line.strip()
|
131 |
+
lexicon.append(line)
|
132 |
+
if lexicon_type==2:
|
133 |
+
# weak lexicon
|
134 |
+
lexicon_path = '../../lexicons/ic13/ch4_test_vocabulary_new.txt'
|
135 |
+
lexicon_fid=open(lexicon_path, 'r')
|
136 |
+
pair_list = open('../../lexicons/ic13/ch4_test_vocabulary_pair_list.txt', 'r')
|
137 |
+
pairs = dict()
|
138 |
+
for line in pair_list.readlines():
|
139 |
+
line=line.strip()
|
140 |
+
word = line.split(' ')[0].upper()
|
141 |
+
word_gt = line[len(word)+1:]
|
142 |
+
pairs[word] = word_gt
|
143 |
+
lexicon_fid=open(lexicon_path, 'r')
|
144 |
+
lexicon=[]
|
145 |
+
for line in lexicon_fid.readlines():
|
146 |
+
line=line.strip()
|
147 |
+
lexicon.append(line)
|
148 |
+
|
149 |
+
for i in tqdm(range(1,234)):
|
150 |
+
img = 'img_'+str(i)+'.jpg'
|
151 |
+
gt_img = 'gt_img_'+str(i)+'.txt'
|
152 |
+
if lexicon_type==3:
|
153 |
+
# weak
|
154 |
+
lexicon_path = '../../lexicons/ic13/new_strong_lexicon/new_voc_img_' + str(i) + '.txt'
|
155 |
+
lexicon_fid=open(lexicon_path, 'r')
|
156 |
+
pair_list = open('../../lexicons/ic13/new_strong_lexicon/pair_voc_img_' + str(i) + '.txt', 'r')
|
157 |
+
pairs = dict()
|
158 |
+
for line in pair_list.readlines():
|
159 |
+
line=line.strip()
|
160 |
+
word = line.split(' ')[0].upper()
|
161 |
+
word_gt = line[len(word)+1:]
|
162 |
+
pairs[word] = word_gt
|
163 |
+
lexicon_fid=open(lexicon_path, 'r')
|
164 |
+
lexicon=[]
|
165 |
+
for line in lexicon_fid.readlines():
|
166 |
+
line=line.strip()
|
167 |
+
lexicon.append(line)
|
168 |
+
result_path = os.path.join(results_dir,'res_img_'+str(i)+'.txt')
|
169 |
+
if os.path.isfile(result_path):
|
170 |
+
with open(result_path,'r') as f:
|
171 |
+
dt_lines = [a.strip() for a in f.readlines()]
|
172 |
+
dt_lines = [list_from_str(dt) for dt in dt_lines]
|
173 |
+
else:
|
174 |
+
dt_lines = []
|
175 |
+
dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det]
|
176 |
+
nms_flag = nms(dt_lines,overlap)
|
177 |
+
boxes = []
|
178 |
+
for k in range(len(dt_lines)):
|
179 |
+
dt = dt_lines[k]
|
180 |
+
if nms_flag[k]:
|
181 |
+
if dt not in boxes:
|
182 |
+
boxes.append(dt)
|
183 |
+
|
184 |
+
with open(os.path.join(nms_dir,'res_img_'+str(i)+'.txt'),'w') as f:
|
185 |
+
for g in boxes:
|
186 |
+
gt_coors = [int(b) for b in g[0:8]]
|
187 |
+
with open('../../../' + g[-1], "rb") as input_file:
|
188 |
+
# with open(g[-1], "rb") as input_file:
|
189 |
+
dict_scores = pickle.load(input_file)
|
190 |
+
if use_char and use_seq:
|
191 |
+
if g[-2]>g[-3]:
|
192 |
+
word = g[-5]
|
193 |
+
scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
|
194 |
+
else:
|
195 |
+
word = g[-4]
|
196 |
+
scores = dict_scores['seg_char_scores']
|
197 |
+
elif use_seq:
|
198 |
+
word = g[-5]
|
199 |
+
scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
|
200 |
+
else:
|
201 |
+
word = g[-4]
|
202 |
+
scores = dict_scores['seg_char_scores']
|
203 |
+
if not use_lexicon:
|
204 |
+
match_word = word
|
205 |
+
match_dist = 0.
|
206 |
+
else:
|
207 |
+
match_word, match_dist = find_match_word(word, lexicon, pairs, scores, use_lexicon, weighted_ed)
|
208 |
+
if match_dist<1.5 or lexicon_type==1:
|
209 |
+
gt_coor_strs = [str(a) for a in gt_coors]+ [match_word]
|
210 |
+
f.write(','.join(gt_coor_strs)+'\r\n')
|
211 |
+
|
212 |
+
pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap)
|
213 |
+
|
214 |
+
packing(nms_dir,cache_dir,pack_name)
|
215 |
+
submit_file_path = os.path.join(cache_dir, pack_name+'.zip')
|
216 |
+
return submit_file_path
|
217 |
+
|
218 |
+
def find_match_word(rec_str, lexicon, pairs, scores_numpy, use_ed = True, weighted_ed = False):
|
219 |
+
if not use_ed:
|
220 |
+
return rec_str
|
221 |
+
rec_str = rec_str.upper()
|
222 |
+
dist_min = 100
|
223 |
+
dist_min_pre = 100
|
224 |
+
match_word = ''
|
225 |
+
match_dist = 100
|
226 |
+
if not weighted_ed:
|
227 |
+
for word in lexicon:
|
228 |
+
word = word.upper()
|
229 |
+
ed = editdistance.eval(rec_str, word)
|
230 |
+
length_dist = abs(len(word) - len(rec_str))
|
231 |
+
# dist = ed + length_dist
|
232 |
+
dist = ed
|
233 |
+
if dist<dist_min:
|
234 |
+
dist_min = dist
|
235 |
+
match_word = pairs[word]
|
236 |
+
match_dist = dist
|
237 |
+
return match_word, match_dist
|
238 |
+
else:
|
239 |
+
small_lexicon_dict = dict()
|
240 |
+
for word in lexicon:
|
241 |
+
word = word.upper()
|
242 |
+
ed = editdistance.eval(rec_str, word)
|
243 |
+
small_lexicon_dict[word] = ed
|
244 |
+
dist = ed
|
245 |
+
if dist<dist_min_pre:
|
246 |
+
dist_min_pre = dist
|
247 |
+
small_lexicon = []
|
248 |
+
for word in small_lexicon_dict:
|
249 |
+
if small_lexicon_dict[word]<=dist_min_pre+2:
|
250 |
+
small_lexicon.append(word)
|
251 |
+
|
252 |
+
for word in small_lexicon:
|
253 |
+
word = word.upper()
|
254 |
+
ed = weighted_edit_distance(rec_str, word, scores_numpy)
|
255 |
+
dist = ed
|
256 |
+
if dist<dist_min:
|
257 |
+
dist_min = dist
|
258 |
+
match_word = pairs[word]
|
259 |
+
match_dist = dist
|
260 |
+
return match_word, match_dist
|
261 |
+
|
262 |
+
|
263 |
+
def prepare_results_for_evaluation(results_dir, use_lexicon, cache_dir, score_det, score_rec, score_rec_seq):
|
264 |
+
if not os.path.isdir(cache_dir):
|
265 |
+
os.mkdir(cache_dir)
|
266 |
+
result_path = test_single(results_dir,score_det=score_det,score_rec=score_rec,score_rec_seq=score_rec_seq,overlap=0.2,cache_dir=cache_dir,lexicon_type=3, use_lexicon=use_lexicon, weighted_ed=True, use_seq=True, use_char=True, mix=True)
|
267 |
+
return result_path
|
evaluation/rotated_icdar2013/e2e/rrc_evaluation_funcs.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
#encoding: UTF-8
|
3 |
+
import json
|
4 |
+
import sys;sys.path.append('./')
|
5 |
+
import zipfile
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import codecs
|
10 |
+
import importlib
|
11 |
+
try:
|
12 |
+
from StringIO import StringIO
|
13 |
+
except ImportError:
|
14 |
+
from io import StringIO
|
15 |
+
|
16 |
+
def print_help():
|
17 |
+
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
|
18 |
+
sys.exit(2)
|
19 |
+
|
20 |
+
|
21 |
+
def load_zip_file_keys(file,fileNameRegExp=''):
|
22 |
+
"""
|
23 |
+
Returns an array with the entries of the ZIP file that match with the regular expression.
|
24 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
25 |
+
"""
|
26 |
+
try:
|
27 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
28 |
+
except :
|
29 |
+
raise Exception('Error loading the ZIP archive.')
|
30 |
+
|
31 |
+
pairs = []
|
32 |
+
|
33 |
+
for name in archive.namelist():
|
34 |
+
addFile = True
|
35 |
+
keyName = name
|
36 |
+
if fileNameRegExp!="":
|
37 |
+
m = re.match(fileNameRegExp,name)
|
38 |
+
if m == None:
|
39 |
+
addFile = False
|
40 |
+
else:
|
41 |
+
if len(m.groups())>0:
|
42 |
+
keyName = m.group(1)
|
43 |
+
|
44 |
+
if addFile:
|
45 |
+
pairs.append( keyName )
|
46 |
+
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def load_zip_file(file,fileNameRegExp='',allEntries=False):
|
51 |
+
"""
|
52 |
+
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
|
53 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
54 |
+
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
58 |
+
except :
|
59 |
+
raise Exception('Error loading the ZIP archive')
|
60 |
+
|
61 |
+
pairs = []
|
62 |
+
for name in archive.namelist():
|
63 |
+
addFile = True
|
64 |
+
keyName = name
|
65 |
+
if fileNameRegExp!="":
|
66 |
+
m = re.match(fileNameRegExp,name)
|
67 |
+
if m == None:
|
68 |
+
addFile = False
|
69 |
+
else:
|
70 |
+
if len(m.groups())>0:
|
71 |
+
keyName = m.group(1)
|
72 |
+
|
73 |
+
if addFile:
|
74 |
+
pairs.append( [ keyName , archive.read(name)] )
|
75 |
+
else:
|
76 |
+
if allEntries:
|
77 |
+
raise Exception('ZIP entry not valid: %s' %name)
|
78 |
+
|
79 |
+
return dict(pairs)
|
80 |
+
|
81 |
+
def decode_utf8(raw):
|
82 |
+
"""
|
83 |
+
Returns a Unicode object on success, or None on failure
|
84 |
+
"""
|
85 |
+
try:
|
86 |
+
raw = codecs.decode(raw,'utf-8', 'replace')
|
87 |
+
#extracts BOM if exists
|
88 |
+
raw = raw.encode('utf8')
|
89 |
+
if raw.startswith(codecs.BOM_UTF8):
|
90 |
+
raw = raw.replace(codecs.BOM_UTF8, '', 1)
|
91 |
+
return raw.decode('utf-8')
|
92 |
+
except:
|
93 |
+
return None
|
94 |
+
|
95 |
+
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
96 |
+
"""
|
97 |
+
This function validates that all lines of the file calling the Line validation function for each line
|
98 |
+
"""
|
99 |
+
utf8File = decode_utf8(file_contents)
|
100 |
+
if (utf8File is None) :
|
101 |
+
raise Exception("The file %s is not UTF-8" %fileName)
|
102 |
+
|
103 |
+
lines = utf8File.split( "\r\n" if CRLF else "\n" )
|
104 |
+
for line in lines:
|
105 |
+
line = line.replace("\r","").replace("\n","")
|
106 |
+
if(line != ""):
|
107 |
+
try:
|
108 |
+
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
109 |
+
except Exception as e:
|
110 |
+
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
|
115 |
+
"""
|
116 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
117 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
118 |
+
Posible values are:
|
119 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
120 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
121 |
+
"""
|
122 |
+
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
123 |
+
|
124 |
+
|
125 |
+
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
126 |
+
"""
|
127 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
128 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
129 |
+
Posible values are:
|
130 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
131 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
132 |
+
Returns values from a textline. Points , [Confidences], [Transcriptions]
|
133 |
+
"""
|
134 |
+
confidence = 0.0
|
135 |
+
transcription = "";
|
136 |
+
points = []
|
137 |
+
|
138 |
+
numPoints = 4;
|
139 |
+
|
140 |
+
if LTRB:
|
141 |
+
|
142 |
+
numPoints = 4;
|
143 |
+
|
144 |
+
if withTranscription and withConfidence:
|
145 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
146 |
+
if m == None :
|
147 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
148 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
|
149 |
+
elif withConfidence:
|
150 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
151 |
+
if m == None :
|
152 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
|
153 |
+
elif withTranscription:
|
154 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
|
155 |
+
if m == None :
|
156 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
|
157 |
+
else:
|
158 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
|
159 |
+
if m == None :
|
160 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
|
161 |
+
|
162 |
+
xmin = int(m.group(1))
|
163 |
+
ymin = int(m.group(2))
|
164 |
+
xmax = int(m.group(3))
|
165 |
+
ymax = int(m.group(4))
|
166 |
+
if(xmax<xmin):
|
167 |
+
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
|
168 |
+
if(ymax<ymin):
|
169 |
+
raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
|
170 |
+
|
171 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
172 |
+
|
173 |
+
if (imWidth>0 and imHeight>0):
|
174 |
+
validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
|
175 |
+
validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
|
176 |
+
|
177 |
+
else:
|
178 |
+
|
179 |
+
numPoints = 8;
|
180 |
+
|
181 |
+
if withTranscription and withConfidence:
|
182 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
183 |
+
if m == None :
|
184 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
|
185 |
+
elif withConfidence:
|
186 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
187 |
+
if m == None :
|
188 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
|
189 |
+
elif withTranscription:
|
190 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
|
191 |
+
if m == None :
|
192 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
|
193 |
+
else:
|
194 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
|
195 |
+
if m == None :
|
196 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
|
197 |
+
|
198 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
199 |
+
|
200 |
+
validate_clockwise_points(points)
|
201 |
+
|
202 |
+
if (imWidth>0 and imHeight>0):
|
203 |
+
validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
|
204 |
+
validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
|
205 |
+
validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
|
206 |
+
validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
|
207 |
+
|
208 |
+
|
209 |
+
if withConfidence:
|
210 |
+
try:
|
211 |
+
confidence = float(m.group(numPoints+1))
|
212 |
+
except ValueError:
|
213 |
+
raise Exception("Confidence value must be a float")
|
214 |
+
|
215 |
+
if withTranscription:
|
216 |
+
posTranscription = numPoints + (2 if withConfidence else 1)
|
217 |
+
transcription = m.group(posTranscription)
|
218 |
+
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
|
219 |
+
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
|
220 |
+
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
|
221 |
+
|
222 |
+
return points,confidence,transcription
|
223 |
+
|
224 |
+
|
225 |
+
def validate_point_inside_bounds(x,y,imWidth,imHeight):
|
226 |
+
if(x<0 or x>imWidth):
|
227 |
+
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
|
228 |
+
if(y<0 or y>imHeight):
|
229 |
+
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
|
230 |
+
|
231 |
+
def validate_clockwise_points(points):
|
232 |
+
"""
|
233 |
+
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
|
234 |
+
"""
|
235 |
+
|
236 |
+
if len(points) != 8:
|
237 |
+
raise Exception("Points list not valid." + str(len(points)))
|
238 |
+
|
239 |
+
point = [
|
240 |
+
[int(points[0]) , int(points[1])],
|
241 |
+
[int(points[2]) , int(points[3])],
|
242 |
+
[int(points[4]) , int(points[5])],
|
243 |
+
[int(points[6]) , int(points[7])]
|
244 |
+
]
|
245 |
+
edge = [
|
246 |
+
( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
|
247 |
+
( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
|
248 |
+
( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
|
249 |
+
( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
|
250 |
+
]
|
251 |
+
|
252 |
+
summatory = edge[0] + edge[1] + edge[2] + edge[3];
|
253 |
+
if summatory>0:
|
254 |
+
raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
|
255 |
+
|
256 |
+
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
|
257 |
+
"""
|
258 |
+
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
|
259 |
+
xmin,ymin,xmax,ymax,[confidence],[transcription]
|
260 |
+
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
|
261 |
+
"""
|
262 |
+
pointsList = []
|
263 |
+
transcriptionsList = []
|
264 |
+
confidencesList = []
|
265 |
+
|
266 |
+
lines = content.split( "\r\n" if CRLF else "\n" )
|
267 |
+
for line in lines:
|
268 |
+
line = line.replace("\r","").replace("\n","")
|
269 |
+
if(line != "") :
|
270 |
+
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
|
271 |
+
pointsList.append(points)
|
272 |
+
transcriptionsList.append(transcription)
|
273 |
+
confidencesList.append(confidence)
|
274 |
+
|
275 |
+
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
|
276 |
+
import numpy as np
|
277 |
+
sorted_ind = np.argsort(-np.array(confidencesList))
|
278 |
+
confidencesList = [confidencesList[i] for i in sorted_ind]
|
279 |
+
pointsList = [pointsList[i] for i in sorted_ind]
|
280 |
+
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
|
281 |
+
|
282 |
+
return pointsList,confidencesList,transcriptionsList
|
283 |
+
|
284 |
+
def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
|
285 |
+
"""
|
286 |
+
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
|
287 |
+
Params:
|
288 |
+
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
|
289 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
290 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
291 |
+
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
|
292 |
+
"""
|
293 |
+
|
294 |
+
if (p == None):
|
295 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
296 |
+
if(len(sys.argv)<3):
|
297 |
+
print_help()
|
298 |
+
|
299 |
+
evalParams = default_evaluation_params_fn()
|
300 |
+
if 'p' in p.keys():
|
301 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
302 |
+
|
303 |
+
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
|
304 |
+
try:
|
305 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
306 |
+
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
|
307 |
+
resDict.update(evalData)
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
resDict['Message']= str(e)
|
311 |
+
resDict['calculated']=False
|
312 |
+
|
313 |
+
if 'o' in p:
|
314 |
+
if not os.path.exists(p['o']):
|
315 |
+
os.makedirs(p['o'])
|
316 |
+
|
317 |
+
resultsOutputname = p['o'] + '/results.zip'
|
318 |
+
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
|
319 |
+
|
320 |
+
del resDict['per_sample']
|
321 |
+
if 'output_items' in resDict.keys():
|
322 |
+
del resDict['output_items']
|
323 |
+
|
324 |
+
outZip.writestr('method.json',json.dumps(resDict))
|
325 |
+
|
326 |
+
if not resDict['calculated']:
|
327 |
+
if show_result:
|
328 |
+
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
|
329 |
+
if 'o' in p:
|
330 |
+
outZip.close()
|
331 |
+
return resDict
|
332 |
+
|
333 |
+
if 'o' in p:
|
334 |
+
if per_sample == True:
|
335 |
+
for k,v in evalData['per_sample'].items():
|
336 |
+
outZip.writestr( k + '.json',json.dumps(v))
|
337 |
+
|
338 |
+
if 'output_items' in evalData.keys():
|
339 |
+
for k, v in evalData['output_items'].items():
|
340 |
+
outZip.writestr( k,v)
|
341 |
+
|
342 |
+
outZip.close()
|
343 |
+
|
344 |
+
if show_result:
|
345 |
+
sys.stdout.write("Calculated!")
|
346 |
+
sys.stdout.write(json.dumps(resDict['method']))
|
347 |
+
|
348 |
+
return resDict
|
349 |
+
|
350 |
+
|
351 |
+
def main_validation(default_evaluation_params_fn,validate_data_fn):
|
352 |
+
"""
|
353 |
+
This process validates a method
|
354 |
+
Params:
|
355 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
356 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
357 |
+
"""
|
358 |
+
try:
|
359 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
360 |
+
evalParams = default_evaluation_params_fn()
|
361 |
+
if 'p' in p.keys():
|
362 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
363 |
+
|
364 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
365 |
+
print('SUCCESS')
|
366 |
+
sys.exit(0)
|
367 |
+
except Exception as e:
|
368 |
+
print(str(e))
|
369 |
+
sys.exit(101)
|
evaluation/rotated_icdar2013/e2e/script.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# encoding=utf8
|
4 |
+
from collections import namedtuple
|
5 |
+
import rrc_evaluation_funcs
|
6 |
+
import importlib
|
7 |
+
from prepare_results import prepare_results_for_evaluation
|
8 |
+
|
9 |
+
def evaluation_imports():
|
10 |
+
"""
|
11 |
+
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
|
12 |
+
"""
|
13 |
+
return {
|
14 |
+
'Polygon':'plg',
|
15 |
+
'numpy':'np'
|
16 |
+
}
|
17 |
+
|
18 |
+
def default_evaluation_params():
|
19 |
+
"""
|
20 |
+
default_evaluation_params: Default parameters to use for the validation and evaluation.
|
21 |
+
"""
|
22 |
+
return {
|
23 |
+
'IOU_CONSTRAINT' :0.5,
|
24 |
+
'AREA_PRECISION_CONSTRAINT' :0.5,
|
25 |
+
'WORD_SPOTTING' :False,
|
26 |
+
'MIN_LENGTH_CARE_WORD' :3,
|
27 |
+
'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
|
28 |
+
'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
|
29 |
+
'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
|
30 |
+
'CRLF':False, # Lines are delimited by Windows CRLF format
|
31 |
+
'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
|
32 |
+
'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
|
33 |
+
'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
|
34 |
+
}
|
35 |
+
|
36 |
+
def validate_data(gtFilePath, submFilePath, evaluationParams):
|
37 |
+
"""
|
38 |
+
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
|
39 |
+
Validates also that there are no missing files in the folder.
|
40 |
+
If some error detected, the method raises the error
|
41 |
+
"""
|
42 |
+
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
43 |
+
|
44 |
+
subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
|
45 |
+
|
46 |
+
#Validate format of GroundTruth
|
47 |
+
for k in gt:
|
48 |
+
rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
|
49 |
+
|
50 |
+
#Validate format of results
|
51 |
+
for k in subm:
|
52 |
+
if (k in gt) == False :
|
53 |
+
raise Exception("The sample %s not present in GT" %k)
|
54 |
+
|
55 |
+
rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
|
56 |
+
|
57 |
+
|
58 |
+
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
|
59 |
+
"""
|
60 |
+
Method evaluate_method: evaluate method and returns the results
|
61 |
+
Results. Dictionary with the following values:
|
62 |
+
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
|
63 |
+
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
|
64 |
+
"""
|
65 |
+
for module,alias in evaluation_imports().items():
|
66 |
+
globals()[alias] = importlib.import_module(module)
|
67 |
+
|
68 |
+
def polygon_from_points(points,correctOffset=False):
|
69 |
+
"""
|
70 |
+
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
|
71 |
+
"""
|
72 |
+
|
73 |
+
if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax
|
74 |
+
points[2] -= 1
|
75 |
+
points[4] -= 1
|
76 |
+
points[5] -= 1
|
77 |
+
points[7] -= 1
|
78 |
+
|
79 |
+
resBoxes=np.empty([1,8],dtype='int32')
|
80 |
+
resBoxes[0,0]=int(points[0])
|
81 |
+
resBoxes[0,4]=int(points[1])
|
82 |
+
resBoxes[0,1]=int(points[2])
|
83 |
+
resBoxes[0,5]=int(points[3])
|
84 |
+
resBoxes[0,2]=int(points[4])
|
85 |
+
resBoxes[0,6]=int(points[5])
|
86 |
+
resBoxes[0,3]=int(points[6])
|
87 |
+
resBoxes[0,7]=int(points[7])
|
88 |
+
pointMat = resBoxes[0].reshape([2,4]).T
|
89 |
+
return plg.Polygon( pointMat)
|
90 |
+
|
91 |
+
def rectangle_to_polygon(rect):
|
92 |
+
resBoxes=np.empty([1,8],dtype='int32')
|
93 |
+
resBoxes[0,0]=int(rect.xmin)
|
94 |
+
resBoxes[0,4]=int(rect.ymax)
|
95 |
+
resBoxes[0,1]=int(rect.xmin)
|
96 |
+
resBoxes[0,5]=int(rect.ymin)
|
97 |
+
resBoxes[0,2]=int(rect.xmax)
|
98 |
+
resBoxes[0,6]=int(rect.ymin)
|
99 |
+
resBoxes[0,3]=int(rect.xmax)
|
100 |
+
resBoxes[0,7]=int(rect.ymax)
|
101 |
+
|
102 |
+
pointMat = resBoxes[0].reshape([2,4]).T
|
103 |
+
|
104 |
+
return plg.Polygon( pointMat)
|
105 |
+
|
106 |
+
def rectangle_to_points(rect):
|
107 |
+
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
|
108 |
+
return points
|
109 |
+
|
110 |
+
def get_union(pD,pG):
|
111 |
+
areaA = pD.area();
|
112 |
+
areaB = pG.area();
|
113 |
+
return areaA + areaB - get_intersection(pD, pG);
|
114 |
+
|
115 |
+
def get_intersection_over_union(pD,pG):
|
116 |
+
try:
|
117 |
+
return get_intersection(pD, pG) / get_union(pD, pG);
|
118 |
+
except:
|
119 |
+
return 0
|
120 |
+
|
121 |
+
def get_intersection(pD,pG):
|
122 |
+
pInt = pD & pG
|
123 |
+
if len(pInt) == 0:
|
124 |
+
return 0
|
125 |
+
return pInt.area()
|
126 |
+
|
127 |
+
def compute_ap(confList, matchList,numGtCare):
|
128 |
+
correct = 0
|
129 |
+
AP = 0
|
130 |
+
if len(confList)>0:
|
131 |
+
confList = np.array(confList)
|
132 |
+
matchList = np.array(matchList)
|
133 |
+
sorted_ind = np.argsort(-confList)
|
134 |
+
confList = confList[sorted_ind]
|
135 |
+
matchList = matchList[sorted_ind]
|
136 |
+
for n in range(len(confList)):
|
137 |
+
match = matchList[n]
|
138 |
+
if match:
|
139 |
+
correct += 1
|
140 |
+
AP += float(correct)/(n + 1)
|
141 |
+
|
142 |
+
if numGtCare>0:
|
143 |
+
AP /= numGtCare
|
144 |
+
|
145 |
+
return AP
|
146 |
+
|
147 |
+
def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
|
148 |
+
|
149 |
+
if onlyRemoveFirstLastCharacterGT:
|
150 |
+
#special characters in GT are allowed only at initial or final position
|
151 |
+
if (transGt==transDet):
|
152 |
+
return True
|
153 |
+
|
154 |
+
if specialCharacters.find(transGt[0])>-1:
|
155 |
+
if transGt[1:]==transDet:
|
156 |
+
return True
|
157 |
+
|
158 |
+
if specialCharacters.find(transGt[-1])>-1:
|
159 |
+
if transGt[0:len(transGt)-1]==transDet:
|
160 |
+
return True
|
161 |
+
|
162 |
+
if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
|
163 |
+
if transGt[1:len(transGt)-1]==transDet:
|
164 |
+
return True
|
165 |
+
return False
|
166 |
+
else:
|
167 |
+
#Special characters are removed from the begining and the end of both Detection and GroundTruth
|
168 |
+
while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
|
169 |
+
transGt = transGt[1:]
|
170 |
+
|
171 |
+
while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
|
172 |
+
transDet = transDet[1:]
|
173 |
+
|
174 |
+
while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
|
175 |
+
transGt = transGt[0:len(transGt)-1]
|
176 |
+
|
177 |
+
while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
|
178 |
+
transDet = transDet[0:len(transDet)-1]
|
179 |
+
|
180 |
+
return transGt == transDet
|
181 |
+
|
182 |
+
|
183 |
+
def include_in_dictionary(transcription):
|
184 |
+
"""
|
185 |
+
Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
|
186 |
+
"""
|
187 |
+
#special case 's at final
|
188 |
+
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
|
189 |
+
transcription = transcription[0:len(transcription)-2]
|
190 |
+
|
191 |
+
#hypens at init or final of the word
|
192 |
+
transcription = transcription.strip('-');
|
193 |
+
|
194 |
+
specialCharacters = "'!?.:,*\"()·[]/";
|
195 |
+
for character in specialCharacters:
|
196 |
+
transcription = transcription.replace(character,' ')
|
197 |
+
|
198 |
+
transcription = transcription.strip()
|
199 |
+
|
200 |
+
if len(transcription) != len(transcription.replace(" ","")) :
|
201 |
+
return False;
|
202 |
+
|
203 |
+
if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
|
204 |
+
return False;
|
205 |
+
|
206 |
+
notAllowed = "×÷·";
|
207 |
+
|
208 |
+
range1 = [ ord(u'a'), ord(u'z') ]
|
209 |
+
range2 = [ ord(u'A'), ord(u'Z') ]
|
210 |
+
range3 = [ ord(u'À'), ord(u'ƿ') ]
|
211 |
+
range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
|
212 |
+
range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
|
213 |
+
range6 = [ ord(u'-'), ord(u'-') ]
|
214 |
+
|
215 |
+
for char in transcription :
|
216 |
+
charCode = ord(char)
|
217 |
+
if(notAllowed.find(char) != -1):
|
218 |
+
return False
|
219 |
+
|
220 |
+
valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
|
221 |
+
if valid == False:
|
222 |
+
return False
|
223 |
+
|
224 |
+
return True
|
225 |
+
|
226 |
+
def include_in_dictionary_transcription(transcription):
|
227 |
+
"""
|
228 |
+
Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
|
229 |
+
"""
|
230 |
+
#special case 's at final
|
231 |
+
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
|
232 |
+
transcription = transcription[0:len(transcription)-2]
|
233 |
+
|
234 |
+
#hypens at init or final of the word
|
235 |
+
transcription = transcription.strip('-');
|
236 |
+
|
237 |
+
specialCharacters = "'!?.:,*\"()·[]/";
|
238 |
+
for character in specialCharacters:
|
239 |
+
transcription = transcription.replace(character,' ')
|
240 |
+
|
241 |
+
transcription = transcription.strip()
|
242 |
+
|
243 |
+
return transcription
|
244 |
+
|
245 |
+
perSampleMetrics = {}
|
246 |
+
|
247 |
+
matchedSum = 0
|
248 |
+
|
249 |
+
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
250 |
+
|
251 |
+
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
252 |
+
subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
|
253 |
+
|
254 |
+
numGlobalCareGt = 0;
|
255 |
+
numGlobalCareDet = 0;
|
256 |
+
|
257 |
+
arrGlobalConfidences = [];
|
258 |
+
arrGlobalMatches = [];
|
259 |
+
|
260 |
+
for resFile in gt:
|
261 |
+
|
262 |
+
gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
|
263 |
+
if (gtFile is None) :
|
264 |
+
raise Exception("The file %s is not UTF-8" %resFile)
|
265 |
+
|
266 |
+
recall = 0
|
267 |
+
precision = 0
|
268 |
+
hmean = 0
|
269 |
+
detCorrect = 0
|
270 |
+
iouMat = np.empty([1,1])
|
271 |
+
gtPols = []
|
272 |
+
detPols = []
|
273 |
+
gtTrans = []
|
274 |
+
detTrans = []
|
275 |
+
gtPolPoints = []
|
276 |
+
detPolPoints = []
|
277 |
+
gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
|
278 |
+
detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
|
279 |
+
detMatchedNums = []
|
280 |
+
pairs = []
|
281 |
+
|
282 |
+
arrSampleConfidences = [];
|
283 |
+
arrSampleMatch = [];
|
284 |
+
sampleAP = 0;
|
285 |
+
|
286 |
+
evaluationLog = ""
|
287 |
+
|
288 |
+
pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
|
289 |
+
for n in range(len(pointsList)):
|
290 |
+
points = pointsList[n]
|
291 |
+
transcription = transcriptionsList[n]
|
292 |
+
dontCare = transcription == "###"
|
293 |
+
if evaluationParams['LTRB']:
|
294 |
+
gtRect = Rectangle(*points)
|
295 |
+
gtPol = rectangle_to_polygon(gtRect)
|
296 |
+
else:
|
297 |
+
gtPol = polygon_from_points(points)
|
298 |
+
gtPols.append(gtPol)
|
299 |
+
gtPolPoints.append(points)
|
300 |
+
|
301 |
+
#On word spotting we will filter some transcriptions with special characters
|
302 |
+
if evaluationParams['WORD_SPOTTING'] :
|
303 |
+
if dontCare == False :
|
304 |
+
if include_in_dictionary(transcription) == False :
|
305 |
+
dontCare = True
|
306 |
+
else:
|
307 |
+
transcription = include_in_dictionary_transcription(transcription)
|
308 |
+
|
309 |
+
gtTrans.append(transcription)
|
310 |
+
if dontCare:
|
311 |
+
gtDontCarePolsNum.append( len(gtPols)-1 )
|
312 |
+
|
313 |
+
evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
|
314 |
+
|
315 |
+
if resFile in subm:
|
316 |
+
|
317 |
+
detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
|
318 |
+
|
319 |
+
pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
|
320 |
+
|
321 |
+
for n in range(len(pointsList)):
|
322 |
+
points = pointsList[n]
|
323 |
+
transcription = transcriptionsList[n]
|
324 |
+
|
325 |
+
if evaluationParams['LTRB']:
|
326 |
+
detRect = Rectangle(*points)
|
327 |
+
detPol = rectangle_to_polygon(detRect)
|
328 |
+
else:
|
329 |
+
detPol = polygon_from_points(points)
|
330 |
+
detPols.append(detPol)
|
331 |
+
detPolPoints.append(points)
|
332 |
+
detTrans.append(transcription)
|
333 |
+
|
334 |
+
if len(gtDontCarePolsNum)>0 :
|
335 |
+
for dontCarePol in gtDontCarePolsNum:
|
336 |
+
dontCarePol = gtPols[dontCarePol]
|
337 |
+
intersected_area = get_intersection(dontCarePol,detPol)
|
338 |
+
pdDimensions = detPol.area()
|
339 |
+
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
340 |
+
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
|
341 |
+
detDontCarePolsNum.append( len(detPols)-1 )
|
342 |
+
break
|
343 |
+
|
344 |
+
evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
|
345 |
+
|
346 |
+
if len(gtPols)>0 and len(detPols)>0:
|
347 |
+
#Calculate IoU and precision matrixs
|
348 |
+
outputShape=[len(gtPols),len(detPols)]
|
349 |
+
iouMat = np.empty(outputShape)
|
350 |
+
gtRectMat = np.zeros(len(gtPols),np.int8)
|
351 |
+
detRectMat = np.zeros(len(detPols),np.int8)
|
352 |
+
for gtNum in range(len(gtPols)):
|
353 |
+
for detNum in range(len(detPols)):
|
354 |
+
pG = gtPols[gtNum]
|
355 |
+
pD = detPols[detNum]
|
356 |
+
iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
|
357 |
+
|
358 |
+
for gtNum in range(len(gtPols)):
|
359 |
+
for detNum in range(len(detPols)):
|
360 |
+
if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
|
361 |
+
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
|
362 |
+
gtRectMat[gtNum] = 1
|
363 |
+
detRectMat[detNum] = 1
|
364 |
+
#detection matched only if transcription is equal
|
365 |
+
if evaluationParams['WORD_SPOTTING']:
|
366 |
+
correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
|
367 |
+
else:
|
368 |
+
correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
|
369 |
+
detCorrect += (1 if correct else 0)
|
370 |
+
if correct:
|
371 |
+
detMatchedNums.append(detNum)
|
372 |
+
pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
|
373 |
+
evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
|
374 |
+
|
375 |
+
if evaluationParams['CONFIDENCES']:
|
376 |
+
for detNum in range(len(detPols)):
|
377 |
+
if detNum not in detDontCarePolsNum :
|
378 |
+
#we exclude the don't care detections
|
379 |
+
match = detNum in detMatchedNums
|
380 |
+
|
381 |
+
arrSampleConfidences.append(confidencesList[detNum])
|
382 |
+
arrSampleMatch.append(match)
|
383 |
+
|
384 |
+
arrGlobalConfidences.append(confidencesList[detNum]);
|
385 |
+
arrGlobalMatches.append(match);
|
386 |
+
|
387 |
+
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
388 |
+
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
389 |
+
if numGtCare == 0:
|
390 |
+
recall = float(1)
|
391 |
+
precision = float(0) if numDetCare >0 else float(1)
|
392 |
+
sampleAP = precision
|
393 |
+
else:
|
394 |
+
recall = float(detCorrect) / numGtCare
|
395 |
+
precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
|
396 |
+
if evaluationParams['CONFIDENCES']:
|
397 |
+
sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
|
398 |
+
|
399 |
+
hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
|
400 |
+
|
401 |
+
matchedSum += detCorrect
|
402 |
+
numGlobalCareGt += numGtCare
|
403 |
+
numGlobalCareDet += numDetCare
|
404 |
+
|
405 |
+
perSampleMetrics[resFile] = {
|
406 |
+
'precision':precision,
|
407 |
+
'recall':recall,
|
408 |
+
'hmean':hmean,
|
409 |
+
'pairs':pairs,
|
410 |
+
'AP':sampleAP,
|
411 |
+
'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
|
412 |
+
'gtPolPoints':gtPolPoints,
|
413 |
+
'detPolPoints':detPolPoints,
|
414 |
+
'gtTrans':gtTrans,
|
415 |
+
'detTrans':detTrans,
|
416 |
+
'gtDontCare':gtDontCarePolsNum,
|
417 |
+
'detDontCare':detDontCarePolsNum,
|
418 |
+
'evaluationParams': evaluationParams,
|
419 |
+
'evaluationLog': evaluationLog
|
420 |
+
}
|
421 |
+
|
422 |
+
# Compute AP
|
423 |
+
AP = 0
|
424 |
+
if evaluationParams['CONFIDENCES']:
|
425 |
+
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
|
426 |
+
|
427 |
+
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
|
428 |
+
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
|
429 |
+
methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
|
430 |
+
|
431 |
+
methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
|
432 |
+
|
433 |
+
resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
|
434 |
+
|
435 |
+
|
436 |
+
return resDict;
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
if __name__=='__main__':
|
441 |
+
'''
|
442 |
+
results_dir: result directory
|
443 |
+
score_det: score of detection bounding box
|
444 |
+
score_rec: score of the mask recognition branch
|
445 |
+
score_rec_seq: score of the sequence recognition branch
|
446 |
+
lexicon_type: 1 for generic; 2 for weak; 3 for strong
|
447 |
+
'''
|
448 |
+
angle = 45
|
449 |
+
results_dir = '../../../output/mixtrain/inference/rotated_ic13_test_' + str(angle) + '/model_0250000_1000_results/'
|
450 |
+
score_rec_seq = 0.9
|
451 |
+
score_rec = 0.4
|
452 |
+
score_det = 0.1
|
453 |
+
evaluate_result_path = prepare_results_for_evaluation(results_dir,
|
454 |
+
use_lexicon=False, cache_dir='./cache_files',
|
455 |
+
score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
|
456 |
+
p = {
|
457 |
+
'g': '../gt/gt_'+str(angle)+'.zip',
|
458 |
+
's': evaluate_result_path
|
459 |
+
}
|
460 |
+
rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
|
evaluation/rotated_icdar2013/gt/gt.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_-15.zip
ADDED
Binary file (64.9 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_-30.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_-45.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_-60.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_-75.zip
ADDED
Binary file (64.9 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_-90.zip
ADDED
Binary file (59.9 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_0.zip
ADDED
Binary file (59.6 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_15.zip
ADDED
Binary file (65 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_30.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_45.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_60.zip
ADDED
Binary file (65.2 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_75.zip
ADDED
Binary file (64.9 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_85.zip
ADDED
Binary file (64.4 kB). View file
|
|
evaluation/rotated_icdar2013/gt/gt_90.zip
ADDED
Binary file (59.7 kB). View file
|
|
evaluation/totaltext/e2e/prepare_results.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
sys.path.append('./')
|
7 |
+
import shapely
|
8 |
+
from shapely.geometry import Polygon,MultiPoint
|
9 |
+
import numpy as np
|
10 |
+
import editdistance
|
11 |
+
sys.path.append('../../')
|
12 |
+
from weighted_editdistance import weighted_edit_distance
|
13 |
+
from tqdm import tqdm
|
14 |
+
try:
|
15 |
+
import pickle
|
16 |
+
except ImportError:
|
17 |
+
import cPickle as pickle
|
18 |
+
|
19 |
+
def list_from_str(st):
|
20 |
+
line = st.split(';')
|
21 |
+
segms = line[1].split(',')
|
22 |
+
scores = line[2].split(',')
|
23 |
+
new_line = [float(a) for a in segms]+[float(scores[-4])]+[scores[-5]]+[scores[-6]]+[float(scores[-3])]+[float(scores[-2])] + [scores[-1]]
|
24 |
+
return new_line
|
25 |
+
|
26 |
+
def polygon_from_list(line):
|
27 |
+
"""
|
28 |
+
Create a shapely polygon object from gt or dt line.
|
29 |
+
"""
|
30 |
+
polygon_points = np.array(line).reshape(-1, 2)
|
31 |
+
polygon = Polygon(polygon_points).convex_hull
|
32 |
+
return polygon
|
33 |
+
|
34 |
+
def polygon_iou(list1, list2):
|
35 |
+
"""
|
36 |
+
Intersection over union between two shapely polygons.
|
37 |
+
"""
|
38 |
+
polygon_points1 = np.array(list1).reshape(-1, 2)
|
39 |
+
poly1 = Polygon(polygon_points1).convex_hull
|
40 |
+
polygon_points2 = np.array(list2).reshape(-1, 2)
|
41 |
+
poly2 = Polygon(polygon_points2).convex_hull
|
42 |
+
union_poly = np.concatenate((polygon_points1,polygon_points2))
|
43 |
+
if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
|
44 |
+
iou = 0
|
45 |
+
else:
|
46 |
+
try:
|
47 |
+
inter_area = poly1.intersection(poly2).area
|
48 |
+
#union_area = poly1.area + poly2.area - inter_area
|
49 |
+
union_area = MultiPoint(union_poly).convex_hull.area
|
50 |
+
iou = float(inter_area) / (union_area+1e-6)
|
51 |
+
except shapely.geos.TopologicalError:
|
52 |
+
print('shapely.geos.TopologicalError occured, iou set to 0')
|
53 |
+
iou = 0
|
54 |
+
return iou
|
55 |
+
|
56 |
+
def nms(boxes,overlap):
|
57 |
+
rec_scores = [b[-6] for b in boxes]
|
58 |
+
indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k])
|
59 |
+
box_num = len(boxes)
|
60 |
+
nms_flag = [True]*box_num
|
61 |
+
for i in range(box_num):
|
62 |
+
ii = indices[i]
|
63 |
+
if not nms_flag[ii]:
|
64 |
+
continue
|
65 |
+
for j in range(box_num):
|
66 |
+
jj = indices[j]
|
67 |
+
if j == i:
|
68 |
+
continue
|
69 |
+
if not nms_flag[jj]:
|
70 |
+
continue
|
71 |
+
box1 = boxes[ii]
|
72 |
+
box2 = boxes[jj]
|
73 |
+
box1_score = rec_scores[ii]
|
74 |
+
box2_score = rec_scores[jj]
|
75 |
+
str1 = box1[9]
|
76 |
+
str2 = box2[9]
|
77 |
+
box_i = [box1[0],box1[1],box1[4],box1[5]]
|
78 |
+
box_j = [box2[0],box2[1],box2[4],box2[5]]
|
79 |
+
poly1 = polygon_from_list(box1[0:-6])
|
80 |
+
poly2 = polygon_from_list(box2[0:-6])
|
81 |
+
iou = polygon_iou(box1[0:-6],box2[0:-6])
|
82 |
+
thresh = overlap
|
83 |
+
|
84 |
+
if iou > thresh:
|
85 |
+
if box1_score > box2_score:
|
86 |
+
nms_flag[jj] = False
|
87 |
+
if box1_score == box2_score and poly1.area > poly2.area:
|
88 |
+
nms_flag[jj] = False
|
89 |
+
if box1_score == box2_score and poly1.area<=poly2.area:
|
90 |
+
nms_flag[ii] = False
|
91 |
+
break
|
92 |
+
|
93 |
+
return nms_flag
|
94 |
+
|
95 |
+
def packing(save_dir, cache_dir, pack_name):
|
96 |
+
files = os.listdir(save_dir)
|
97 |
+
if not os.path.exists(cache_dir):
|
98 |
+
os.mkdir(cache_dir)
|
99 |
+
os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*')
|
100 |
+
|
101 |
+
def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False):
|
102 |
+
'''
|
103 |
+
results_dir: result directory
|
104 |
+
score_det: score of detection bounding box
|
105 |
+
score_rec: score of the mask recognition branch
|
106 |
+
socre_rec_seq: score of the sequence recognition branch
|
107 |
+
overlap: overlap threshold used for nms
|
108 |
+
lexicon_type: 1 for generic; 2 for weak; 3 for strong
|
109 |
+
use_seq: use the recognition result of sequence branch
|
110 |
+
use_mix: use both the recognition result of the mask and sequence branches, selected by score
|
111 |
+
'''
|
112 |
+
print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'overlap:', overlap,'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix)
|
113 |
+
if not os.path.exists(cache_dir):
|
114 |
+
os.mkdir(cache_dir)
|
115 |
+
nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq))
|
116 |
+
if not os.path.exists(nms_dir):
|
117 |
+
os.mkdir(nms_dir)
|
118 |
+
if use_lexicon and lexicon_type==2:
|
119 |
+
# weak lexicon
|
120 |
+
lexicon_path = '../../lexicons/totaltext/weak_voc_new.txt'
|
121 |
+
lexicon_fid=open(lexicon_path, 'r')
|
122 |
+
pair_list = open('../../lexicons/totaltext/weak_voc_pair_list.txt', 'r')
|
123 |
+
pairs = dict()
|
124 |
+
for line in pair_list.readlines():
|
125 |
+
line=line.strip()
|
126 |
+
word = line.split(' ')[0].upper()
|
127 |
+
word_gt = line[len(word)+1:]
|
128 |
+
pairs[word] = word_gt
|
129 |
+
lexicon_fid=open(lexicon_path, 'r')
|
130 |
+
lexicon=[]
|
131 |
+
for line in lexicon_fid.readlines():
|
132 |
+
line=line.strip()
|
133 |
+
lexicon.append(line)
|
134 |
+
|
135 |
+
for res_file in glob.glob("*.txt"):
|
136 |
+
result_path = os.path.join(results_dir,res_file)
|
137 |
+
if os.path.isfile(result_path):
|
138 |
+
with open(result_path,'r') as f:
|
139 |
+
dt_lines = [a.strip() for a in f.readlines()]
|
140 |
+
dt_lines = [list_from_str(dt) for dt in dt_lines]
|
141 |
+
else:
|
142 |
+
dt_lines = []
|
143 |
+
dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det]
|
144 |
+
nms_flag = nms(dt_lines,overlap)
|
145 |
+
boxes = []
|
146 |
+
for k in range(len(dt_lines)):
|
147 |
+
dt = dt_lines[k]
|
148 |
+
if nms_flag[k]:
|
149 |
+
if dt not in boxes:
|
150 |
+
boxes.append(dt)
|
151 |
+
|
152 |
+
with open(os.path.join(nms_dir,'gt_'+res_file.split('.')[0].split('_')[1]+'.txt'),'w') as f:
|
153 |
+
for g in boxes:
|
154 |
+
gt_coors = [int(b) for b in g[0:-6]]
|
155 |
+
with open('../../../' + g[-1], "rb") as input_file:
|
156 |
+
dict_scores = pickle.load(input_file)
|
157 |
+
if use_char and use_seq:
|
158 |
+
if g[-2]>g[-3]:
|
159 |
+
word = g[-5]
|
160 |
+
scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
|
161 |
+
else:
|
162 |
+
word = g[-4]
|
163 |
+
scores = dict_scores['seg_char_scores']
|
164 |
+
elif use_seq:
|
165 |
+
word = g[-5]
|
166 |
+
scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
|
167 |
+
else:
|
168 |
+
word = g[-4]
|
169 |
+
scores = dict_scores['seg_char_scores']
|
170 |
+
if not use_lexicon:
|
171 |
+
match_word = word
|
172 |
+
match_dist = 0.
|
173 |
+
else:
|
174 |
+
match_word, match_dist = find_match_word(word, pairs, scores, use_lexicon, weighted_ed, lexicon)
|
175 |
+
if match_dist<1.5 or lexicon_type==1:
|
176 |
+
gt_coor_strs = [str(a) for a in gt_coors]+ [match_word]
|
177 |
+
f.write(','.join(gt_coor_strs)+'\r\n')
|
178 |
+
|
179 |
+
pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap)
|
180 |
+
|
181 |
+
packing(nms_dir,cache_dir,pack_name)
|
182 |
+
submit_file_path = os.path.join(cache_dir, pack_name+'.zip')
|
183 |
+
return submit_file_path
|
184 |
+
|
185 |
+
def find_match_word(rec_str, pairs, scores_numpy, use_ed=True, weighted_ed=False, lexicon=None):
|
186 |
+
if not use_ed:
|
187 |
+
return rec_str
|
188 |
+
rec_str = rec_str.upper()
|
189 |
+
dist_min = 100
|
190 |
+
dist_min_pre = 100
|
191 |
+
match_word = ''
|
192 |
+
match_dist = 100
|
193 |
+
if not weighted_ed:
|
194 |
+
for word in lexicon:
|
195 |
+
word = word.upper()
|
196 |
+
ed = editdistance.eval(rec_str, word)
|
197 |
+
length_dist = abs(len(word) - len(rec_str))
|
198 |
+
# dist = ed + length_dist
|
199 |
+
dist = ed
|
200 |
+
if dist<dist_min:
|
201 |
+
dist_min = dist
|
202 |
+
match_word = pairs[word]
|
203 |
+
match_dist = dist
|
204 |
+
return match_word, match_dist
|
205 |
+
else:
|
206 |
+
small_lexicon_dict = dict()
|
207 |
+
for word in lexicon:
|
208 |
+
word = word.upper()
|
209 |
+
ed = editdistance.eval(rec_str, word)
|
210 |
+
small_lexicon_dict[word] = ed
|
211 |
+
dist = ed
|
212 |
+
if dist<dist_min_pre:
|
213 |
+
dist_min_pre = dist
|
214 |
+
small_lexicon = []
|
215 |
+
for word in small_lexicon_dict:
|
216 |
+
if small_lexicon_dict[word]<=dist_min_pre+2:
|
217 |
+
small_lexicon.append(word)
|
218 |
+
|
219 |
+
for word in small_lexicon:
|
220 |
+
word = word.upper()
|
221 |
+
ed = weighted_edit_distance(rec_str, word, scores_numpy)
|
222 |
+
dist = ed
|
223 |
+
if dist<dist_min:
|
224 |
+
dist_min = dist
|
225 |
+
match_word = pairs[word]
|
226 |
+
match_dist = dist
|
227 |
+
return match_word, match_dist
|
228 |
+
|
229 |
+
|
230 |
+
def prepare_results_for_evaluation(results_dir, use_lexicon, cache_dir, score_det, score_rec, score_rec_seq):
|
231 |
+
if not os.path.isdir(cache_dir):
|
232 |
+
os.mkdir(cache_dir)
|
233 |
+
result_path = test_single(results_dir,score_det=score_det,score_rec=score_rec,score_rec_seq=score_rec_seq,overlap=0.2,cache_dir=cache_dir,lexicon_type=2, use_lexicon=use_lexicon, weighted_ed=True, use_seq=True, use_char=True, mix=True)
|
234 |
+
return result_path
|
evaluation/totaltext/e2e/rrc_evaluation_funcs.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
#encoding: UTF-8
|
3 |
+
import json
|
4 |
+
import sys;sys.path.append('./')
|
5 |
+
import zipfile
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import codecs
|
10 |
+
import importlib
|
11 |
+
try:
|
12 |
+
from StringIO import StringIO
|
13 |
+
except ImportError:
|
14 |
+
from io import StringIO
|
15 |
+
|
16 |
+
def print_help():
|
17 |
+
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
|
18 |
+
sys.exit(2)
|
19 |
+
|
20 |
+
|
21 |
+
def load_zip_file_keys(file,fileNameRegExp=''):
|
22 |
+
"""
|
23 |
+
Returns an array with the entries of the ZIP file that match with the regular expression.
|
24 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
25 |
+
"""
|
26 |
+
try:
|
27 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
28 |
+
except :
|
29 |
+
raise Exception('Error loading the ZIP archive.')
|
30 |
+
|
31 |
+
pairs = []
|
32 |
+
|
33 |
+
for name in archive.namelist():
|
34 |
+
addFile = True
|
35 |
+
keyName = name
|
36 |
+
if fileNameRegExp!="":
|
37 |
+
m = re.match(fileNameRegExp,name)
|
38 |
+
if m == None:
|
39 |
+
addFile = False
|
40 |
+
else:
|
41 |
+
if len(m.groups())>0:
|
42 |
+
keyName = m.group(1)
|
43 |
+
|
44 |
+
if addFile:
|
45 |
+
pairs.append( keyName )
|
46 |
+
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def load_zip_file(file,fileNameRegExp='',allEntries=False):
|
51 |
+
"""
|
52 |
+
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
|
53 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
54 |
+
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
58 |
+
except :
|
59 |
+
raise Exception('Error loading the ZIP archive')
|
60 |
+
|
61 |
+
pairs = []
|
62 |
+
for name in archive.namelist():
|
63 |
+
addFile = True
|
64 |
+
keyName = name
|
65 |
+
if fileNameRegExp!="":
|
66 |
+
m = re.match(fileNameRegExp,name)
|
67 |
+
if m == None:
|
68 |
+
addFile = False
|
69 |
+
else:
|
70 |
+
if len(m.groups())>0:
|
71 |
+
keyName = m.group(1)
|
72 |
+
|
73 |
+
if addFile:
|
74 |
+
pairs.append( [ keyName , archive.read(name)] )
|
75 |
+
else:
|
76 |
+
if allEntries:
|
77 |
+
raise Exception('ZIP entry not valid: %s' %name)
|
78 |
+
|
79 |
+
return dict(pairs)
|
80 |
+
|
81 |
+
def decode_utf8(raw):
|
82 |
+
"""
|
83 |
+
Returns a Unicode object on success, or None on failure
|
84 |
+
"""
|
85 |
+
try:
|
86 |
+
raw = codecs.decode(raw,'utf-8', 'replace')
|
87 |
+
#extracts BOM if exists
|
88 |
+
raw = raw.encode('utf8')
|
89 |
+
if raw.startswith(codecs.BOM_UTF8):
|
90 |
+
raw = raw.replace(codecs.BOM_UTF8, '', 1)
|
91 |
+
return raw.decode('utf-8')
|
92 |
+
except:
|
93 |
+
return None
|
94 |
+
|
95 |
+
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
96 |
+
"""
|
97 |
+
This function validates that all lines of the file calling the Line validation function for each line
|
98 |
+
"""
|
99 |
+
utf8File = decode_utf8(file_contents)
|
100 |
+
if (utf8File is None) :
|
101 |
+
raise Exception("The file %s is not UTF-8" %fileName)
|
102 |
+
|
103 |
+
lines = utf8File.split( "\r\n" if CRLF else "\n" )
|
104 |
+
for line in lines:
|
105 |
+
line = line.replace("\r","").replace("\n","")
|
106 |
+
if(line != ""):
|
107 |
+
try:
|
108 |
+
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
109 |
+
except Exception as e:
|
110 |
+
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
|
115 |
+
"""
|
116 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
117 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
118 |
+
Posible values are:
|
119 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
120 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
121 |
+
"""
|
122 |
+
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
123 |
+
|
124 |
+
|
125 |
+
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
126 |
+
"""
|
127 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
128 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
129 |
+
Posible values are:
|
130 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
131 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
132 |
+
Returns values from a textline. Points , [Confidences], [Transcriptions]
|
133 |
+
"""
|
134 |
+
confidence = 0.0
|
135 |
+
transcription = "";
|
136 |
+
points = []
|
137 |
+
|
138 |
+
numPoints = 4;
|
139 |
+
|
140 |
+
if LTRB:
|
141 |
+
|
142 |
+
numPoints = 4;
|
143 |
+
|
144 |
+
if withTranscription and withConfidence:
|
145 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
146 |
+
if m == None :
|
147 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
148 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
|
149 |
+
elif withConfidence:
|
150 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
151 |
+
if m == None :
|
152 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
|
153 |
+
elif withTranscription:
|
154 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
|
155 |
+
if m == None :
|
156 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
|
157 |
+
else:
|
158 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
|
159 |
+
if m == None :
|
160 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
|
161 |
+
|
162 |
+
xmin = int(m.group(1))
|
163 |
+
ymin = int(m.group(2))
|
164 |
+
xmax = int(m.group(3))
|
165 |
+
ymax = int(m.group(4))
|
166 |
+
if(xmax<xmin):
|
167 |
+
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
|
168 |
+
if(ymax<ymin):
|
169 |
+
raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
|
170 |
+
|
171 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
172 |
+
|
173 |
+
if (imWidth>0 and imHeight>0):
|
174 |
+
validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
|
175 |
+
validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
|
176 |
+
|
177 |
+
else:
|
178 |
+
|
179 |
+
numPoints = 8;
|
180 |
+
|
181 |
+
if withTranscription and withConfidence:
|
182 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
183 |
+
if m == None :
|
184 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
|
185 |
+
elif withConfidence:
|
186 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
187 |
+
if m == None :
|
188 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
|
189 |
+
elif withTranscription:
|
190 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
|
191 |
+
if m == None :
|
192 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
|
193 |
+
else:
|
194 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
|
195 |
+
if m == None :
|
196 |
+
raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
|
197 |
+
|
198 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
199 |
+
|
200 |
+
validate_clockwise_points(points)
|
201 |
+
|
202 |
+
if (imWidth>0 and imHeight>0):
|
203 |
+
validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
|
204 |
+
validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
|
205 |
+
validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
|
206 |
+
validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
|
207 |
+
|
208 |
+
|
209 |
+
if withConfidence:
|
210 |
+
try:
|
211 |
+
confidence = float(m.group(numPoints+1))
|
212 |
+
except ValueError:
|
213 |
+
raise Exception("Confidence value must be a float")
|
214 |
+
|
215 |
+
if withTranscription:
|
216 |
+
posTranscription = numPoints + (2 if withConfidence else 1)
|
217 |
+
transcription = m.group(posTranscription)
|
218 |
+
m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
|
219 |
+
if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
|
220 |
+
transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
|
221 |
+
|
222 |
+
return points,confidence,transcription
|
223 |
+
|
224 |
+
|
225 |
+
def validate_point_inside_bounds(x,y,imWidth,imHeight):
|
226 |
+
if(x<0 or x>imWidth):
|
227 |
+
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
|
228 |
+
if(y<0 or y>imHeight):
|
229 |
+
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
|
230 |
+
|
231 |
+
def validate_clockwise_points(points):
|
232 |
+
"""
|
233 |
+
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
|
234 |
+
"""
|
235 |
+
|
236 |
+
if len(points) != 8:
|
237 |
+
raise Exception("Points list not valid." + str(len(points)))
|
238 |
+
|
239 |
+
point = [
|
240 |
+
[int(points[0]) , int(points[1])],
|
241 |
+
[int(points[2]) , int(points[3])],
|
242 |
+
[int(points[4]) , int(points[5])],
|
243 |
+
[int(points[6]) , int(points[7])]
|
244 |
+
]
|
245 |
+
edge = [
|
246 |
+
( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
|
247 |
+
( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
|
248 |
+
( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
|
249 |
+
( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
|
250 |
+
]
|
251 |
+
|
252 |
+
summatory = edge[0] + edge[1] + edge[2] + edge[3];
|
253 |
+
if summatory>0:
|
254 |
+
raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
|
255 |
+
|
256 |
+
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
|
257 |
+
"""
|
258 |
+
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
|
259 |
+
xmin,ymin,xmax,ymax,[confidence],[transcription]
|
260 |
+
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
|
261 |
+
"""
|
262 |
+
pointsList = []
|
263 |
+
transcriptionsList = []
|
264 |
+
confidencesList = []
|
265 |
+
|
266 |
+
lines = content.split( "\r\n" if CRLF else "\n" )
|
267 |
+
for line in lines:
|
268 |
+
line = line.replace("\r","").replace("\n","")
|
269 |
+
if(line != "") :
|
270 |
+
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
|
271 |
+
pointsList.append(points)
|
272 |
+
transcriptionsList.append(transcription)
|
273 |
+
confidencesList.append(confidence)
|
274 |
+
|
275 |
+
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
|
276 |
+
import numpy as np
|
277 |
+
sorted_ind = np.argsort(-np.array(confidencesList))
|
278 |
+
confidencesList = [confidencesList[i] for i in sorted_ind]
|
279 |
+
pointsList = [pointsList[i] for i in sorted_ind]
|
280 |
+
transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
|
281 |
+
|
282 |
+
return pointsList,confidencesList,transcriptionsList
|
283 |
+
|
284 |
+
def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
|
285 |
+
"""
|
286 |
+
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
|
287 |
+
Params:
|
288 |
+
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
|
289 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
290 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
291 |
+
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
|
292 |
+
"""
|
293 |
+
|
294 |
+
if (p == None):
|
295 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
296 |
+
if(len(sys.argv)<3):
|
297 |
+
print_help()
|
298 |
+
|
299 |
+
evalParams = default_evaluation_params_fn()
|
300 |
+
if 'p' in p.keys():
|
301 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
302 |
+
|
303 |
+
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
|
304 |
+
try:
|
305 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
306 |
+
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
|
307 |
+
resDict.update(evalData)
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
resDict['Message']= str(e)
|
311 |
+
resDict['calculated']=False
|
312 |
+
|
313 |
+
if 'o' in p:
|
314 |
+
if not os.path.exists(p['o']):
|
315 |
+
os.makedirs(p['o'])
|
316 |
+
|
317 |
+
resultsOutputname = p['o'] + '/results.zip'
|
318 |
+
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
|
319 |
+
|
320 |
+
del resDict['per_sample']
|
321 |
+
if 'output_items' in resDict.keys():
|
322 |
+
del resDict['output_items']
|
323 |
+
|
324 |
+
outZip.writestr('method.json',json.dumps(resDict))
|
325 |
+
|
326 |
+
if not resDict['calculated']:
|
327 |
+
if show_result:
|
328 |
+
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
|
329 |
+
if 'o' in p:
|
330 |
+
outZip.close()
|
331 |
+
return resDict
|
332 |
+
|
333 |
+
if 'o' in p:
|
334 |
+
if per_sample == True:
|
335 |
+
for k,v in evalData['per_sample'].items():
|
336 |
+
outZip.writestr( k + '.json',json.dumps(v))
|
337 |
+
|
338 |
+
if 'output_items' in evalData.keys():
|
339 |
+
for k, v in evalData['output_items'].items():
|
340 |
+
outZip.writestr( k,v)
|
341 |
+
|
342 |
+
outZip.close()
|
343 |
+
|
344 |
+
if show_result:
|
345 |
+
sys.stdout.write("Calculated!")
|
346 |
+
sys.stdout.write(json.dumps(resDict['method']))
|
347 |
+
|
348 |
+
return resDict
|
349 |
+
|
350 |
+
|
351 |
+
def main_validation(default_evaluation_params_fn,validate_data_fn):
|
352 |
+
"""
|
353 |
+
This process validates a method
|
354 |
+
Params:
|
355 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
356 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
357 |
+
"""
|
358 |
+
try:
|
359 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
360 |
+
evalParams = default_evaluation_params_fn()
|
361 |
+
if 'p' in p.keys():
|
362 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
363 |
+
|
364 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
365 |
+
print('SUCCESS')
|
366 |
+
sys.exit(0)
|
367 |
+
except Exception as e:
|
368 |
+
print(str(e))
|
369 |
+
sys.exit(101)
|
evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python2
|
2 |
+
#encoding: UTF-8
|
3 |
+
import json
|
4 |
+
import sys;sys.path.append('./')
|
5 |
+
import zipfile
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import codecs
|
10 |
+
import importlib
|
11 |
+
from io import StringIO
|
12 |
+
|
13 |
+
def print_help():
|
14 |
+
sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> -o=<outputFolder> [-i=<gtImagesFile> -p=<jsonParams>]' %sys.argv[0])
|
15 |
+
sys.exit(2)
|
16 |
+
|
17 |
+
|
18 |
+
def load_zip_file_keys(file,fileNameRegExp=''):
|
19 |
+
"""
|
20 |
+
Returns an array with the entries of the ZIP file that match with the regular expression.
|
21 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
25 |
+
except :
|
26 |
+
raise Exception('Error loading the ZIP archive.')
|
27 |
+
|
28 |
+
pairs = []
|
29 |
+
|
30 |
+
for name in archive.namelist():
|
31 |
+
addFile = True
|
32 |
+
keyName = name
|
33 |
+
# if fileNameRegExp!="":
|
34 |
+
# m = re.match(fileNameRegExp,name)
|
35 |
+
# if m == None:
|
36 |
+
# addFile = False
|
37 |
+
# else:
|
38 |
+
# if len(m.groups())>0:
|
39 |
+
# keyName = m.group(1)
|
40 |
+
|
41 |
+
if addFile:
|
42 |
+
pairs.append( keyName )
|
43 |
+
|
44 |
+
return pairs
|
45 |
+
|
46 |
+
|
47 |
+
def load_zip_file(file,fileNameRegExp='',allEntries=False):
|
48 |
+
"""
|
49 |
+
Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
|
50 |
+
The key's are the names or the file or the capturing group definied in the fileNameRegExp
|
51 |
+
allEntries validates that all entries in the ZIP file pass the fileNameRegExp
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
|
55 |
+
except :
|
56 |
+
raise Exception('Error loading the ZIP archive')
|
57 |
+
|
58 |
+
pairs = []
|
59 |
+
for name in archive.namelist():
|
60 |
+
addFile = True
|
61 |
+
keyName = name
|
62 |
+
# if fileNameRegExp!="":
|
63 |
+
# m = re.match(fileNameRegExp,name)
|
64 |
+
# if m == None:
|
65 |
+
# addFile = False
|
66 |
+
# else:
|
67 |
+
# if len(m.groups())>0:
|
68 |
+
# keyName = m.group(1)
|
69 |
+
|
70 |
+
if addFile:
|
71 |
+
pairs.append( [ keyName , archive.read(name)] )
|
72 |
+
else:
|
73 |
+
if allEntries:
|
74 |
+
raise Exception('ZIP entry not valid: %s' %name)
|
75 |
+
|
76 |
+
return dict(pairs)
|
77 |
+
|
78 |
+
def decode_utf8(raw):
|
79 |
+
"""
|
80 |
+
Returns a Unicode object on success, or None on failure
|
81 |
+
"""
|
82 |
+
try:
|
83 |
+
raw = codecs.decode(raw,'utf-8', 'replace')
|
84 |
+
#extracts BOM if exists
|
85 |
+
raw = raw.encode('utf8')
|
86 |
+
if raw.startswith(codecs.BOM_UTF8):
|
87 |
+
raw = raw.replace(codecs.BOM_UTF8, '', 1)
|
88 |
+
return raw.decode('utf-8')
|
89 |
+
except:
|
90 |
+
return None
|
91 |
+
|
92 |
+
def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
93 |
+
"""
|
94 |
+
This function validates that all lines of the file calling the Line validation function for each line
|
95 |
+
"""
|
96 |
+
utf8File = decode_utf8(file_contents)
|
97 |
+
if (utf8File is None) :
|
98 |
+
raise Exception("The file %s is not UTF-8" %fileName)
|
99 |
+
|
100 |
+
lines = utf8File.split( "\r\n" if CRLF else "\n" )
|
101 |
+
for line in lines:
|
102 |
+
line = line.replace("\r","").replace("\n","")
|
103 |
+
if(line != ""):
|
104 |
+
try:
|
105 |
+
validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
106 |
+
except Exception as e:
|
107 |
+
raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
|
112 |
+
"""
|
113 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
114 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
115 |
+
Posible values are:
|
116 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
117 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
118 |
+
"""
|
119 |
+
get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
|
120 |
+
|
121 |
+
|
122 |
+
def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
|
123 |
+
"""
|
124 |
+
Validate the format of the line. If the line is not valid an exception will be raised.
|
125 |
+
If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
|
126 |
+
Posible values are:
|
127 |
+
LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
|
128 |
+
LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
|
129 |
+
Returns values from a textline. Points , [Confidences], [Transcriptions]
|
130 |
+
"""
|
131 |
+
confidence = 0.0
|
132 |
+
transcription = "";
|
133 |
+
points = []
|
134 |
+
|
135 |
+
numPoints = 4;
|
136 |
+
if LTRB:
|
137 |
+
|
138 |
+
numPoints = 4;
|
139 |
+
|
140 |
+
if withTranscription and withConfidence:
|
141 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
142 |
+
if m == None :
|
143 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
144 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
|
145 |
+
elif withConfidence:
|
146 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
147 |
+
if m == None :
|
148 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
|
149 |
+
elif withTranscription:
|
150 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
|
151 |
+
if m == None :
|
152 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
|
153 |
+
else:
|
154 |
+
m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
|
155 |
+
if m == None :
|
156 |
+
raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
|
157 |
+
|
158 |
+
xmin = int(m.group(1))
|
159 |
+
ymin = int(m.group(2))
|
160 |
+
xmax = int(m.group(3))
|
161 |
+
ymax = int(m.group(4))
|
162 |
+
if(xmax<xmin):
|
163 |
+
raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
|
164 |
+
if(ymax<ymin):
|
165 |
+
raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
|
166 |
+
|
167 |
+
points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
168 |
+
|
169 |
+
if (imWidth>0 and imHeight>0):
|
170 |
+
validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
|
171 |
+
validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
|
172 |
+
|
173 |
+
else:
|
174 |
+
line_split = line.split(',')
|
175 |
+
# print(line_split)
|
176 |
+
numPoints = int((len(line_split) - 1) / 2)
|
177 |
+
points = [ float(line_split[i]) for i in range(2 * numPoints) ]
|
178 |
+
# print(points)
|
179 |
+
transcription = line_split[-1]
|
180 |
+
# numPoints = 8;
|
181 |
+
|
182 |
+
# if withTranscription and withConfidence:
|
183 |
+
# m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
|
184 |
+
# if m == None :
|
185 |
+
# raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
|
186 |
+
# elif withConfidence:
|
187 |
+
# m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
|
188 |
+
# if m == None :
|
189 |
+
# raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
|
190 |
+
# elif withTranscription:
|
191 |
+
# m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
|
192 |
+
# if m == None :
|
193 |
+
# raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
|
194 |
+
# else:
|
195 |
+
# m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
|
196 |
+
# if m == None :
|
197 |
+
# raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
|
198 |
+
|
199 |
+
# points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
|
200 |
+
|
201 |
+
# validate_clockwise_points(points)
|
202 |
+
|
203 |
+
# if (imWidth>0 and imHeight>0):
|
204 |
+
# validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
|
205 |
+
# validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
|
206 |
+
# validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
|
207 |
+
# validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
|
208 |
+
|
209 |
+
|
210 |
+
# if withConfidence:
|
211 |
+
# try:
|
212 |
+
# confidence = float(m.group(numPoints+1))
|
213 |
+
# except ValueError:
|
214 |
+
# raise Exception("Confidence value must be a float")
|
215 |
+
|
216 |
+
# if withTranscription:
|
217 |
+
# posTranscription = numPoints + (2 if withConfidence else 1)
|
218 |
+
# transcription = m.group(posTranscription)
|
219 |
+
# m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
|
220 |
+
# if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
|
221 |
+
# transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
|
222 |
+
|
223 |
+
return points,confidence,transcription
|
224 |
+
|
225 |
+
|
226 |
+
def validate_point_inside_bounds(x,y,imWidth,imHeight):
|
227 |
+
if(x<0 or x>imWidth):
|
228 |
+
raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
|
229 |
+
if(y<0 or y>imHeight):
|
230 |
+
raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
|
231 |
+
|
232 |
+
def validate_clockwise_points(points):
|
233 |
+
"""
|
234 |
+
Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
|
235 |
+
"""
|
236 |
+
|
237 |
+
if len(points) != 8:
|
238 |
+
raise Exception("Points list not valid." + str(len(points)))
|
239 |
+
|
240 |
+
point = [
|
241 |
+
[int(points[0]) , int(points[1])],
|
242 |
+
[int(points[2]) , int(points[3])],
|
243 |
+
[int(points[4]) , int(points[5])],
|
244 |
+
[int(points[6]) , int(points[7])]
|
245 |
+
]
|
246 |
+
edge = [
|
247 |
+
( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
|
248 |
+
( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
|
249 |
+
( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
|
250 |
+
( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
|
251 |
+
]
|
252 |
+
|
253 |
+
summatory = edge[0] + edge[1] + edge[2] + edge[3];
|
254 |
+
if summatory>0:
|
255 |
+
raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
|
256 |
+
|
257 |
+
def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
|
258 |
+
"""
|
259 |
+
Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
|
260 |
+
xmin,ymin,xmax,ymax,[confidence],[transcription]
|
261 |
+
x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
|
262 |
+
"""
|
263 |
+
pointsList = []
|
264 |
+
transcriptionsList = []
|
265 |
+
confidencesList = []
|
266 |
+
|
267 |
+
lines = content.split( "\r\n" if CRLF else "\n" )
|
268 |
+
for line in lines:
|
269 |
+
line = line.replace("\r","").replace("\n","")
|
270 |
+
if(line != "") :
|
271 |
+
points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
|
272 |
+
pointsList.append(points)
|
273 |
+
transcriptionsList.append(transcription)
|
274 |
+
confidencesList.append(confidence)
|
275 |
+
|
276 |
+
if withConfidence and len(confidencesList)>0 and sort_by_confidences:
|
277 |
+
confidencesList, pointsList,transcriptionsList = (list(t) for t in zip(*sorted(zip(confidencesList, pointsList, transcriptionsList), reverse=True)))
|
278 |
+
|
279 |
+
return pointsList,confidencesList,transcriptionsList
|
280 |
+
|
281 |
+
def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
|
282 |
+
"""
|
283 |
+
This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
|
284 |
+
Params:
|
285 |
+
p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
|
286 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
287 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
288 |
+
evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
|
289 |
+
"""
|
290 |
+
|
291 |
+
if (p == None):
|
292 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
293 |
+
if(len(sys.argv)<2):
|
294 |
+
print_help()
|
295 |
+
|
296 |
+
evalParams = default_evaluation_params_fn()
|
297 |
+
if 'p' in list(p.keys()):
|
298 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
299 |
+
|
300 |
+
resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
|
301 |
+
try:
|
302 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
303 |
+
evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
|
304 |
+
resDict.update(evalData)
|
305 |
+
|
306 |
+
except Exception as e:
|
307 |
+
resDict['Message']= str(e)
|
308 |
+
resDict['calculated']=False
|
309 |
+
|
310 |
+
if not os.path.exists(p['o']):
|
311 |
+
os.makedirs(p['o'])
|
312 |
+
|
313 |
+
resultsOutputname = p['o'] + '/results.zip'
|
314 |
+
outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
|
315 |
+
|
316 |
+
del resDict['per_sample']
|
317 |
+
if 'output_items' in list(resDict.keys()):
|
318 |
+
del resDict['output_items']
|
319 |
+
|
320 |
+
outZip.writestr('method.json',json.dumps(resDict))
|
321 |
+
|
322 |
+
if not resDict['calculated']:
|
323 |
+
if show_result:
|
324 |
+
sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
|
325 |
+
outZip.close()
|
326 |
+
return resDict
|
327 |
+
|
328 |
+
if per_sample == True:
|
329 |
+
for k,v in evalData['per_sample'].items():
|
330 |
+
outZip.writestr( k + '.json',json.dumps(v))
|
331 |
+
|
332 |
+
if 'output_items' in list(evalData.keys()):
|
333 |
+
for k, v in evalData['output_items'].items():
|
334 |
+
outZip.writestr( k,v)
|
335 |
+
|
336 |
+
outZip.close()
|
337 |
+
|
338 |
+
if show_result:
|
339 |
+
sys.stdout.write("Calculated!")
|
340 |
+
sys.stdout.write(json.dumps(resDict['method']))
|
341 |
+
|
342 |
+
return resDict
|
343 |
+
|
344 |
+
|
345 |
+
def main_validation(default_evaluation_params_fn,validate_data_fn):
|
346 |
+
"""
|
347 |
+
This process validates a method
|
348 |
+
Params:
|
349 |
+
default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
|
350 |
+
validate_data_fn: points to a method that validates the corrct format of the submission
|
351 |
+
"""
|
352 |
+
try:
|
353 |
+
p = dict([s[1:].split('=') for s in sys.argv[1:]])
|
354 |
+
evalParams = default_evaluation_params_fn()
|
355 |
+
if 'p' in list(p.keys()):
|
356 |
+
evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
|
357 |
+
|
358 |
+
validate_data_fn(p['g'], p['s'], evalParams)
|
359 |
+
print('SUCCESS')
|
360 |
+
sys.exit(0)
|
361 |
+
except Exception as e:
|
362 |
+
print(str(e))
|
363 |
+
sys.exit(101)
|
evaluation/totaltext/e2e/script.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# encoding=utf8
|
4 |
+
from collections import namedtuple
|
5 |
+
import rrc_evaluation_funcs_total_text as rrc_evaluation_funcs
|
6 |
+
import importlib
|
7 |
+
from prepare_results import prepare_results_for_evaluation
|
8 |
+
|
9 |
+
def evaluation_imports():
|
10 |
+
"""
|
11 |
+
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
|
12 |
+
"""
|
13 |
+
return {
|
14 |
+
'Polygon':'plg',
|
15 |
+
'numpy':'np'
|
16 |
+
}
|
17 |
+
|
18 |
+
def default_evaluation_params():
|
19 |
+
"""
|
20 |
+
default_evaluation_params: Default parameters to use for the validation and evaluation.
|
21 |
+
"""
|
22 |
+
return {
|
23 |
+
'IOU_CONSTRAINT' :0.5,
|
24 |
+
'AREA_PRECISION_CONSTRAINT' :0.5,
|
25 |
+
'WORD_SPOTTING' :False,
|
26 |
+
'MIN_LENGTH_CARE_WORD' :3,
|
27 |
+
'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
|
28 |
+
'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
|
29 |
+
'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
|
30 |
+
'CRLF':False, # Lines are delimited by Windows CRLF format
|
31 |
+
'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
|
32 |
+
'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
|
33 |
+
'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
|
34 |
+
}
|
35 |
+
|
36 |
+
def validate_data(gtFilePath, submFilePath, evaluationParams):
|
37 |
+
"""
|
38 |
+
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
|
39 |
+
Validates also that there are no missing files in the folder.
|
40 |
+
If some error detected, the method raises the error
|
41 |
+
"""
|
42 |
+
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
43 |
+
|
44 |
+
subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
|
45 |
+
|
46 |
+
#Validate format of GroundTruth
|
47 |
+
for k in gt:
|
48 |
+
rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
|
49 |
+
|
50 |
+
#Validate format of results
|
51 |
+
for k in subm:
|
52 |
+
if (k in gt) == False :
|
53 |
+
raise Exception("The sample %s not present in GT" %k)
|
54 |
+
|
55 |
+
rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
|
56 |
+
|
57 |
+
|
58 |
+
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
|
59 |
+
"""
|
60 |
+
Method evaluate_method: evaluate method and returns the results
|
61 |
+
Results. Dictionary with the following values:
|
62 |
+
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
|
63 |
+
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
|
64 |
+
"""
|
65 |
+
for module,alias in evaluation_imports().items():
|
66 |
+
globals()[alias] = importlib.import_module(module)
|
67 |
+
|
68 |
+
def polygon_from_points(points,correctOffset=False):
|
69 |
+
"""
|
70 |
+
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
|
71 |
+
"""
|
72 |
+
resBoxes=np.empty([1,len(points)],dtype='int32')
|
73 |
+
for i in range(int(len(points) / 2)):
|
74 |
+
resBoxes[0, i] = int(points[2*i])
|
75 |
+
resBoxes[0, int(len(points) / 2) + i] = int(points[2*i+1])
|
76 |
+
|
77 |
+
pointMat = resBoxes[0].reshape([2,-1]).T
|
78 |
+
return plg.Polygon( pointMat)
|
79 |
+
|
80 |
+
def rectangle_to_polygon(rect):
|
81 |
+
resBoxes=np.empty([1,8],dtype='int32')
|
82 |
+
resBoxes[0,0]=int(rect.xmin)
|
83 |
+
resBoxes[0,4]=int(rect.ymax)
|
84 |
+
resBoxes[0,1]=int(rect.xmin)
|
85 |
+
resBoxes[0,5]=int(rect.ymin)
|
86 |
+
resBoxes[0,2]=int(rect.xmax)
|
87 |
+
resBoxes[0,6]=int(rect.ymin)
|
88 |
+
resBoxes[0,3]=int(rect.xmax)
|
89 |
+
resBoxes[0,7]=int(rect.ymax)
|
90 |
+
|
91 |
+
pointMat = resBoxes[0].reshape([2,4]).T
|
92 |
+
|
93 |
+
return plg.Polygon( pointMat)
|
94 |
+
|
95 |
+
def rectangle_to_points(rect):
|
96 |
+
points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
|
97 |
+
return points
|
98 |
+
|
99 |
+
def get_union(pD,pG):
|
100 |
+
areaA = pD.area();
|
101 |
+
areaB = pG.area();
|
102 |
+
return areaA + areaB - get_intersection(pD, pG);
|
103 |
+
|
104 |
+
def get_intersection_over_union(pD,pG):
|
105 |
+
try:
|
106 |
+
return get_intersection(pD, pG) / get_union(pD, pG);
|
107 |
+
except:
|
108 |
+
return 0
|
109 |
+
|
110 |
+
def get_intersection(pD,pG):
|
111 |
+
pInt = pD & pG
|
112 |
+
if len(pInt) == 0:
|
113 |
+
return 0
|
114 |
+
return pInt.area()
|
115 |
+
|
116 |
+
def compute_ap(confList, matchList,numGtCare):
|
117 |
+
correct = 0
|
118 |
+
AP = 0
|
119 |
+
if len(confList)>0:
|
120 |
+
confList = np.array(confList)
|
121 |
+
matchList = np.array(matchList)
|
122 |
+
sorted_ind = np.argsort(-confList)
|
123 |
+
confList = confList[sorted_ind]
|
124 |
+
matchList = matchList[sorted_ind]
|
125 |
+
for n in range(len(confList)):
|
126 |
+
match = matchList[n]
|
127 |
+
if match:
|
128 |
+
correct += 1
|
129 |
+
AP += float(correct)/(n + 1)
|
130 |
+
|
131 |
+
if numGtCare>0:
|
132 |
+
AP /= numGtCare
|
133 |
+
|
134 |
+
return AP
|
135 |
+
|
136 |
+
def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
|
137 |
+
|
138 |
+
if onlyRemoveFirstLastCharacterGT:
|
139 |
+
#special characters in GT are allowed only at initial or final position
|
140 |
+
if (transGt==transDet):
|
141 |
+
return True
|
142 |
+
|
143 |
+
if specialCharacters.find(transGt[0])>-1:
|
144 |
+
if transGt[1:]==transDet:
|
145 |
+
return True
|
146 |
+
|
147 |
+
if specialCharacters.find(transGt[-1])>-1:
|
148 |
+
if transGt[0:len(transGt)-1]==transDet:
|
149 |
+
return True
|
150 |
+
|
151 |
+
if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
|
152 |
+
if transGt[1:len(transGt)-1]==transDet:
|
153 |
+
return True
|
154 |
+
return False
|
155 |
+
else:
|
156 |
+
#Special characters are removed from the begining and the end of both Detection and GroundTruth
|
157 |
+
while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
|
158 |
+
transGt = transGt[1:]
|
159 |
+
|
160 |
+
while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
|
161 |
+
transDet = transDet[1:]
|
162 |
+
|
163 |
+
while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
|
164 |
+
transGt = transGt[0:len(transGt)-1]
|
165 |
+
|
166 |
+
while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
|
167 |
+
transDet = transDet[0:len(transDet)-1]
|
168 |
+
|
169 |
+
return transGt == transDet
|
170 |
+
|
171 |
+
|
172 |
+
def include_in_dictionary(transcription):
|
173 |
+
"""
|
174 |
+
Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
|
175 |
+
"""
|
176 |
+
#special case 's at final
|
177 |
+
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
|
178 |
+
transcription = transcription[0:len(transcription)-2]
|
179 |
+
|
180 |
+
#hypens at init or final of the word
|
181 |
+
transcription = transcription.strip('-');
|
182 |
+
|
183 |
+
specialCharacters = "'!?.:,*\"()·[]/";
|
184 |
+
for character in specialCharacters:
|
185 |
+
transcription = transcription.replace(character,' ')
|
186 |
+
|
187 |
+
transcription = transcription.strip()
|
188 |
+
|
189 |
+
if len(transcription) != len(transcription.replace(" ","")) :
|
190 |
+
return False;
|
191 |
+
|
192 |
+
if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
|
193 |
+
return False;
|
194 |
+
|
195 |
+
notAllowed = "×÷·";
|
196 |
+
|
197 |
+
range1 = [ ord(u'a'), ord(u'z') ]
|
198 |
+
range2 = [ ord(u'A'), ord(u'Z') ]
|
199 |
+
range3 = [ ord(u'À'), ord(u'ƿ') ]
|
200 |
+
range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
|
201 |
+
range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
|
202 |
+
range6 = [ ord(u'-'), ord(u'-') ]
|
203 |
+
|
204 |
+
for char in transcription :
|
205 |
+
charCode = ord(char)
|
206 |
+
if(notAllowed.find(char) != -1):
|
207 |
+
return False
|
208 |
+
|
209 |
+
valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
|
210 |
+
if valid == False:
|
211 |
+
return False
|
212 |
+
|
213 |
+
return True
|
214 |
+
|
215 |
+
def include_in_dictionary_transcription(transcription):
|
216 |
+
"""
|
217 |
+
Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
|
218 |
+
"""
|
219 |
+
#special case 's at final
|
220 |
+
if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
|
221 |
+
transcription = transcription[0:len(transcription)-2]
|
222 |
+
|
223 |
+
#hypens at init or final of the word
|
224 |
+
transcription = transcription.strip('-');
|
225 |
+
|
226 |
+
specialCharacters = "'!?.:,*\"()·[]/";
|
227 |
+
for character in specialCharacters:
|
228 |
+
transcription = transcription.replace(character,' ')
|
229 |
+
|
230 |
+
transcription = transcription.strip()
|
231 |
+
|
232 |
+
return transcription
|
233 |
+
|
234 |
+
perSampleMetrics = {}
|
235 |
+
|
236 |
+
matchedSum = 0
|
237 |
+
|
238 |
+
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
239 |
+
|
240 |
+
gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
|
241 |
+
subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
|
242 |
+
|
243 |
+
numGlobalCareGt = 0;
|
244 |
+
numGlobalCareDet = 0;
|
245 |
+
|
246 |
+
arrGlobalConfidences = [];
|
247 |
+
arrGlobalMatches = [];
|
248 |
+
|
249 |
+
for resFile in gt:
|
250 |
+
|
251 |
+
gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
|
252 |
+
if (gtFile is None) :
|
253 |
+
raise Exception("The file %s is not UTF-8" %resFile)
|
254 |
+
|
255 |
+
recall = 0
|
256 |
+
precision = 0
|
257 |
+
hmean = 0
|
258 |
+
detCorrect = 0
|
259 |
+
iouMat = np.empty([1,1])
|
260 |
+
gtPols = []
|
261 |
+
detPols = []
|
262 |
+
gtTrans = []
|
263 |
+
detTrans = []
|
264 |
+
gtPolPoints = []
|
265 |
+
detPolPoints = []
|
266 |
+
gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
|
267 |
+
detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
|
268 |
+
detMatchedNums = []
|
269 |
+
pairs = []
|
270 |
+
|
271 |
+
arrSampleConfidences = [];
|
272 |
+
arrSampleMatch = [];
|
273 |
+
sampleAP = 0;
|
274 |
+
|
275 |
+
evaluationLog = ""
|
276 |
+
|
277 |
+
pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
|
278 |
+
for n in range(len(pointsList)):
|
279 |
+
points = pointsList[n]
|
280 |
+
transcription = transcriptionsList[n]
|
281 |
+
dontCare = transcription == "###"
|
282 |
+
if evaluationParams['LTRB']:
|
283 |
+
gtRect = Rectangle(*points)
|
284 |
+
gtPol = rectangle_to_polygon(gtRect)
|
285 |
+
else:
|
286 |
+
gtPol = polygon_from_points(points)
|
287 |
+
gtPols.append(gtPol)
|
288 |
+
gtPolPoints.append(points)
|
289 |
+
|
290 |
+
#On word spotting we will filter some transcriptions with special characters
|
291 |
+
if evaluationParams['WORD_SPOTTING'] :
|
292 |
+
if dontCare == False :
|
293 |
+
if include_in_dictionary(transcription) == False :
|
294 |
+
dontCare = True
|
295 |
+
else:
|
296 |
+
transcription = include_in_dictionary_transcription(transcription)
|
297 |
+
|
298 |
+
gtTrans.append(transcription)
|
299 |
+
if dontCare:
|
300 |
+
gtDontCarePolsNum.append( len(gtPols)-1 )
|
301 |
+
|
302 |
+
evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
|
303 |
+
|
304 |
+
if resFile in subm:
|
305 |
+
|
306 |
+
detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
|
307 |
+
|
308 |
+
pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
|
309 |
+
|
310 |
+
for n in range(len(pointsList)):
|
311 |
+
points = pointsList[n]
|
312 |
+
transcription = transcriptionsList[n]
|
313 |
+
|
314 |
+
if evaluationParams['LTRB']:
|
315 |
+
detRect = Rectangle(*points)
|
316 |
+
detPol = rectangle_to_polygon(detRect)
|
317 |
+
else:
|
318 |
+
detPol = polygon_from_points(points)
|
319 |
+
detPols.append(detPol)
|
320 |
+
detPolPoints.append(points)
|
321 |
+
detTrans.append(transcription)
|
322 |
+
|
323 |
+
if len(gtDontCarePolsNum)>0 :
|
324 |
+
for dontCarePol in gtDontCarePolsNum:
|
325 |
+
dontCarePol = gtPols[dontCarePol]
|
326 |
+
intersected_area = get_intersection(dontCarePol,detPol)
|
327 |
+
pdDimensions = detPol.area()
|
328 |
+
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
329 |
+
if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
|
330 |
+
detDontCarePolsNum.append( len(detPols)-1 )
|
331 |
+
break
|
332 |
+
|
333 |
+
evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
|
334 |
+
|
335 |
+
if len(gtPols)>0 and len(detPols)>0:
|
336 |
+
#Calculate IoU and precision matrixs
|
337 |
+
outputShape=[len(gtPols),len(detPols)]
|
338 |
+
iouMat = np.empty(outputShape)
|
339 |
+
gtRectMat = np.zeros(len(gtPols),np.int8)
|
340 |
+
detRectMat = np.zeros(len(detPols),np.int8)
|
341 |
+
for gtNum in range(len(gtPols)):
|
342 |
+
for detNum in range(len(detPols)):
|
343 |
+
pG = gtPols[gtNum]
|
344 |
+
pD = detPols[detNum]
|
345 |
+
iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
|
346 |
+
|
347 |
+
for gtNum in range(len(gtPols)):
|
348 |
+
for detNum in range(len(detPols)):
|
349 |
+
if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
|
350 |
+
if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
|
351 |
+
gtRectMat[gtNum] = 1
|
352 |
+
detRectMat[detNum] = 1
|
353 |
+
#detection matched only if transcription is equal
|
354 |
+
if evaluationParams['WORD_SPOTTING']:
|
355 |
+
correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
|
356 |
+
else:
|
357 |
+
correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
|
358 |
+
detCorrect += (1 if correct else 0)
|
359 |
+
if correct:
|
360 |
+
detMatchedNums.append(detNum)
|
361 |
+
pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
|
362 |
+
evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
|
363 |
+
|
364 |
+
if evaluationParams['CONFIDENCES']:
|
365 |
+
for detNum in range(len(detPols)):
|
366 |
+
if detNum not in detDontCarePolsNum :
|
367 |
+
#we exclude the don't care detections
|
368 |
+
match = detNum in detMatchedNums
|
369 |
+
|
370 |
+
arrSampleConfidences.append(confidencesList[detNum])
|
371 |
+
arrSampleMatch.append(match)
|
372 |
+
|
373 |
+
arrGlobalConfidences.append(confidencesList[detNum]);
|
374 |
+
arrGlobalMatches.append(match);
|
375 |
+
|
376 |
+
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
377 |
+
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
378 |
+
if numGtCare == 0:
|
379 |
+
recall = float(1)
|
380 |
+
precision = float(0) if numDetCare >0 else float(1)
|
381 |
+
sampleAP = precision
|
382 |
+
else:
|
383 |
+
recall = float(detCorrect) / numGtCare
|
384 |
+
precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
|
385 |
+
if evaluationParams['CONFIDENCES']:
|
386 |
+
sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
|
387 |
+
|
388 |
+
hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
|
389 |
+
|
390 |
+
matchedSum += detCorrect
|
391 |
+
numGlobalCareGt += numGtCare
|
392 |
+
numGlobalCareDet += numDetCare
|
393 |
+
|
394 |
+
perSampleMetrics[resFile] = {
|
395 |
+
'precision':precision,
|
396 |
+
'recall':recall,
|
397 |
+
'hmean':hmean,
|
398 |
+
'pairs':pairs,
|
399 |
+
'AP':sampleAP,
|
400 |
+
'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
|
401 |
+
'gtPolPoints':gtPolPoints,
|
402 |
+
'detPolPoints':detPolPoints,
|
403 |
+
'gtTrans':gtTrans,
|
404 |
+
'detTrans':detTrans,
|
405 |
+
'gtDontCare':gtDontCarePolsNum,
|
406 |
+
'detDontCare':detDontCarePolsNum,
|
407 |
+
'evaluationParams': evaluationParams,
|
408 |
+
'evaluationLog': evaluationLog
|
409 |
+
}
|
410 |
+
|
411 |
+
# Compute AP
|
412 |
+
AP = 0
|
413 |
+
if evaluationParams['CONFIDENCES']:
|
414 |
+
AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
|
415 |
+
|
416 |
+
methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
|
417 |
+
methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
|
418 |
+
methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
|
419 |
+
|
420 |
+
methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
|
421 |
+
|
422 |
+
resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
|
423 |
+
|
424 |
+
|
425 |
+
return resDict;
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
if __name__=='__main__':
|
430 |
+
'''
|
431 |
+
results_dir: result directory
|
432 |
+
score_det: score of detection bounding box
|
433 |
+
score_rec: score of the mask recognition branch
|
434 |
+
score_rec_seq: score of the sequence recognition branch
|
435 |
+
lexicon_type: 1 for generic; 2 for weak; 3 for strong
|
436 |
+
'''
|
437 |
+
results_dir = '../../../output/mixtrain/inference/total_text_test/model_0250000_1000_results/'
|
438 |
+
score_det = 0.05
|
439 |
+
score_rec = 0.5
|
440 |
+
use_lexicon = False
|
441 |
+
score_rec_seq = 0.9
|
442 |
+
# use_lexicon = True
|
443 |
+
# score_rec_seq = 0.8
|
444 |
+
evaluate_result_path = prepare_results_for_evaluation(results_dir,
|
445 |
+
use_lexicon=use_lexicon, cache_dir='./cache_files',
|
446 |
+
score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
|
447 |
+
p = {
|
448 |
+
'g': "../gt.zip",
|
449 |
+
'o': "./cache_files",
|
450 |
+
's': evaluate_result_path
|
451 |
+
}
|
452 |
+
rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
|
evaluation/totaltext/gt.zip
ADDED
Binary file (106 kB). View file
|
|
evaluation/weighted_editdistance.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def weighted_edit_distance(word1, word2, scores):
|
2 |
+
m = len(word1)
|
3 |
+
n = len(word2)
|
4 |
+
dp = [[0 for __ in range(m + 1)] for __ in range(n + 1)]
|
5 |
+
for j in range(m + 1):
|
6 |
+
dp[0][j] = j
|
7 |
+
for i in range(n + 1):
|
8 |
+
dp[i][0] = i
|
9 |
+
for i in range(1, n + 1): ## word2
|
10 |
+
for j in range(1, m + 1): ## word1
|
11 |
+
delect_cost = ed_delect_cost(j-1, i-1, word1, word2, scores) ## delect a[i]
|
12 |
+
insert_cost = ed_insert_cost(j-1, i-1, word1, word2, scores) ## insert b[j]
|
13 |
+
if word1[j - 1] != word2[i - 1]:
|
14 |
+
replace_cost = ed_replace_cost(j-1, i-1, word1, word2, scores) ## replace a[i] with b[j]
|
15 |
+
else:
|
16 |
+
replace_cost = 0
|
17 |
+
dp[i][j] = min(dp[i-1][j] + insert_cost, dp[i][j-1] + delect_cost, dp[i-1][j-1] + replace_cost)
|
18 |
+
|
19 |
+
return dp[n][m]
|
20 |
+
|
21 |
+
def ed_delect_cost(j, i, word1, word2, scores):
|
22 |
+
## delect a[i]
|
23 |
+
c = char2num(word1[j])
|
24 |
+
return scores[c][j]
|
25 |
+
|
26 |
+
|
27 |
+
def ed_insert_cost(i, j, word1, word2, scores):
|
28 |
+
## insert b[j]
|
29 |
+
if i < len(word1) - 1:
|
30 |
+
c1 = char2num(word1[i])
|
31 |
+
c2 = char2num(word1[i + 1])
|
32 |
+
return (scores[c1][i] + scores[c2][i+1])/2
|
33 |
+
else:
|
34 |
+
c1 = char2num(word1[i])
|
35 |
+
return scores[c1][i]
|
36 |
+
|
37 |
+
|
38 |
+
def ed_replace_cost(i, j, word1, word2, scores):
|
39 |
+
## replace a[i] with b[j]
|
40 |
+
c1 = char2num(word1[i])
|
41 |
+
c2 = char2num(word2[j])
|
42 |
+
# if word1 == "eeatpisaababarait".upper():
|
43 |
+
# print(scores[c2][i]/scores[c1][i])
|
44 |
+
|
45 |
+
return max(1 - scores[c2][i]/scores[c1][i]*5, 0)
|
46 |
+
|
47 |
+
def char2num(char):
|
48 |
+
if char in '0123456789':
|
49 |
+
num = ord(char) - ord('0') + 1
|
50 |
+
elif char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
|
51 |
+
num = ord(char.lower()) - ord('a') + 11
|
52 |
+
else:
|
53 |
+
print('error symbol', char)
|
54 |
+
exit()
|
55 |
+
return num - 1
|
example1.jpg
ADDED
example2.jpg
ADDED
example3.jpg
ADDED
maskrcnn_benchmark/config/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
from .defaults import _C as cfg
|
maskrcnn_benchmark/config/defaults.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
import os
|
4 |
+
|
5 |
+
from yacs.config import CfgNode as CN
|
6 |
+
|
7 |
+
|
8 |
+
# -----------------------------------------------------------------------------
|
9 |
+
# Convention about Training / Test specific parameters
|
10 |
+
# -----------------------------------------------------------------------------
|
11 |
+
# Whenever an argument can be either used for training or for testing, the
|
12 |
+
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
13 |
+
# or _TEST for a test-specific parameter.
|
14 |
+
# For example, the number of images during training will be
|
15 |
+
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
|
16 |
+
# IMAGES_PER_BATCH_TEST
|
17 |
+
|
18 |
+
# -----------------------------------------------------------------------------
|
19 |
+
# Config definition
|
20 |
+
# -----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
_C = CN()
|
23 |
+
|
24 |
+
_C.MODEL = CN()
|
25 |
+
_C.MODEL.RPN_ONLY = False
|
26 |
+
_C.MODEL.MASK_ON = False
|
27 |
+
_C.MODEL.SEG_ON = False
|
28 |
+
_C.MODEL.CHAR_MASK_ON = False
|
29 |
+
_C.MODEL.DEVICE = "cuda"
|
30 |
+
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
|
31 |
+
_C.MODEL.TRAIN_DETECTION_ONLY = False
|
32 |
+
_C.MODEL.RESNET34 = False
|
33 |
+
|
34 |
+
# If the WEIGHT starts with a catalog://, like :R-50, the code will look for
|
35 |
+
# the path in paths_catalog. Else, it will use it as the specified absolute
|
36 |
+
# path
|
37 |
+
_C.MODEL.WEIGHT = ""
|
38 |
+
|
39 |
+
_C.SEQUENCE = CN()
|
40 |
+
_C.SEQUENCE.SEQ_ON = False
|
41 |
+
_C.SEQUENCE.NUM_CHAR = 38
|
42 |
+
_C.SEQUENCE.BOS_TOKEN = 0
|
43 |
+
_C.SEQUENCE.MAX_LENGTH = 32
|
44 |
+
_C.SEQUENCE.TEACHER_FORCE_RATIO = 1.0
|
45 |
+
_C.SEQUENCE.TWO_CONV = False
|
46 |
+
_C.SEQUENCE.MEAN_SCORE = False
|
47 |
+
_C.SEQUENCE.RESIZE_HEIGHT = 16
|
48 |
+
_C.SEQUENCE.RESIZE_WIDTH = 64
|
49 |
+
|
50 |
+
|
51 |
+
# -----------------------------------------------------------------------------
|
52 |
+
# INPUT
|
53 |
+
# -----------------------------------------------------------------------------
|
54 |
+
_C.INPUT = CN()
|
55 |
+
# Size of the smallest side of the image during training
|
56 |
+
_C.INPUT.MIN_SIZE_TRAIN = (800,) # (800,)
|
57 |
+
# Maximum size of the side of the image during training
|
58 |
+
_C.INPUT.MAX_SIZE_TRAIN = 1333
|
59 |
+
# Size of the smallest side of the image during testing
|
60 |
+
_C.INPUT.MIN_SIZE_TEST = 800
|
61 |
+
# Maximum size of the side of the image during testing
|
62 |
+
_C.INPUT.MAX_SIZE_TEST = 1333
|
63 |
+
# Values to be used for image normalization
|
64 |
+
_C.INPUT.PIXEL_MEAN = [102.9801, 115.9465, 122.7717]
|
65 |
+
# Values to be used for image normalization
|
66 |
+
_C.INPUT.PIXEL_STD = [1.0, 1.0, 1.0]
|
67 |
+
# Convert image to BGR format (for Caffe2 models), in range 0-255
|
68 |
+
_C.INPUT.TO_BGR255 = True
|
69 |
+
_C.INPUT.STRICT_RESIZE = False
|
70 |
+
|
71 |
+
|
72 |
+
# -----------------------------------------------------------------------------
|
73 |
+
# Dataset
|
74 |
+
# -----------------------------------------------------------------------------
|
75 |
+
_C.DATASETS = CN()
|
76 |
+
# List of the dataset names for training, as present in paths_catalog.py
|
77 |
+
_C.DATASETS.TRAIN = ()
|
78 |
+
# List of the dataset names for testing, as present in paths_catalog.py
|
79 |
+
_C.DATASETS.TEST = ()
|
80 |
+
|
81 |
+
_C.DATASETS.RATIOS = []
|
82 |
+
|
83 |
+
_C.DATASETS.AUG = False
|
84 |
+
_C.DATASETS.RANDOM_CROP_PROB = 0.0
|
85 |
+
_C.DATASETS.IGNORE_DIFFICULT = False
|
86 |
+
_C.DATASETS.FIX_CROP = False
|
87 |
+
_C.DATASETS.CROP_SIZE = (512, 512)
|
88 |
+
_C.DATASETS.MAX_ROTATE_THETA = 30
|
89 |
+
_C.DATASETS.FIX_ROTATE = False
|
90 |
+
|
91 |
+
# -----------------------------------------------------------------------------
|
92 |
+
# DataLoader
|
93 |
+
# -----------------------------------------------------------------------------
|
94 |
+
_C.DATALOADER = CN()
|
95 |
+
# Number of data loading threads
|
96 |
+
_C.DATALOADER.NUM_WORKERS = 4
|
97 |
+
# If > 0, this enforces that each collated batch should have a size divisible
|
98 |
+
# by SIZE_DIVISIBILITY
|
99 |
+
_C.DATALOADER.SIZE_DIVISIBILITY = 0
|
100 |
+
# If True, each batch should contain only images for which the aspect ratio
|
101 |
+
# is compatible. This groups portrait images together, and landscape images
|
102 |
+
# are not batched with portrait images.
|
103 |
+
_C.DATALOADER.ASPECT_RATIO_GROUPING = True
|
104 |
+
|
105 |
+
# ---------------------------------------------------------------------------- #
|
106 |
+
# Backbone options
|
107 |
+
# ---------------------------------------------------------------------------- #
|
108 |
+
_C.MODEL.BACKBONE = CN()
|
109 |
+
|
110 |
+
# The backbone conv body to use
|
111 |
+
# The string must match a function that is imported in modeling.model_builder
|
112 |
+
# (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN
|
113 |
+
# backbone)
|
114 |
+
_C.MODEL.BACKBONE.CONV_BODY = "R-50-C4"
|
115 |
+
|
116 |
+
# Add StopGrad at a specified stage so the bottom layers are frozen
|
117 |
+
_C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2
|
118 |
+
_C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4
|
119 |
+
|
120 |
+
# ---------------------------------------------------------------------------- #
|
121 |
+
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
|
122 |
+
# Note that parts of a resnet may be used for both the backbone and the head
|
123 |
+
# These options apply to both
|
124 |
+
# ---------------------------------------------------------------------------- #
|
125 |
+
_C.MODEL.RESNETS = CN()
|
126 |
+
|
127 |
+
# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
|
128 |
+
_C.MODEL.RESNETS.NUM_GROUPS = 1
|
129 |
+
|
130 |
+
# Baseline width of each group
|
131 |
+
_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
|
132 |
+
|
133 |
+
# Place the stride 2 conv on the 1x1 filter
|
134 |
+
# Use True only for the original MSRA ResNet; use False for C2 and Torch models
|
135 |
+
_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
|
136 |
+
|
137 |
+
# Residual transformation function
|
138 |
+
_C.MODEL.RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm"
|
139 |
+
# ResNet's stem function (conv1 and pool1)
|
140 |
+
_C.MODEL.RESNETS.STEM_FUNC = "StemWithFixedBatchNorm"
|
141 |
+
|
142 |
+
# Apply dilation in stage "res5"
|
143 |
+
_C.MODEL.RESNETS.RES5_DILATION = 1
|
144 |
+
|
145 |
+
_C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
|
146 |
+
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
|
147 |
+
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
|
148 |
+
|
149 |
+
_C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False)
|
150 |
+
_C.MODEL.RESNETS.WITH_MODULATED_DCN = False
|
151 |
+
_C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1
|
152 |
+
_C.MODEL.RESNETS.LAYERS = (3, 4, 6, 3)
|
153 |
+
|
154 |
+
# ---------------------------------------------------------------------------- #
|
155 |
+
# FPN options
|
156 |
+
# ---------------------------------------------------------------------------- #
|
157 |
+
_C.MODEL.FPN = CN()
|
158 |
+
_C.MODEL.FPN.USE_GN = False
|
159 |
+
_C.MODEL.FPN.USE_RELU = False
|
160 |
+
|
161 |
+
# ---------------------------------------------------------------------------- #
|
162 |
+
# RPN options
|
163 |
+
# ---------------------------------------------------------------------------- #
|
164 |
+
_C.MODEL.RPN = CN()
|
165 |
+
_C.MODEL.RPN.USE_FPN = False
|
166 |
+
# Base RPN anchor sizes given in absolute pixels w.r.t. the scaled network input
|
167 |
+
_C.MODEL.RPN.ANCHOR_SIZES = (32, 64, 128, 256, 512)
|
168 |
+
# Stride of the feature map that RPN is attached.
|
169 |
+
# For FPN, number of strides should match number of scales
|
170 |
+
_C.MODEL.RPN.ANCHOR_STRIDE = (16,)
|
171 |
+
# RPN anchor aspect ratios
|
172 |
+
_C.MODEL.RPN.ASPECT_RATIOS = (0.5, 1.0, 2.0)
|
173 |
+
# Remove RPN anchors that go outside the image by RPN_STRADDLE_THRESH pixels
|
174 |
+
# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
|
175 |
+
_C.MODEL.RPN.STRADDLE_THRESH = 0
|
176 |
+
# Minimum overlap required between an anchor and ground-truth box for the
|
177 |
+
# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
|
178 |
+
# ==> positive RPN example)
|
179 |
+
_C.MODEL.RPN.FG_IOU_THRESHOLD = 0.7
|
180 |
+
# Maximum overlap allowed between an anchor and ground-truth box for the
|
181 |
+
# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
|
182 |
+
# ==> negative RPN example)
|
183 |
+
_C.MODEL.RPN.BG_IOU_THRESHOLD = 0.3
|
184 |
+
# Total number of RPN examples per image
|
185 |
+
_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
|
186 |
+
# Target fraction of foreground (positive) examples per RPN minibatch
|
187 |
+
_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
|
188 |
+
# Number of top scoring RPN proposals to keep before applying NMS
|
189 |
+
# When FPN is used, this is *per FPN level* (not total)
|
190 |
+
_C.MODEL.RPN.PRE_NMS_TOP_N_TRAIN = 12000
|
191 |
+
_C.MODEL.RPN.PRE_NMS_TOP_N_TEST = 6000
|
192 |
+
# Number of top scoring RPN proposals to keep after applying NMS
|
193 |
+
_C.MODEL.RPN.POST_NMS_TOP_N_TRAIN = 2000
|
194 |
+
_C.MODEL.RPN.POST_NMS_TOP_N_TEST = 1000
|
195 |
+
# NMS threshold used on RPN proposals
|
196 |
+
_C.MODEL.RPN.NMS_THRESH = 0.7
|
197 |
+
# Proposal height and width both need to be greater than RPN_MIN_SIZE
|
198 |
+
# (a the scale used during training or inference)
|
199 |
+
_C.MODEL.RPN.MIN_SIZE = 0
|
200 |
+
# Number of top scoring RPN proposals to keep after combining proposals from
|
201 |
+
# all FPN levels
|
202 |
+
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
|
203 |
+
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
|
204 |
+
|
205 |
+
_C.MODEL.SEG = CN()
|
206 |
+
_C.MODEL.SEG.USE_FPN = False
|
207 |
+
_C.MODEL.SEG.USE_FUSE_FEATURE = False
|
208 |
+
# Total number of SEG examples per image
|
209 |
+
_C.MODEL.SEG.BATCH_SIZE_PER_IMAGE = 256
|
210 |
+
# Target fraction of foreground (positive) examples per SEG minibatch
|
211 |
+
_C.MODEL.SEG.POSITIVE_FRACTION = 0.5
|
212 |
+
# NMS threshold used on SEG proposals
|
213 |
+
_C.MODEL.SEG.BINARY_THRESH = 0.5
|
214 |
+
_C.MODEL.SEG.USE_MULTIPLE_THRESH = False
|
215 |
+
_C.MODEL.SEG.MULTIPLE_THRESH = (0.2, 0.3, 0.5, 0.7)
|
216 |
+
_C.MODEL.SEG.BOX_THRESH = 0.7
|
217 |
+
# Proposal height and width both need to be greater than RPN_MIN_SIZE
|
218 |
+
# (a the scale used during training or inference)
|
219 |
+
_C.MODEL.SEG.MIN_SIZE = 0
|
220 |
+
_C.MODEL.SEG.SHRINK_RATIO = 0.5
|
221 |
+
# Number of top scoring RPN proposals to keep after combining proposals from
|
222 |
+
# all FPN levels
|
223 |
+
_C.MODEL.SEG.TOP_N_TRAIN = 1000
|
224 |
+
_C.MODEL.SEG.TOP_N_TEST = 1000
|
225 |
+
_C.MODEL.SEG.AUG_PROPOSALS = False
|
226 |
+
_C.MODEL.SEG.IGNORE_DIFFICULT = True
|
227 |
+
_C.MODEL.SEG.EXPAND_RATIO = 1.6
|
228 |
+
_C.MODEL.SEG.BOX_EXPAND_RATIO = 1.5
|
229 |
+
_C.MODEL.SEG.USE_SEG_POLY = False
|
230 |
+
_C.MODEL.SEG.USE_PPM = False
|
231 |
+
|
232 |
+
|
233 |
+
# ---------------------------------------------------------------------------- #
|
234 |
+
# ROI HEADS options
|
235 |
+
# ---------------------------------------------------------------------------- #
|
236 |
+
_C.MODEL.ROI_HEADS = CN()
|
237 |
+
_C.MODEL.ROI_HEADS.USE_FPN = False
|
238 |
+
# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
|
239 |
+
_C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5
|
240 |
+
# Overlap threshold for an RoI to be considered background
|
241 |
+
# (class = 0 if overlap in [0, BG_IOU_THRESHOLD))
|
242 |
+
_C.MODEL.ROI_HEADS.BG_IOU_THRESHOLD = 0.5
|
243 |
+
# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
|
244 |
+
# These are empirically chosen to approximately lead to unit variance targets
|
245 |
+
_C.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
|
246 |
+
# RoI minibatch size *per image* (number of regions of interest [ROIs])
|
247 |
+
# Total number of RoIs per training minibatch =
|
248 |
+
# TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS
|
249 |
+
# E.g., a common configuration is: 512 * 2 * 8 = 8192
|
250 |
+
_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
|
251 |
+
# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
|
252 |
+
_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
|
253 |
+
|
254 |
+
# Only used on test mode
|
255 |
+
|
256 |
+
# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
|
257 |
+
# balance obtaining high recall with not having too many low precision
|
258 |
+
# detections that will slow down inference post processing steps (like NMS)
|
259 |
+
# _C.MODEL.ROI_HEADS.SCORE_THRESH = 0.05
|
260 |
+
_C.MODEL.ROI_HEADS.SCORE_THRESH = 0.0
|
261 |
+
# Overlap threshold used for non-maximum suppression (suppress boxes with
|
262 |
+
# IoU >= this threshold)
|
263 |
+
_C.MODEL.ROI_HEADS.NMS = 0.5
|
264 |
+
# Maximum number of detections to return per image (100 is based on the limit
|
265 |
+
# established for the COCO dataset)
|
266 |
+
_C.MODEL.ROI_HEADS.DETECTIONS_PER_IMG = 100
|
267 |
+
|
268 |
+
|
269 |
+
_C.MODEL.ROI_BOX_HEAD = CN()
|
270 |
+
_C.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor"
|
271 |
+
_C.MODEL.ROI_BOX_HEAD.PREDICTOR = "FastRCNNPredictor"
|
272 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
|
273 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
|
274 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,)
|
275 |
+
_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81
|
276 |
+
# Hidden layer dimension when using an MLP for the RoI box head
|
277 |
+
_C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024
|
278 |
+
_C.MODEL.ROI_BOX_HEAD.USE_REGRESSION = True
|
279 |
+
_C.MODEL.ROI_BOX_HEAD.INFERENCE_USE_BOX = True
|
280 |
+
_C.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE = False
|
281 |
+
_C.MODEL.ROI_BOX_HEAD.SOFT_MASKED_FEATURE_RATIO = 0.
|
282 |
+
_C.MODEL.ROI_BOX_HEAD.MIX_OPTION = ""
|
283 |
+
|
284 |
+
|
285 |
+
_C.MODEL.ROI_MASK_HEAD = CN()
|
286 |
+
_C.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor"
|
287 |
+
_C.MODEL.ROI_MASK_HEAD.PREDICTOR = "MaskRCNNC4Predictor"
|
288 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
|
289 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H = 32
|
290 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W = 128
|
291 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
|
292 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_SCALES = (1.0 / 16,)
|
293 |
+
_C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024
|
294 |
+
_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
|
295 |
+
_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
|
296 |
+
_C.MODEL.ROI_MASK_HEAD.RESOLUTION_H = 32
|
297 |
+
_C.MODEL.ROI_MASK_HEAD.RESOLUTION_W = 128
|
298 |
+
_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
|
299 |
+
_C.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES = 38
|
300 |
+
_C.MODEL.ROI_MASK_HEAD.USE_WEIGHTED_CHAR_MASK = False
|
301 |
+
_C.MODEL.ROI_MASK_HEAD.MASK_BATCH_SIZE_PER_IM = 64
|
302 |
+
_C.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE = False
|
303 |
+
_C.MODEL.ROI_MASK_HEAD.SOFT_MASKED_FEATURE_RATIO = 0.
|
304 |
+
_C.MODEL.ROI_MASK_HEAD.MIX_OPTION = ""
|
305 |
+
|
306 |
+
# ---------------------------------------------------------------------------- #
|
307 |
+
# Solver
|
308 |
+
# ---------------------------------------------------------------------------- #
|
309 |
+
_C.SOLVER = CN()
|
310 |
+
_C.SOLVER.MAX_ITER = 40000
|
311 |
+
|
312 |
+
_C.SOLVER.BASE_LR = 0.001
|
313 |
+
_C.SOLVER.BIAS_LR_FACTOR = 2
|
314 |
+
|
315 |
+
_C.SOLVER.MOMENTUM = 0.9
|
316 |
+
|
317 |
+
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
318 |
+
_C.SOLVER.WEIGHT_DECAY_BIAS = 0
|
319 |
+
|
320 |
+
_C.SOLVER.GAMMA = 0.1
|
321 |
+
_C.SOLVER.STEPS = (30000,)
|
322 |
+
|
323 |
+
_C.SOLVER.WARMUP_FACTOR = 1.0 / 3
|
324 |
+
_C.SOLVER.WARMUP_ITERS = 500
|
325 |
+
_C.SOLVER.WARMUP_METHOD = "linear"
|
326 |
+
|
327 |
+
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
328 |
+
|
329 |
+
# Number of images per batch
|
330 |
+
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
331 |
+
# see 2 images per batch
|
332 |
+
_C.SOLVER.IMS_PER_BATCH = 16
|
333 |
+
|
334 |
+
_C.SOLVER.RESUME = True
|
335 |
+
|
336 |
+
_C.SOLVER.USE_ADAM = False
|
337 |
+
|
338 |
+
_C.SOLVER.POW_SCHEDULE = False
|
339 |
+
|
340 |
+
_C.SOLVER.DISPLAY_FREQ = 20
|
341 |
+
|
342 |
+
# ---------------------------------------------------------------------------- #
|
343 |
+
# Specific test options
|
344 |
+
# ---------------------------------------------------------------------------- #
|
345 |
+
_C.TEST = CN()
|
346 |
+
_C.TEST.EXPECTED_RESULTS = []
|
347 |
+
_C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4
|
348 |
+
# Number of images per batch
|
349 |
+
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
350 |
+
# see 2 images per batch
|
351 |
+
_C.TEST.IMS_PER_BATCH = 8
|
352 |
+
_C.TEST.VIS = False
|
353 |
+
# from 0 to 255
|
354 |
+
_C.TEST.CHAR_THRESH = 128
|
355 |
+
|
356 |
+
|
357 |
+
# ---------------------------------------------------------------------------- #
|
358 |
+
# Misc options
|
359 |
+
# ---------------------------------------------------------------------------- #
|
360 |
+
_C.OUTPUT_DIR = "."
|
361 |
+
|
362 |
+
_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")
|
363 |
+
|
364 |
+
|
365 |
+
# ---------------------------------------------------------------------------- #
|
366 |
+
# Precision options
|
367 |
+
# ---------------------------------------------------------------------------- #
|
368 |
+
|
369 |
+
# Precision of input, allowable: (float32, float16)
|
370 |
+
_C.DTYPE = "float32"
|
371 |
+
|
372 |
+
# Enable verbosity in apex.amp
|
373 |
+
_C.AMP_VERBOSE = False
|
maskrcnn_benchmark/config/paths_catalog.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
"""Centralized catalog of paths."""
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
class DatasetCatalog(object):
|
8 |
+
DATA_DIR = "datasets"
|
9 |
+
# DATA_DIR = "/share/mhliao/MaskTextSpotterV3/datasets/"
|
10 |
+
|
11 |
+
DATASETS = {
|
12 |
+
"coco_2014_train": (
|
13 |
+
"coco/train2014",
|
14 |
+
"coco/annotations/instances_train2014.json",
|
15 |
+
),
|
16 |
+
"coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"),
|
17 |
+
"coco_2014_minival": (
|
18 |
+
"coco/val2014",
|
19 |
+
"coco/annotations/instances_minival2014.json",
|
20 |
+
),
|
21 |
+
"coco_2014_valminusminival": (
|
22 |
+
"coco/val2014",
|
23 |
+
"coco/annotations/instances_valminusminival2014.json",
|
24 |
+
),
|
25 |
+
"icdar_2013_train": ("icdar2013/train_images", "icdar2013/train_gts"),
|
26 |
+
"icdar_2013_test": ("icdar2013/test_images", "icdar2013/test_gts"),
|
27 |
+
"rotated_ic13_test_0": ("icdar2013/rotated_test_images_0", "icdar2013/rotated_test_gts_0"),
|
28 |
+
"rotated_ic13_test_15": ("icdar2013/rotated_test_images_15", "icdar2013/rotated_test_gts_15"),
|
29 |
+
"rotated_ic13_test_30": ("icdar2013/rotated_test_images_30", "icdar2013/rotated_test_gts_30"),
|
30 |
+
"rotated_ic13_test_45": ("icdar2013/rotated_test_images_45", "icdar2013/rotated_test_gts_45"),
|
31 |
+
"rotated_ic13_test_60": ("icdar2013/rotated_test_images_60", "icdar2013/rotated_test_gts_60"),
|
32 |
+
"rotated_ic13_test_75": ("icdar2013/rotated_test_images_75", "icdar2013/rotated_test_gts_75"),
|
33 |
+
"rotated_ic13_test_85": ("icdar2013/rotated_test_images_85", "icdar2013/rotated_test_gts_85"),
|
34 |
+
"rotated_ic13_test_90": ("icdar2013/rotated_test_images_90", "icdar2013/rotated_test_gts_90"),
|
35 |
+
"rotated_ic13_test_-15": ("icdar2013/rotated_test_images_-15", "icdar2013/rotated_test_gts_-15"),
|
36 |
+
"rotated_ic13_test_-30": ("icdar2013/rotated_test_images_-30", "icdar2013/rotated_test_gts_-30"),
|
37 |
+
"rotated_ic13_test_-45": ("icdar2013/rotated_test_images_-45", "icdar2013/rotated_test_gts_-45"),
|
38 |
+
"rotated_ic13_test_-60": ("icdar2013/rotated_test_images_-60", "icdar2013/rotated_test_gts_-60"),
|
39 |
+
"rotated_ic13_test_-75": ("icdar2013/rotated_test_images_-75", "icdar2013/rotated_test_gts_-75"),
|
40 |
+
"rotated_ic13_test_-90": ("icdar2013/rotated_test_images_-90", "icdar2013/rotated_test_gts_-90"),
|
41 |
+
"icdar_2015_train": ("icdar2015/train_images", "icdar2015/train_gts"),
|
42 |
+
"icdar_2015_test": (
|
43 |
+
"icdar2015/test_images",
|
44 |
+
# "icdar2015/test_gts",
|
45 |
+
),
|
46 |
+
"synthtext_train": ("synthtext/train_images", "synthtext/train_gts"),
|
47 |
+
"synthtext_test": ("synthtext/test_images", "synthtext/test_gts"),
|
48 |
+
"total_text_train": ("total_text/train_images", "total_text/train_gts"),
|
49 |
+
"td500_train": ("TD_TR/TD500/train_images", "TD500/train_gts"),
|
50 |
+
"td500_test": ("TD_TR/TD500/test_images", ),
|
51 |
+
"tr400_train": ("TD_TR/TR400/train_images", "TR400/train_gts"),
|
52 |
+
"total_text_test": (
|
53 |
+
"total_text/test_images",
|
54 |
+
# "total_text/test_gts",
|
55 |
+
),
|
56 |
+
"scut-eng-char_train": (
|
57 |
+
"scut-eng-char/train_images",
|
58 |
+
"scut-eng-char/train_gts",
|
59 |
+
),
|
60 |
+
}
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def get(name):
|
64 |
+
if "coco" in name:
|
65 |
+
data_dir = DatasetCatalog.DATA_DIR
|
66 |
+
attrs = DatasetCatalog.DATASETS[name]
|
67 |
+
args = dict(
|
68 |
+
root=os.path.join(data_dir, attrs[0]),
|
69 |
+
ann_file=os.path.join(data_dir, attrs[1]),
|
70 |
+
)
|
71 |
+
return dict(factory="COCODataset", args=args)
|
72 |
+
elif "icdar_2013" in name:
|
73 |
+
data_dir = DatasetCatalog.DATA_DIR
|
74 |
+
attrs = DatasetCatalog.DATASETS[name]
|
75 |
+
args = dict(
|
76 |
+
use_charann=True,
|
77 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
78 |
+
gts_dir=os.path.join(data_dir, attrs[1]),
|
79 |
+
# imgs_dir='/tmp/icdar2013/icdar2013/train_images',
|
80 |
+
# gts_dir='/tmp/icdar2013/icdar2013/train_gts',
|
81 |
+
)
|
82 |
+
return dict(args=args, factory="IcdarDataset")
|
83 |
+
elif "rotated_ic13" in name:
|
84 |
+
data_dir = DatasetCatalog.DATA_DIR
|
85 |
+
attrs = DatasetCatalog.DATASETS[name]
|
86 |
+
args = dict(
|
87 |
+
use_charann=True,
|
88 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
89 |
+
gts_dir=os.path.join(data_dir, attrs[1]),
|
90 |
+
)
|
91 |
+
return dict(args=args, factory="IcdarDataset")
|
92 |
+
elif "icdar_2015" in name:
|
93 |
+
data_dir = DatasetCatalog.DATA_DIR
|
94 |
+
attrs = DatasetCatalog.DATASETS[name]
|
95 |
+
if len(attrs) > 1:
|
96 |
+
gts_dir = os.path.join(data_dir, attrs[1])
|
97 |
+
else:
|
98 |
+
gts_dir = None
|
99 |
+
|
100 |
+
args = dict(
|
101 |
+
use_charann=False,
|
102 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
103 |
+
gts_dir=gts_dir,
|
104 |
+
# imgs_dir='/tmp/icdar2015/icdar2015/train_images/',
|
105 |
+
# gts_dir='/tmp/icdar2015/icdar2015/train_gts/',
|
106 |
+
)
|
107 |
+
return dict(args=args, factory="IcdarDataset")
|
108 |
+
elif "synthtext" in name:
|
109 |
+
data_dir = DatasetCatalog.DATA_DIR
|
110 |
+
attrs = DatasetCatalog.DATASETS[name]
|
111 |
+
args = dict(
|
112 |
+
use_charann=True,
|
113 |
+
list_file_path=os.path.join(data_dir, "synthtext/train_list.txt"),
|
114 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
115 |
+
gts_dir=os.path.join(data_dir, attrs[1]),
|
116 |
+
# imgs_dir='/tmp/synth/SynthText/',
|
117 |
+
# gts_dir='/tmp/synth_gt/SynthText_GT_E2E/',
|
118 |
+
)
|
119 |
+
return dict(args=args, factory="SynthtextDataset")
|
120 |
+
elif "total_text" in name:
|
121 |
+
data_dir = DatasetCatalog.DATA_DIR
|
122 |
+
# data_dir = '/tmp/total_text/'
|
123 |
+
attrs = DatasetCatalog.DATASETS[name]
|
124 |
+
if len(attrs) > 1:
|
125 |
+
gts_dir = os.path.join(data_dir, attrs[1])
|
126 |
+
else:
|
127 |
+
gts_dir = None
|
128 |
+
args = dict(
|
129 |
+
use_charann=False,
|
130 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
131 |
+
gts_dir=gts_dir,
|
132 |
+
# imgs_dir='/tmp/total_text/total_text/train_images/',
|
133 |
+
# gts_dir='/tmp/total_text/total_text/train_gts/',
|
134 |
+
)
|
135 |
+
return dict(args=args, factory="TotaltextDataset")
|
136 |
+
elif "scut-eng-char" in name:
|
137 |
+
data_dir = DatasetCatalog.DATA_DIR
|
138 |
+
attrs = DatasetCatalog.DATASETS[name]
|
139 |
+
args = dict(
|
140 |
+
use_charann=True,
|
141 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
142 |
+
gts_dir=os.path.join(data_dir, attrs[1]),
|
143 |
+
# imgs_dir='/tmp/scut-eng-char/scut-eng-char/train_images/',
|
144 |
+
# gts_dir='/tmp/scut-eng-char/scut-eng-char/train_gts/',
|
145 |
+
)
|
146 |
+
return dict(args=args, factory="ScutDataset")
|
147 |
+
elif "td500" in name:
|
148 |
+
data_dir = DatasetCatalog.DATA_DIR
|
149 |
+
attrs = DatasetCatalog.DATASETS[name]
|
150 |
+
if len(attrs) > 1:
|
151 |
+
gts_dir = os.path.join(data_dir, attrs[1])
|
152 |
+
else:
|
153 |
+
gts_dir = None
|
154 |
+
args = dict(
|
155 |
+
use_charann=False,
|
156 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
157 |
+
gts_dir=gts_dir,
|
158 |
+
)
|
159 |
+
return dict(args=args, factory="TotaltextDataset")
|
160 |
+
elif "tr400" in name:
|
161 |
+
data_dir = DatasetCatalog.DATA_DIR
|
162 |
+
attrs = DatasetCatalog.DATASETS[name]
|
163 |
+
if len(attrs) > 1:
|
164 |
+
gts_dir = os.path.join(data_dir, attrs[1])
|
165 |
+
else:
|
166 |
+
gts_dir = None
|
167 |
+
args = dict(
|
168 |
+
use_charann=False,
|
169 |
+
imgs_dir=os.path.join(data_dir, attrs[0]),
|
170 |
+
gts_dir=gts_dir,
|
171 |
+
)
|
172 |
+
return dict(args=args, factory="TotaltextDataset")
|
173 |
+
raise RuntimeError("Dataset not available: {}".format(name))
|
174 |
+
|
175 |
+
|
176 |
+
class ModelCatalog(object):
|
177 |
+
S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron"
|
178 |
+
C2_IMAGENET_MODELS = {
|
179 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
180 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
181 |
+
"MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
|
182 |
+
"MSRA/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
|
183 |
+
"MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
|
184 |
+
"MSRA/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
|
185 |
+
"FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
|
186 |
+
}
|
187 |
+
|
188 |
+
C2_DETECTRON_SUFFIX = "output/train/{}coco_2014_train%3A{}coco_2014_valminusminival/generalized_rcnn/model_final.pkl"
|
189 |
+
C2_DETECTRON_MODELS = {
|
190 |
+
"35857197/e2e_faster_rcnn_R-50-C4_1x": "01_33_49.iAX0mXvW",
|
191 |
+
"35857345/e2e_faster_rcnn_R-50-FPN_1x": "01_36_30.cUF7QR7I",
|
192 |
+
"35857890/e2e_faster_rcnn_R-101-FPN_1x": "01_38_50.sNxI7sX7",
|
193 |
+
"36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "06_31_39.5MIHi1fZ",
|
194 |
+
"35858791/e2e_mask_rcnn_R-50-C4_1x": "01_45_57.ZgkA7hPB",
|
195 |
+
"35858933/e2e_mask_rcnn_R-50-FPN_1x": "01_48_14.DzEQe4wC",
|
196 |
+
"35861795/e2e_mask_rcnn_R-101-FPN_1x": "02_31_37.KqyEK4tT",
|
197 |
+
"36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "06_35_59.RZotkLKI",
|
198 |
+
"37129812/e2e_mask_rcnn_X-152-32x8d-FPN-IN5k_1.44x": "09_35_36.8pzTQKYK",
|
199 |
+
# keypoints
|
200 |
+
"37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "08_42_54.kdzV35ao"
|
201 |
+
}
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def get(name):
|
205 |
+
if name.startswith("Caffe2Detectron/COCO"):
|
206 |
+
return ModelCatalog.get_c2_detectron_12_2017_baselines(name)
|
207 |
+
if name.startswith("ImageNetPretrained"):
|
208 |
+
return ModelCatalog.get_c2_imagenet_pretrained(name)
|
209 |
+
raise RuntimeError("model not present in the catalog {}".format(name))
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def get_c2_imagenet_pretrained(name):
|
213 |
+
prefix = ModelCatalog.S3_C2_DETECTRON_URL
|
214 |
+
name = name[len("ImageNetPretrained/") :]
|
215 |
+
name = ModelCatalog.C2_IMAGENET_MODELS[name]
|
216 |
+
if 'resnet34' in name or 'resnet18' in name:
|
217 |
+
return name
|
218 |
+
url = "/".join([prefix, name])
|
219 |
+
return url
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def get_c2_detectron_12_2017_baselines(name):
|
223 |
+
# Detectron C2 models are stored following the structure
|
224 |
+
# prefix/<model_id>/2012_2017_baselines/<model_name>.yaml.<signature>/suffix
|
225 |
+
# we use as identifiers in the catalog Caffe2Detectron/COCO/<model_id>/<model_name>
|
226 |
+
prefix = ModelCatalog.S3_C2_DETECTRON_URL
|
227 |
+
suffix = ModelCatalog.C2_DETECTRON_SUFFIX
|
228 |
+
# remove identification prefix
|
229 |
+
name = name[len("Caffe2Detectron/COCO/") :]
|
230 |
+
# split in <model_id> and <model_name>
|
231 |
+
model_id, model_name = name.split("/")
|
232 |
+
# parsing to make it match the url address from the Caffe2 models
|
233 |
+
model_name = "{}.yaml".format(model_name)
|
234 |
+
signature = ModelCatalog.C2_DETECTRON_MODELS[name]
|
235 |
+
unique_name = ".".join([model_name, signature])
|
236 |
+
url = "/".join([prefix, model_id, "12_2017_baselines", unique_name, suffix])
|
237 |
+
return url
|
maskrcnn_benchmark/csrc/ROIAlign.h
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#pragma once
|
3 |
+
|
4 |
+
#include "cpu/vision.h"
|
5 |
+
|
6 |
+
#ifdef WITH_CUDA
|
7 |
+
#include "cuda/vision.h"
|
8 |
+
#endif
|
9 |
+
|
10 |
+
// Interface for Python
|
11 |
+
at::Tensor ROIAlign_forward(const at::Tensor& input,
|
12 |
+
const at::Tensor& rois,
|
13 |
+
const float spatial_scale,
|
14 |
+
const int pooled_height,
|
15 |
+
const int pooled_width,
|
16 |
+
const int sampling_ratio) {
|
17 |
+
if (input.type().is_cuda()) {
|
18 |
+
#ifdef WITH_CUDA
|
19 |
+
return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
|
20 |
+
#else
|
21 |
+
AT_ERROR("Not compiled with GPU support");
|
22 |
+
#endif
|
23 |
+
}
|
24 |
+
return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
|
25 |
+
}
|
26 |
+
|
27 |
+
at::Tensor ROIAlign_backward(const at::Tensor& grad,
|
28 |
+
const at::Tensor& rois,
|
29 |
+
const float spatial_scale,
|
30 |
+
const int pooled_height,
|
31 |
+
const int pooled_width,
|
32 |
+
const int batch_size,
|
33 |
+
const int channels,
|
34 |
+
const int height,
|
35 |
+
const int width,
|
36 |
+
const int sampling_ratio) {
|
37 |
+
if (grad.type().is_cuda()) {
|
38 |
+
#ifdef WITH_CUDA
|
39 |
+
return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
|
40 |
+
#else
|
41 |
+
AT_ERROR("Not compiled with GPU support");
|
42 |
+
#endif
|
43 |
+
}
|
44 |
+
AT_ERROR("Not implemented on the CPU");
|
45 |
+
}
|
46 |
+
|
maskrcnn_benchmark/csrc/ROIPool.h
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#pragma once
|
3 |
+
|
4 |
+
#include "cpu/vision.h"
|
5 |
+
|
6 |
+
#ifdef WITH_CUDA
|
7 |
+
#include "cuda/vision.h"
|
8 |
+
#endif
|
9 |
+
|
10 |
+
|
11 |
+
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor& input,
|
12 |
+
const at::Tensor& rois,
|
13 |
+
const float spatial_scale,
|
14 |
+
const int pooled_height,
|
15 |
+
const int pooled_width) {
|
16 |
+
if (input.type().is_cuda()) {
|
17 |
+
#ifdef WITH_CUDA
|
18 |
+
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
|
19 |
+
#else
|
20 |
+
AT_ERROR("Not compiled with GPU support");
|
21 |
+
#endif
|
22 |
+
}
|
23 |
+
AT_ERROR("Not implemented on the CPU");
|
24 |
+
}
|
25 |
+
|
26 |
+
at::Tensor ROIPool_backward(const at::Tensor& grad,
|
27 |
+
const at::Tensor& input,
|
28 |
+
const at::Tensor& rois,
|
29 |
+
const at::Tensor& argmax,
|
30 |
+
const float spatial_scale,
|
31 |
+
const int pooled_height,
|
32 |
+
const int pooled_width,
|
33 |
+
const int batch_size,
|
34 |
+
const int channels,
|
35 |
+
const int height,
|
36 |
+
const int width) {
|
37 |
+
if (grad.type().is_cuda()) {
|
38 |
+
#ifdef WITH_CUDA
|
39 |
+
return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
|
40 |
+
#else
|
41 |
+
AT_ERROR("Not compiled with GPU support");
|
42 |
+
#endif
|
43 |
+
}
|
44 |
+
AT_ERROR("Not implemented on the CPU");
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
|
maskrcnn_benchmark/csrc/SigmoidFocalLoss.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cpu/vision.h"
|
4 |
+
|
5 |
+
#ifdef WITH_CUDA
|
6 |
+
#include "cuda/vision.h"
|
7 |
+
#endif
|
8 |
+
|
9 |
+
// Interface for Python
|
10 |
+
at::Tensor SigmoidFocalLoss_forward(
|
11 |
+
const at::Tensor& logits,
|
12 |
+
const at::Tensor& targets,
|
13 |
+
const int num_classes,
|
14 |
+
const float gamma,
|
15 |
+
const float alpha) {
|
16 |
+
if (logits.type().is_cuda()) {
|
17 |
+
#ifdef WITH_CUDA
|
18 |
+
return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha);
|
19 |
+
#else
|
20 |
+
AT_ERROR("Not compiled with GPU support");
|
21 |
+
#endif
|
22 |
+
}
|
23 |
+
AT_ERROR("Not implemented on the CPU");
|
24 |
+
}
|
25 |
+
|
26 |
+
at::Tensor SigmoidFocalLoss_backward(
|
27 |
+
const at::Tensor& logits,
|
28 |
+
const at::Tensor& targets,
|
29 |
+
const at::Tensor& d_losses,
|
30 |
+
const int num_classes,
|
31 |
+
const float gamma,
|
32 |
+
const float alpha) {
|
33 |
+
if (logits.type().is_cuda()) {
|
34 |
+
#ifdef WITH_CUDA
|
35 |
+
return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha);
|
36 |
+
#else
|
37 |
+
AT_ERROR("Not compiled with GPU support");
|
38 |
+
#endif
|
39 |
+
}
|
40 |
+
AT_ERROR("Not implemented on the CPU");
|
41 |
+
}
|
maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#include "cpu/vision.h"
|
3 |
+
|
4 |
+
// implementation taken from Caffe2
|
5 |
+
template <typename T>
|
6 |
+
struct PreCalc {
|
7 |
+
int pos1;
|
8 |
+
int pos2;
|
9 |
+
int pos3;
|
10 |
+
int pos4;
|
11 |
+
T w1;
|
12 |
+
T w2;
|
13 |
+
T w3;
|
14 |
+
T w4;
|
15 |
+
};
|
16 |
+
|
17 |
+
template <typename T>
|
18 |
+
void pre_calc_for_bilinear_interpolate(
|
19 |
+
const int height,
|
20 |
+
const int width,
|
21 |
+
const int pooled_height,
|
22 |
+
const int pooled_width,
|
23 |
+
const int iy_upper,
|
24 |
+
const int ix_upper,
|
25 |
+
T roi_start_h,
|
26 |
+
T roi_start_w,
|
27 |
+
T bin_size_h,
|
28 |
+
T bin_size_w,
|
29 |
+
int roi_bin_grid_h,
|
30 |
+
int roi_bin_grid_w,
|
31 |
+
std::vector<PreCalc<T>>& pre_calc) {
|
32 |
+
int pre_calc_index = 0;
|
33 |
+
for (int ph = 0; ph < pooled_height; ph++) {
|
34 |
+
for (int pw = 0; pw < pooled_width; pw++) {
|
35 |
+
for (int iy = 0; iy < iy_upper; iy++) {
|
36 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
37 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
38 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
39 |
+
for (int ix = 0; ix < ix_upper; ix++) {
|
40 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
41 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
42 |
+
static_cast<T>(roi_bin_grid_w);
|
43 |
+
|
44 |
+
T x = xx;
|
45 |
+
T y = yy;
|
46 |
+
// deal with: inverse elements are out of feature map boundary
|
47 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
48 |
+
// empty
|
49 |
+
PreCalc<T> pc;
|
50 |
+
pc.pos1 = 0;
|
51 |
+
pc.pos2 = 0;
|
52 |
+
pc.pos3 = 0;
|
53 |
+
pc.pos4 = 0;
|
54 |
+
pc.w1 = 0;
|
55 |
+
pc.w2 = 0;
|
56 |
+
pc.w3 = 0;
|
57 |
+
pc.w4 = 0;
|
58 |
+
pre_calc[pre_calc_index] = pc;
|
59 |
+
pre_calc_index += 1;
|
60 |
+
continue;
|
61 |
+
}
|
62 |
+
|
63 |
+
if (y <= 0) {
|
64 |
+
y = 0;
|
65 |
+
}
|
66 |
+
if (x <= 0) {
|
67 |
+
x = 0;
|
68 |
+
}
|
69 |
+
|
70 |
+
int y_low = (int)y;
|
71 |
+
int x_low = (int)x;
|
72 |
+
int y_high;
|
73 |
+
int x_high;
|
74 |
+
|
75 |
+
if (y_low >= height - 1) {
|
76 |
+
y_high = y_low = height - 1;
|
77 |
+
y = (T)y_low;
|
78 |
+
} else {
|
79 |
+
y_high = y_low + 1;
|
80 |
+
}
|
81 |
+
|
82 |
+
if (x_low >= width - 1) {
|
83 |
+
x_high = x_low = width - 1;
|
84 |
+
x = (T)x_low;
|
85 |
+
} else {
|
86 |
+
x_high = x_low + 1;
|
87 |
+
}
|
88 |
+
|
89 |
+
T ly = y - y_low;
|
90 |
+
T lx = x - x_low;
|
91 |
+
T hy = 1. - ly, hx = 1. - lx;
|
92 |
+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
93 |
+
|
94 |
+
// save weights and indices
|
95 |
+
PreCalc<T> pc;
|
96 |
+
pc.pos1 = y_low * width + x_low;
|
97 |
+
pc.pos2 = y_low * width + x_high;
|
98 |
+
pc.pos3 = y_high * width + x_low;
|
99 |
+
pc.pos4 = y_high * width + x_high;
|
100 |
+
pc.w1 = w1;
|
101 |
+
pc.w2 = w2;
|
102 |
+
pc.w3 = w3;
|
103 |
+
pc.w4 = w4;
|
104 |
+
pre_calc[pre_calc_index] = pc;
|
105 |
+
|
106 |
+
pre_calc_index += 1;
|
107 |
+
}
|
108 |
+
}
|
109 |
+
}
|
110 |
+
}
|
111 |
+
}
|
112 |
+
|
113 |
+
template <typename T>
|
114 |
+
void ROIAlignForward_cpu_kernel(
|
115 |
+
const int nthreads,
|
116 |
+
const T* bottom_data,
|
117 |
+
const T& spatial_scale,
|
118 |
+
const int channels,
|
119 |
+
const int height,
|
120 |
+
const int width,
|
121 |
+
const int pooled_height,
|
122 |
+
const int pooled_width,
|
123 |
+
const int sampling_ratio,
|
124 |
+
const T* bottom_rois,
|
125 |
+
//int roi_cols,
|
126 |
+
T* top_data) {
|
127 |
+
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
|
128 |
+
int roi_cols = 5;
|
129 |
+
|
130 |
+
int n_rois = nthreads / channels / pooled_width / pooled_height;
|
131 |
+
// (n, c, ph, pw) is an element in the pooled output
|
132 |
+
// can be parallelized using omp
|
133 |
+
// #pragma omp parallel for num_threads(32)
|
134 |
+
for (int n = 0; n < n_rois; n++) {
|
135 |
+
int index_n = n * channels * pooled_width * pooled_height;
|
136 |
+
|
137 |
+
// roi could have 4 or 5 columns
|
138 |
+
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
|
139 |
+
int roi_batch_ind = 0;
|
140 |
+
if (roi_cols == 5) {
|
141 |
+
roi_batch_ind = offset_bottom_rois[0];
|
142 |
+
offset_bottom_rois++;
|
143 |
+
}
|
144 |
+
|
145 |
+
// Do not using rounding; this implementation detail is critical
|
146 |
+
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
|
147 |
+
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
|
148 |
+
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
|
149 |
+
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
|
150 |
+
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
|
151 |
+
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
|
152 |
+
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
|
153 |
+
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
|
154 |
+
|
155 |
+
// Force malformed ROIs to be 1x1
|
156 |
+
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
|
157 |
+
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
|
158 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
159 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
160 |
+
|
161 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
162 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
163 |
+
? sampling_ratio
|
164 |
+
: ceil(roi_height / pooled_height); // e.g., = 2
|
165 |
+
int roi_bin_grid_w =
|
166 |
+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
167 |
+
|
168 |
+
// We do average (integral) pooling inside a bin
|
169 |
+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
170 |
+
|
171 |
+
// we want to precalculate indices and weights shared by all channels,
|
172 |
+
// this is the key point of optimization
|
173 |
+
std::vector<PreCalc<T>> pre_calc(
|
174 |
+
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
|
175 |
+
pre_calc_for_bilinear_interpolate(
|
176 |
+
height,
|
177 |
+
width,
|
178 |
+
pooled_height,
|
179 |
+
pooled_width,
|
180 |
+
roi_bin_grid_h,
|
181 |
+
roi_bin_grid_w,
|
182 |
+
roi_start_h,
|
183 |
+
roi_start_w,
|
184 |
+
bin_size_h,
|
185 |
+
bin_size_w,
|
186 |
+
roi_bin_grid_h,
|
187 |
+
roi_bin_grid_w,
|
188 |
+
pre_calc);
|
189 |
+
|
190 |
+
for (int c = 0; c < channels; c++) {
|
191 |
+
int index_n_c = index_n + c * pooled_width * pooled_height;
|
192 |
+
const T* offset_bottom_data =
|
193 |
+
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
194 |
+
int pre_calc_index = 0;
|
195 |
+
|
196 |
+
for (int ph = 0; ph < pooled_height; ph++) {
|
197 |
+
for (int pw = 0; pw < pooled_width; pw++) {
|
198 |
+
int index = index_n_c + ph * pooled_width + pw;
|
199 |
+
|
200 |
+
T output_val = 0.;
|
201 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
202 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
203 |
+
PreCalc<T> pc = pre_calc[pre_calc_index];
|
204 |
+
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
|
205 |
+
pc.w2 * offset_bottom_data[pc.pos2] +
|
206 |
+
pc.w3 * offset_bottom_data[pc.pos3] +
|
207 |
+
pc.w4 * offset_bottom_data[pc.pos4];
|
208 |
+
|
209 |
+
pre_calc_index += 1;
|
210 |
+
}
|
211 |
+
}
|
212 |
+
output_val /= count;
|
213 |
+
|
214 |
+
top_data[index] = output_val;
|
215 |
+
} // for pw
|
216 |
+
} // for ph
|
217 |
+
} // for c
|
218 |
+
} // for n
|
219 |
+
}
|
220 |
+
|
221 |
+
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
|
222 |
+
const at::Tensor& rois,
|
223 |
+
const float spatial_scale,
|
224 |
+
const int pooled_height,
|
225 |
+
const int pooled_width,
|
226 |
+
const int sampling_ratio) {
|
227 |
+
AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor");
|
228 |
+
AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor");
|
229 |
+
|
230 |
+
auto num_rois = rois.size(0);
|
231 |
+
auto channels = input.size(1);
|
232 |
+
auto height = input.size(2);
|
233 |
+
auto width = input.size(3);
|
234 |
+
|
235 |
+
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
|
236 |
+
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
237 |
+
|
238 |
+
if (output.numel() == 0) {
|
239 |
+
return output;
|
240 |
+
}
|
241 |
+
|
242 |
+
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
|
243 |
+
ROIAlignForward_cpu_kernel<scalar_t>(
|
244 |
+
output_size,
|
245 |
+
input.data<scalar_t>(),
|
246 |
+
spatial_scale,
|
247 |
+
channels,
|
248 |
+
height,
|
249 |
+
width,
|
250 |
+
pooled_height,
|
251 |
+
pooled_width,
|
252 |
+
sampling_ratio,
|
253 |
+
rois.data<scalar_t>(),
|
254 |
+
output.data<scalar_t>());
|
255 |
+
});
|
256 |
+
return output;
|
257 |
+
}
|
maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#include "cpu/vision.h"
|
3 |
+
|
4 |
+
|
5 |
+
template <typename scalar_t>
|
6 |
+
at::Tensor nms_cpu_kernel(const at::Tensor& dets,
|
7 |
+
const at::Tensor& scores,
|
8 |
+
const float threshold) {
|
9 |
+
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
|
10 |
+
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
|
11 |
+
AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");
|
12 |
+
|
13 |
+
if (dets.numel() == 0) {
|
14 |
+
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
|
15 |
+
}
|
16 |
+
|
17 |
+
auto x1_t = dets.select(1, 0).contiguous();
|
18 |
+
auto y1_t = dets.select(1, 1).contiguous();
|
19 |
+
auto x2_t = dets.select(1, 2).contiguous();
|
20 |
+
auto y2_t = dets.select(1, 3).contiguous();
|
21 |
+
|
22 |
+
at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
|
23 |
+
|
24 |
+
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
|
25 |
+
|
26 |
+
auto ndets = dets.size(0);
|
27 |
+
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU));
|
28 |
+
|
29 |
+
auto suppressed = suppressed_t.data<uint8_t>();
|
30 |
+
auto order = order_t.data<int64_t>();
|
31 |
+
auto x1 = x1_t.data<scalar_t>();
|
32 |
+
auto y1 = y1_t.data<scalar_t>();
|
33 |
+
auto x2 = x2_t.data<scalar_t>();
|
34 |
+
auto y2 = y2_t.data<scalar_t>();
|
35 |
+
auto areas = areas_t.data<scalar_t>();
|
36 |
+
|
37 |
+
for (int64_t _i = 0; _i < ndets; _i++) {
|
38 |
+
auto i = order[_i];
|
39 |
+
if (suppressed[i] == 1)
|
40 |
+
continue;
|
41 |
+
auto ix1 = x1[i];
|
42 |
+
auto iy1 = y1[i];
|
43 |
+
auto ix2 = x2[i];
|
44 |
+
auto iy2 = y2[i];
|
45 |
+
auto iarea = areas[i];
|
46 |
+
|
47 |
+
for (int64_t _j = _i + 1; _j < ndets; _j++) {
|
48 |
+
auto j = order[_j];
|
49 |
+
if (suppressed[j] == 1)
|
50 |
+
continue;
|
51 |
+
auto xx1 = std::max(ix1, x1[j]);
|
52 |
+
auto yy1 = std::max(iy1, y1[j]);
|
53 |
+
auto xx2 = std::min(ix2, x2[j]);
|
54 |
+
auto yy2 = std::min(iy2, y2[j]);
|
55 |
+
|
56 |
+
auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
|
57 |
+
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
|
58 |
+
auto inter = w * h;
|
59 |
+
auto ovr = inter / (iarea + areas[j] - inter);
|
60 |
+
if (ovr >= threshold)
|
61 |
+
suppressed[j] = 1;
|
62 |
+
}
|
63 |
+
}
|
64 |
+
return at::nonzero(suppressed_t == 0).squeeze(1);
|
65 |
+
}
|
66 |
+
|
67 |
+
at::Tensor nms_cpu(const at::Tensor& dets,
|
68 |
+
const at::Tensor& scores,
|
69 |
+
const float threshold) {
|
70 |
+
at::Tensor result;
|
71 |
+
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
|
72 |
+
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
|
73 |
+
});
|
74 |
+
return result;
|
75 |
+
}
|
maskrcnn_benchmark/csrc/cpu/vision.h
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#pragma once
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
|
6 |
+
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
|
7 |
+
const at::Tensor& rois,
|
8 |
+
const float spatial_scale,
|
9 |
+
const int pooled_height,
|
10 |
+
const int pooled_width,
|
11 |
+
const int sampling_ratio);
|
12 |
+
|
13 |
+
|
14 |
+
at::Tensor nms_cpu(const at::Tensor& dets,
|
15 |
+
const at::Tensor& scores,
|
16 |
+
const float threshold);
|
maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
|
5 |
+
#include <THC/THC.h>
|
6 |
+
#include <THC/THCAtomics.cuh>
|
7 |
+
#include <THC/THCDeviceUtils.cuh>
|
8 |
+
|
9 |
+
// TODO make it in a common file
|
10 |
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
11 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
12 |
+
i += blockDim.x * gridDim.x)
|
13 |
+
|
14 |
+
|
15 |
+
template <typename T>
|
16 |
+
__device__ T bilinear_interpolate(const T* bottom_data,
|
17 |
+
const int height, const int width,
|
18 |
+
T y, T x,
|
19 |
+
const int index /* index for debug only*/) {
|
20 |
+
|
21 |
+
// deal with cases that inverse elements are out of feature map boundary
|
22 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
23 |
+
//empty
|
24 |
+
return 0;
|
25 |
+
}
|
26 |
+
|
27 |
+
if (y <= 0) y = 0;
|
28 |
+
if (x <= 0) x = 0;
|
29 |
+
|
30 |
+
int y_low = (int) y;
|
31 |
+
int x_low = (int) x;
|
32 |
+
int y_high;
|
33 |
+
int x_high;
|
34 |
+
|
35 |
+
if (y_low >= height - 1) {
|
36 |
+
y_high = y_low = height - 1;
|
37 |
+
y = (T) y_low;
|
38 |
+
} else {
|
39 |
+
y_high = y_low + 1;
|
40 |
+
}
|
41 |
+
|
42 |
+
if (x_low >= width - 1) {
|
43 |
+
x_high = x_low = width - 1;
|
44 |
+
x = (T) x_low;
|
45 |
+
} else {
|
46 |
+
x_high = x_low + 1;
|
47 |
+
}
|
48 |
+
|
49 |
+
T ly = y - y_low;
|
50 |
+
T lx = x - x_low;
|
51 |
+
T hy = 1. - ly, hx = 1. - lx;
|
52 |
+
// do bilinear interpolation
|
53 |
+
T v1 = bottom_data[y_low * width + x_low];
|
54 |
+
T v2 = bottom_data[y_low * width + x_high];
|
55 |
+
T v3 = bottom_data[y_high * width + x_low];
|
56 |
+
T v4 = bottom_data[y_high * width + x_high];
|
57 |
+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
58 |
+
|
59 |
+
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
60 |
+
|
61 |
+
return val;
|
62 |
+
}
|
63 |
+
|
64 |
+
template <typename T>
|
65 |
+
__global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
|
66 |
+
const T spatial_scale, const int channels,
|
67 |
+
const int height, const int width,
|
68 |
+
const int pooled_height, const int pooled_width,
|
69 |
+
const int sampling_ratio,
|
70 |
+
const T* bottom_rois, T* top_data) {
|
71 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
72 |
+
// (n, c, ph, pw) is an element in the pooled output
|
73 |
+
int pw = index % pooled_width;
|
74 |
+
int ph = (index / pooled_width) % pooled_height;
|
75 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
76 |
+
int n = index / pooled_width / pooled_height / channels;
|
77 |
+
|
78 |
+
const T* offset_bottom_rois = bottom_rois + n * 5;
|
79 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
80 |
+
|
81 |
+
// Do not using rounding; this implementation detail is critical
|
82 |
+
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
|
83 |
+
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
|
84 |
+
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
|
85 |
+
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
|
86 |
+
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
|
87 |
+
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
|
88 |
+
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
|
89 |
+
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
|
90 |
+
|
91 |
+
// Force malformed ROIs to be 1x1
|
92 |
+
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
|
93 |
+
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
|
94 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
95 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
96 |
+
|
97 |
+
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
|
98 |
+
|
99 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
100 |
+
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
|
101 |
+
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
102 |
+
|
103 |
+
// We do average (integral) pooling inside a bin
|
104 |
+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
105 |
+
|
106 |
+
T output_val = 0.;
|
107 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
|
108 |
+
{
|
109 |
+
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
110 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix ++)
|
111 |
+
{
|
112 |
+
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
|
113 |
+
|
114 |
+
T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
|
115 |
+
output_val += val;
|
116 |
+
}
|
117 |
+
}
|
118 |
+
output_val /= count;
|
119 |
+
|
120 |
+
top_data[index] = output_val;
|
121 |
+
}
|
122 |
+
}
|
123 |
+
|
124 |
+
|
125 |
+
template <typename T>
|
126 |
+
__device__ void bilinear_interpolate_gradient(
|
127 |
+
const int height, const int width,
|
128 |
+
T y, T x,
|
129 |
+
T & w1, T & w2, T & w3, T & w4,
|
130 |
+
int & x_low, int & x_high, int & y_low, int & y_high,
|
131 |
+
const int index /* index for debug only*/) {
|
132 |
+
|
133 |
+
// deal with cases that inverse elements are out of feature map boundary
|
134 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
135 |
+
//empty
|
136 |
+
w1 = w2 = w3 = w4 = 0.;
|
137 |
+
x_low = x_high = y_low = y_high = -1;
|
138 |
+
return;
|
139 |
+
}
|
140 |
+
|
141 |
+
if (y <= 0) y = 0;
|
142 |
+
if (x <= 0) x = 0;
|
143 |
+
|
144 |
+
y_low = (int) y;
|
145 |
+
x_low = (int) x;
|
146 |
+
|
147 |
+
if (y_low >= height - 1) {
|
148 |
+
y_high = y_low = height - 1;
|
149 |
+
y = (T) y_low;
|
150 |
+
} else {
|
151 |
+
y_high = y_low + 1;
|
152 |
+
}
|
153 |
+
|
154 |
+
if (x_low >= width - 1) {
|
155 |
+
x_high = x_low = width - 1;
|
156 |
+
x = (T) x_low;
|
157 |
+
} else {
|
158 |
+
x_high = x_low + 1;
|
159 |
+
}
|
160 |
+
|
161 |
+
T ly = y - y_low;
|
162 |
+
T lx = x - x_low;
|
163 |
+
T hy = 1. - ly, hx = 1. - lx;
|
164 |
+
|
165 |
+
// reference in forward
|
166 |
+
// T v1 = bottom_data[y_low * width + x_low];
|
167 |
+
// T v2 = bottom_data[y_low * width + x_high];
|
168 |
+
// T v3 = bottom_data[y_high * width + x_low];
|
169 |
+
// T v4 = bottom_data[y_high * width + x_high];
|
170 |
+
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
171 |
+
|
172 |
+
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
173 |
+
|
174 |
+
return;
|
175 |
+
}
|
176 |
+
|
177 |
+
template <typename T>
|
178 |
+
__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
|
179 |
+
const int num_rois, const T spatial_scale,
|
180 |
+
const int channels, const int height, const int width,
|
181 |
+
const int pooled_height, const int pooled_width,
|
182 |
+
const int sampling_ratio,
|
183 |
+
T* bottom_diff,
|
184 |
+
const T* bottom_rois) {
|
185 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
186 |
+
// (n, c, ph, pw) is an element in the pooled output
|
187 |
+
int pw = index % pooled_width;
|
188 |
+
int ph = (index / pooled_width) % pooled_height;
|
189 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
190 |
+
int n = index / pooled_width / pooled_height / channels;
|
191 |
+
|
192 |
+
const T* offset_bottom_rois = bottom_rois + n * 5;
|
193 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
194 |
+
|
195 |
+
// Do not using rounding; this implementation detail is critical
|
196 |
+
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
|
197 |
+
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
|
198 |
+
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
|
199 |
+
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
|
200 |
+
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
|
201 |
+
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
|
202 |
+
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
|
203 |
+
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
|
204 |
+
|
205 |
+
// Force malformed ROIs to be 1x1
|
206 |
+
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
|
207 |
+
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
|
208 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
209 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
210 |
+
|
211 |
+
T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
212 |
+
|
213 |
+
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
214 |
+
const T* offset_top_diff = top_diff + top_offset;
|
215 |
+
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
216 |
+
|
217 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
218 |
+
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
|
219 |
+
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
220 |
+
|
221 |
+
// We do average (integral) pooling inside a bin
|
222 |
+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
223 |
+
|
224 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
|
225 |
+
{
|
226 |
+
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
227 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix ++)
|
228 |
+
{
|
229 |
+
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
|
230 |
+
|
231 |
+
T w1, w2, w3, w4;
|
232 |
+
int x_low, x_high, y_low, y_high;
|
233 |
+
|
234 |
+
bilinear_interpolate_gradient(height, width, y, x,
|
235 |
+
w1, w2, w3, w4,
|
236 |
+
x_low, x_high, y_low, y_high,
|
237 |
+
index);
|
238 |
+
|
239 |
+
T g1 = top_diff_this_bin * w1 / count;
|
240 |
+
T g2 = top_diff_this_bin * w2 / count;
|
241 |
+
T g3 = top_diff_this_bin * w3 / count;
|
242 |
+
T g4 = top_diff_this_bin * w4 / count;
|
243 |
+
|
244 |
+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
|
245 |
+
{
|
246 |
+
atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
|
247 |
+
atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
|
248 |
+
atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
|
249 |
+
atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
|
250 |
+
} // if
|
251 |
+
} // ix
|
252 |
+
} // iy
|
253 |
+
} // CUDA_1D_KERNEL_LOOP
|
254 |
+
} // RoIAlignBackward
|
255 |
+
|
256 |
+
|
257 |
+
at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
|
258 |
+
const at::Tensor& rois,
|
259 |
+
const float spatial_scale,
|
260 |
+
const int pooled_height,
|
261 |
+
const int pooled_width,
|
262 |
+
const int sampling_ratio) {
|
263 |
+
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
264 |
+
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
265 |
+
|
266 |
+
auto num_rois = rois.size(0);
|
267 |
+
auto channels = input.size(1);
|
268 |
+
auto height = input.size(2);
|
269 |
+
auto width = input.size(3);
|
270 |
+
|
271 |
+
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
|
272 |
+
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
273 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
274 |
+
|
275 |
+
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
276 |
+
dim3 block(512);
|
277 |
+
|
278 |
+
if (output.numel() == 0) {
|
279 |
+
THCudaCheck(cudaGetLastError());
|
280 |
+
return output;
|
281 |
+
}
|
282 |
+
|
283 |
+
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
|
284 |
+
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
|
285 |
+
output_size,
|
286 |
+
input.contiguous().data<scalar_t>(),
|
287 |
+
spatial_scale,
|
288 |
+
channels,
|
289 |
+
height,
|
290 |
+
width,
|
291 |
+
pooled_height,
|
292 |
+
pooled_width,
|
293 |
+
sampling_ratio,
|
294 |
+
rois.contiguous().data<scalar_t>(),
|
295 |
+
output.data<scalar_t>());
|
296 |
+
});
|
297 |
+
THCudaCheck(cudaGetLastError());
|
298 |
+
return output;
|
299 |
+
}
|
300 |
+
|
301 |
+
// TODO remove the dependency on input and use instead its sizes -> save memory
|
302 |
+
at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
|
303 |
+
const at::Tensor& rois,
|
304 |
+
const float spatial_scale,
|
305 |
+
const int pooled_height,
|
306 |
+
const int pooled_width,
|
307 |
+
const int batch_size,
|
308 |
+
const int channels,
|
309 |
+
const int height,
|
310 |
+
const int width,
|
311 |
+
const int sampling_ratio) {
|
312 |
+
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
|
313 |
+
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
314 |
+
|
315 |
+
auto num_rois = rois.size(0);
|
316 |
+
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
|
317 |
+
|
318 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
319 |
+
|
320 |
+
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
321 |
+
dim3 block(512);
|
322 |
+
|
323 |
+
// handle possibly empty gradients
|
324 |
+
if (grad.numel() == 0) {
|
325 |
+
THCudaCheck(cudaGetLastError());
|
326 |
+
return grad_input;
|
327 |
+
}
|
328 |
+
|
329 |
+
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
|
330 |
+
RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
|
331 |
+
grad.numel(),
|
332 |
+
grad.contiguous().data<scalar_t>(),
|
333 |
+
num_rois,
|
334 |
+
spatial_scale,
|
335 |
+
channels,
|
336 |
+
height,
|
337 |
+
width,
|
338 |
+
pooled_height,
|
339 |
+
pooled_width,
|
340 |
+
sampling_ratio,
|
341 |
+
grad_input.data<scalar_t>(),
|
342 |
+
rois.contiguous().data<scalar_t>());
|
343 |
+
});
|
344 |
+
THCudaCheck(cudaGetLastError());
|
345 |
+
return grad_input;
|
346 |
+
}
|
maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
|
5 |
+
#include <THC/THC.h>
|
6 |
+
#include <THC/THCAtomics.cuh>
|
7 |
+
#include <THC/THCDeviceUtils.cuh>
|
8 |
+
|
9 |
+
|
10 |
+
// TODO make it in a common file
|
11 |
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
12 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
13 |
+
i += blockDim.x * gridDim.x)
|
14 |
+
|
15 |
+
|
16 |
+
template <typename T>
|
17 |
+
__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
|
18 |
+
const T spatial_scale, const int channels, const int height,
|
19 |
+
const int width, const int pooled_height, const int pooled_width,
|
20 |
+
const T* bottom_rois, T* top_data, int* argmax_data) {
|
21 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
22 |
+
// (n, c, ph, pw) is an element in the pooled output
|
23 |
+
int pw = index % pooled_width;
|
24 |
+
int ph = (index / pooled_width) % pooled_height;
|
25 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
26 |
+
int n = index / pooled_width / pooled_height / channels;
|
27 |
+
|
28 |
+
const T* offset_bottom_rois = bottom_rois + n * 5;
|
29 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
30 |
+
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
|
31 |
+
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
|
32 |
+
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
|
33 |
+
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
|
34 |
+
|
35 |
+
// Force malformed ROIs to be 1x1
|
36 |
+
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
|
37 |
+
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
|
38 |
+
T bin_size_h = static_cast<T>(roi_height)
|
39 |
+
/ static_cast<T>(pooled_height);
|
40 |
+
T bin_size_w = static_cast<T>(roi_width)
|
41 |
+
/ static_cast<T>(pooled_width);
|
42 |
+
|
43 |
+
int hstart = static_cast<int>(floor(static_cast<T>(ph)
|
44 |
+
* bin_size_h));
|
45 |
+
int wstart = static_cast<int>(floor(static_cast<T>(pw)
|
46 |
+
* bin_size_w));
|
47 |
+
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1)
|
48 |
+
* bin_size_h));
|
49 |
+
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1)
|
50 |
+
* bin_size_w));
|
51 |
+
|
52 |
+
// Add roi offsets and clip to input boundaries
|
53 |
+
hstart = min(max(hstart + roi_start_h, 0), height);
|
54 |
+
hend = min(max(hend + roi_start_h, 0), height);
|
55 |
+
wstart = min(max(wstart + roi_start_w, 0), width);
|
56 |
+
wend = min(max(wend + roi_start_w, 0), width);
|
57 |
+
bool is_empty = (hend <= hstart) || (wend <= wstart);
|
58 |
+
|
59 |
+
// Define an empty pooling region to be zero
|
60 |
+
T maxval = is_empty ? 0 : -FLT_MAX;
|
61 |
+
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
|
62 |
+
int maxidx = -1;
|
63 |
+
const T* offset_bottom_data =
|
64 |
+
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
65 |
+
for (int h = hstart; h < hend; ++h) {
|
66 |
+
for (int w = wstart; w < wend; ++w) {
|
67 |
+
int bottom_index = h * width + w;
|
68 |
+
if (offset_bottom_data[bottom_index] > maxval) {
|
69 |
+
maxval = offset_bottom_data[bottom_index];
|
70 |
+
maxidx = bottom_index;
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
74 |
+
top_data[index] = maxval;
|
75 |
+
argmax_data[index] = maxidx;
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
template <typename T>
|
80 |
+
__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff,
|
81 |
+
const int* argmax_data, const int num_rois, const T spatial_scale,
|
82 |
+
const int channels, const int height, const int width,
|
83 |
+
const int pooled_height, const int pooled_width, T* bottom_diff,
|
84 |
+
const T* bottom_rois) {
|
85 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
86 |
+
// (n, c, ph, pw) is an element in the pooled output
|
87 |
+
int pw = index % pooled_width;
|
88 |
+
int ph = (index / pooled_width) % pooled_height;
|
89 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
90 |
+
int n = index / pooled_width / pooled_height / channels;
|
91 |
+
|
92 |
+
const T* offset_bottom_rois = bottom_rois + n * 5;
|
93 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
94 |
+
int bottom_offset = (roi_batch_ind * channels + c) * height * width;
|
95 |
+
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
96 |
+
const T* offset_top_diff = top_diff + top_offset;
|
97 |
+
T* offset_bottom_diff = bottom_diff + bottom_offset;
|
98 |
+
const int* offset_argmax_data = argmax_data + top_offset;
|
99 |
+
|
100 |
+
int argmax = offset_argmax_data[ph * pooled_width + pw];
|
101 |
+
if (argmax != -1) {
|
102 |
+
atomicAdd(
|
103 |
+
offset_bottom_diff + argmax,
|
104 |
+
static_cast<T>(offset_top_diff[ph * pooled_width + pw]));
|
105 |
+
|
106 |
+
}
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
|
111 |
+
const at::Tensor& rois,
|
112 |
+
const float spatial_scale,
|
113 |
+
const int pooled_height,
|
114 |
+
const int pooled_width) {
|
115 |
+
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
116 |
+
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
117 |
+
|
118 |
+
auto num_rois = rois.size(0);
|
119 |
+
auto channels = input.size(1);
|
120 |
+
auto height = input.size(2);
|
121 |
+
auto width = input.size(3);
|
122 |
+
|
123 |
+
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
|
124 |
+
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
125 |
+
auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt));
|
126 |
+
|
127 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
128 |
+
|
129 |
+
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
130 |
+
dim3 block(512);
|
131 |
+
|
132 |
+
if (output.numel() == 0) {
|
133 |
+
THCudaCheck(cudaGetLastError());
|
134 |
+
return std::make_tuple(output, argmax);
|
135 |
+
}
|
136 |
+
|
137 |
+
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] {
|
138 |
+
RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>(
|
139 |
+
output_size,
|
140 |
+
input.contiguous().data<scalar_t>(),
|
141 |
+
spatial_scale,
|
142 |
+
channels,
|
143 |
+
height,
|
144 |
+
width,
|
145 |
+
pooled_height,
|
146 |
+
pooled_width,
|
147 |
+
rois.contiguous().data<scalar_t>(),
|
148 |
+
output.data<scalar_t>(),
|
149 |
+
argmax.data<int>());
|
150 |
+
});
|
151 |
+
THCudaCheck(cudaGetLastError());
|
152 |
+
return std::make_tuple(output, argmax);
|
153 |
+
}
|
154 |
+
|
155 |
+
// TODO remove the dependency on input and use instead its sizes -> save memory
|
156 |
+
at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
|
157 |
+
const at::Tensor& input,
|
158 |
+
const at::Tensor& rois,
|
159 |
+
const at::Tensor& argmax,
|
160 |
+
const float spatial_scale,
|
161 |
+
const int pooled_height,
|
162 |
+
const int pooled_width,
|
163 |
+
const int batch_size,
|
164 |
+
const int channels,
|
165 |
+
const int height,
|
166 |
+
const int width) {
|
167 |
+
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
|
168 |
+
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
169 |
+
// TODO add more checks
|
170 |
+
|
171 |
+
auto num_rois = rois.size(0);
|
172 |
+
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
|
173 |
+
|
174 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
175 |
+
|
176 |
+
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
177 |
+
dim3 block(512);
|
178 |
+
|
179 |
+
// handle possibly empty gradients
|
180 |
+
if (grad.numel() == 0) {
|
181 |
+
THCudaCheck(cudaGetLastError());
|
182 |
+
return grad_input;
|
183 |
+
}
|
184 |
+
|
185 |
+
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] {
|
186 |
+
RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>(
|
187 |
+
grad.numel(),
|
188 |
+
grad.contiguous().data<scalar_t>(),
|
189 |
+
argmax.data<int>(),
|
190 |
+
num_rois,
|
191 |
+
spatial_scale,
|
192 |
+
channels,
|
193 |
+
height,
|
194 |
+
width,
|
195 |
+
pooled_height,
|
196 |
+
pooled_width,
|
197 |
+
grad_input.data<scalar_t>(),
|
198 |
+
rois.contiguous().data<scalar_t>());
|
199 |
+
});
|
200 |
+
THCudaCheck(cudaGetLastError());
|
201 |
+
return grad_input;
|
202 |
+
}
|
maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
// This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
|
3 |
+
// Cheng-Yang Fu
|
4 |
+
// [email protected]
|
5 |
+
#include <ATen/ATen.h>
|
6 |
+
#include <ATen/cuda/CUDAContext.h>
|
7 |
+
|
8 |
+
#include <THC/THC.h>
|
9 |
+
#include <THC/THCAtomics.cuh>
|
10 |
+
#include <THC/THCDeviceUtils.cuh>
|
11 |
+
|
12 |
+
#include <cfloat>
|
13 |
+
|
14 |
+
// TODO make it in a common file
|
15 |
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
16 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
17 |
+
i += blockDim.x * gridDim.x)
|
18 |
+
|
19 |
+
|
20 |
+
template <typename T>
|
21 |
+
__global__ void SigmoidFocalLossForward(const int nthreads,
|
22 |
+
const T* logits,
|
23 |
+
const int* targets,
|
24 |
+
const int num_classes,
|
25 |
+
const float gamma,
|
26 |
+
const float alpha,
|
27 |
+
const int num,
|
28 |
+
T* losses) {
|
29 |
+
CUDA_1D_KERNEL_LOOP(i, nthreads) {
|
30 |
+
|
31 |
+
int n = i / num_classes;
|
32 |
+
int d = i % num_classes; // current class[0~79];
|
33 |
+
int t = targets[n]; // target class [1~80];
|
34 |
+
|
35 |
+
// Decide it is positive or negative case.
|
36 |
+
T c1 = (t == (d+1));
|
37 |
+
T c2 = (t>=0 & t != (d+1));
|
38 |
+
|
39 |
+
T zn = (1.0 - alpha);
|
40 |
+
T zp = (alpha);
|
41 |
+
|
42 |
+
// p = 1. / 1. + expf(-x); p = sigmoid(x)
|
43 |
+
T p = 1. / (1. + expf(-logits[i]));
|
44 |
+
|
45 |
+
// (1-p)**gamma * log(p) where
|
46 |
+
T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
|
47 |
+
|
48 |
+
// p**gamma * log(1-p)
|
49 |
+
T term2 = powf(p, gamma) *
|
50 |
+
(-1. * logits[i] * (logits[i] >= 0) -
|
51 |
+
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
|
52 |
+
|
53 |
+
losses[i] = 0.0;
|
54 |
+
losses[i] += -c1 * term1 * zp;
|
55 |
+
losses[i] += -c2 * term2 * zn;
|
56 |
+
|
57 |
+
} // CUDA_1D_KERNEL_LOOP
|
58 |
+
} // SigmoidFocalLossForward
|
59 |
+
|
60 |
+
|
61 |
+
template <typename T>
|
62 |
+
__global__ void SigmoidFocalLossBackward(const int nthreads,
|
63 |
+
const T* logits,
|
64 |
+
const int* targets,
|
65 |
+
const T* d_losses,
|
66 |
+
const int num_classes,
|
67 |
+
const float gamma,
|
68 |
+
const float alpha,
|
69 |
+
const int num,
|
70 |
+
T* d_logits) {
|
71 |
+
CUDA_1D_KERNEL_LOOP(i, nthreads) {
|
72 |
+
|
73 |
+
int n = i / num_classes;
|
74 |
+
int d = i % num_classes; // current class[0~79];
|
75 |
+
int t = targets[n]; // target class [1~80], 0 is background;
|
76 |
+
|
77 |
+
// Decide it is positive or negative case.
|
78 |
+
T c1 = (t == (d+1));
|
79 |
+
T c2 = (t>=0 & t != (d+1));
|
80 |
+
|
81 |
+
T zn = (1.0 - alpha);
|
82 |
+
T zp = (alpha);
|
83 |
+
// p = 1. / 1. + expf(-x); p = sigmoid(x)
|
84 |
+
T p = 1. / (1. + expf(-logits[i]));
|
85 |
+
|
86 |
+
// (1-p)**g * (1 - p - g*p*log(p)
|
87 |
+
T term1 = powf((1. - p), gamma) *
|
88 |
+
(1. - p - (p * gamma * logf(max(p, FLT_MIN))));
|
89 |
+
|
90 |
+
// (p**g) * (g*(1-p)*log(1-p) - p)
|
91 |
+
T term2 = powf(p, gamma) *
|
92 |
+
((-1. * logits[i] * (logits[i] >= 0) -
|
93 |
+
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
|
94 |
+
(1. - p) * gamma - p);
|
95 |
+
d_logits[i] = 0.0;
|
96 |
+
d_logits[i] += -c1 * term1 * zp;
|
97 |
+
d_logits[i] += -c2 * term2 * zn;
|
98 |
+
d_logits[i] = d_logits[i] * d_losses[i];
|
99 |
+
|
100 |
+
} // CUDA_1D_KERNEL_LOOP
|
101 |
+
} // SigmoidFocalLossBackward
|
102 |
+
|
103 |
+
|
104 |
+
at::Tensor SigmoidFocalLoss_forward_cuda(
|
105 |
+
const at::Tensor& logits,
|
106 |
+
const at::Tensor& targets,
|
107 |
+
const int num_classes,
|
108 |
+
const float gamma,
|
109 |
+
const float alpha) {
|
110 |
+
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
|
111 |
+
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
|
112 |
+
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
|
113 |
+
|
114 |
+
const int num_samples = logits.size(0);
|
115 |
+
|
116 |
+
auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
|
117 |
+
auto losses_size = num_samples * logits.size(1);
|
118 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
119 |
+
|
120 |
+
dim3 grid(std::min(THCCeilDiv((long)losses_size, 512L), 4096L));
|
121 |
+
|
122 |
+
dim3 block(512);
|
123 |
+
|
124 |
+
if (losses.numel() == 0) {
|
125 |
+
THCudaCheck(cudaGetLastError());
|
126 |
+
return losses;
|
127 |
+
}
|
128 |
+
|
129 |
+
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] {
|
130 |
+
SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
|
131 |
+
losses_size,
|
132 |
+
logits.contiguous().data<scalar_t>(),
|
133 |
+
targets.contiguous().data<int>(),
|
134 |
+
num_classes,
|
135 |
+
gamma,
|
136 |
+
alpha,
|
137 |
+
num_samples,
|
138 |
+
losses.data<scalar_t>());
|
139 |
+
});
|
140 |
+
THCudaCheck(cudaGetLastError());
|
141 |
+
return losses;
|
142 |
+
}
|
143 |
+
|
144 |
+
|
145 |
+
at::Tensor SigmoidFocalLoss_backward_cuda(
|
146 |
+
const at::Tensor& logits,
|
147 |
+
const at::Tensor& targets,
|
148 |
+
const at::Tensor& d_losses,
|
149 |
+
const int num_classes,
|
150 |
+
const float gamma,
|
151 |
+
const float alpha) {
|
152 |
+
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
|
153 |
+
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
|
154 |
+
AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor");
|
155 |
+
|
156 |
+
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
|
157 |
+
|
158 |
+
const int num_samples = logits.size(0);
|
159 |
+
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");
|
160 |
+
|
161 |
+
auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
|
162 |
+
auto d_logits_size = num_samples * logits.size(1);
|
163 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
164 |
+
|
165 |
+
dim3 grid(std::min(THCCeilDiv((long)d_logits_size, 512L), 4096L));
|
166 |
+
dim3 block(512);
|
167 |
+
|
168 |
+
if (d_logits.numel() == 0) {
|
169 |
+
THCudaCheck(cudaGetLastError());
|
170 |
+
return d_logits;
|
171 |
+
}
|
172 |
+
|
173 |
+
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] {
|
174 |
+
SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
|
175 |
+
d_logits_size,
|
176 |
+
logits.contiguous().data<scalar_t>(),
|
177 |
+
targets.contiguous().data<int>(),
|
178 |
+
d_losses.contiguous().data<scalar_t>(),
|
179 |
+
num_classes,
|
180 |
+
gamma,
|
181 |
+
alpha,
|
182 |
+
num_samples,
|
183 |
+
d_logits.data<scalar_t>());
|
184 |
+
});
|
185 |
+
|
186 |
+
THCudaCheck(cudaGetLastError());
|
187 |
+
return d_logits;
|
188 |
+
}
|
189 |
+
|
maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu
ADDED
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// modify from
|
2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
3 |
+
|
4 |
+
#include <ATen/ATen.h>
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
|
7 |
+
#include <THC/THC.h>
|
8 |
+
#include <THC/THCDeviceUtils.cuh>
|
9 |
+
|
10 |
+
#include <vector>
|
11 |
+
#include <iostream>
|
12 |
+
#include <cmath>
|
13 |
+
|
14 |
+
|
15 |
+
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
16 |
+
const int channels, const int height, const int width,
|
17 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
18 |
+
const int pad_w, const int stride_h, const int stride_w,
|
19 |
+
const int dilation_h, const int dilation_w,
|
20 |
+
const int parallel_imgs, const int deformable_group,
|
21 |
+
at::Tensor data_col);
|
22 |
+
|
23 |
+
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
24 |
+
const int channels, const int height, const int width,
|
25 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
26 |
+
const int pad_w, const int stride_h, const int stride_w,
|
27 |
+
const int dilation_h, const int dilation_w,
|
28 |
+
const int parallel_imgs, const int deformable_group,
|
29 |
+
at::Tensor grad_im);
|
30 |
+
|
31 |
+
void deformable_col2im_coord(
|
32 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
33 |
+
const at::Tensor data_offset, const int channels, const int height,
|
34 |
+
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
35 |
+
const int pad_w, const int stride_h, const int stride_w,
|
36 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
37 |
+
const int deformable_group, at::Tensor grad_offset);
|
38 |
+
|
39 |
+
void modulated_deformable_im2col_cuda(
|
40 |
+
const at::Tensor data_im, const at::Tensor data_offset,
|
41 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
42 |
+
const int height_im, const int width_im, const int height_col,
|
43 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
44 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
45 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
46 |
+
at::Tensor data_col);
|
47 |
+
|
48 |
+
void modulated_deformable_col2im_cuda(
|
49 |
+
const at::Tensor data_col, const at::Tensor data_offset,
|
50 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
51 |
+
const int height_im, const int width_im, const int height_col,
|
52 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
53 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
54 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
55 |
+
at::Tensor grad_im);
|
56 |
+
|
57 |
+
void modulated_deformable_col2im_coord_cuda(
|
58 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
59 |
+
const at::Tensor data_offset, const at::Tensor data_mask,
|
60 |
+
const int batch_size, const int channels, const int height_im,
|
61 |
+
const int width_im, const int height_col, const int width_col,
|
62 |
+
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
63 |
+
const int stride_h, const int stride_w, const int dilation_h,
|
64 |
+
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
65 |
+
at::Tensor grad_mask);
|
66 |
+
|
67 |
+
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
68 |
+
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
69 |
+
int padW, int dilationH, int dilationW, int group,
|
70 |
+
int deformable_group)
|
71 |
+
{
|
72 |
+
AT_CHECK(weight.ndimension() == 4,
|
73 |
+
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
74 |
+
"but got: %s",
|
75 |
+
weight.ndimension());
|
76 |
+
|
77 |
+
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
78 |
+
|
79 |
+
AT_CHECK(kW > 0 && kH > 0,
|
80 |
+
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
81 |
+
kW);
|
82 |
+
|
83 |
+
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
84 |
+
"kernel size should be consistent with weight, ",
|
85 |
+
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
86 |
+
kW, weight.size(2), weight.size(3));
|
87 |
+
|
88 |
+
AT_CHECK(dW > 0 && dH > 0,
|
89 |
+
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
90 |
+
|
91 |
+
AT_CHECK(
|
92 |
+
dilationW > 0 && dilationH > 0,
|
93 |
+
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
94 |
+
dilationH, dilationW);
|
95 |
+
|
96 |
+
int ndim = input.ndimension();
|
97 |
+
int dimf = 0;
|
98 |
+
int dimh = 1;
|
99 |
+
int dimw = 2;
|
100 |
+
|
101 |
+
if (ndim == 4) {
|
102 |
+
dimf++;
|
103 |
+
dimh++;
|
104 |
+
dimw++;
|
105 |
+
}
|
106 |
+
|
107 |
+
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
108 |
+
ndim);
|
109 |
+
|
110 |
+
long nInputPlane = weight.size(1) * group;
|
111 |
+
long inputHeight = input.size(dimh);
|
112 |
+
long inputWidth = input.size(dimw);
|
113 |
+
long nOutputPlane = weight.size(0);
|
114 |
+
long outputHeight =
|
115 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
116 |
+
long outputWidth =
|
117 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
118 |
+
|
119 |
+
AT_CHECK(nInputPlane % deformable_group == 0,
|
120 |
+
"input channels must divide deformable group size");
|
121 |
+
|
122 |
+
if (outputWidth < 1 || outputHeight < 1)
|
123 |
+
AT_ERROR(
|
124 |
+
"Given input size: (%ld x %ld x %ld). "
|
125 |
+
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
126 |
+
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
127 |
+
outputWidth);
|
128 |
+
|
129 |
+
AT_CHECK(input.size(1) == nInputPlane,
|
130 |
+
"invalid number of input planes, expected: %d, but got: %d",
|
131 |
+
nInputPlane, input.size(1));
|
132 |
+
|
133 |
+
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
|
134 |
+
"input image is smaller than kernel");
|
135 |
+
|
136 |
+
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
137 |
+
"invalid spatial size of offset, expected height: %d width: %d, but "
|
138 |
+
"got height: %d width: %d",
|
139 |
+
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
140 |
+
|
141 |
+
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
142 |
+
"invalid number of channels of offset");
|
143 |
+
|
144 |
+
if (gradOutput != NULL) {
|
145 |
+
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
146 |
+
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
147 |
+
nOutputPlane, gradOutput->size(dimf));
|
148 |
+
|
149 |
+
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
|
150 |
+
gradOutput->size(dimw) == outputWidth),
|
151 |
+
"invalid size of gradOutput, expected height: %d width: %d , but "
|
152 |
+
"got height: %d width: %d",
|
153 |
+
outputHeight, outputWidth, gradOutput->size(dimh),
|
154 |
+
gradOutput->size(dimw));
|
155 |
+
}
|
156 |
+
}
|
157 |
+
|
158 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
159 |
+
at::Tensor offset, at::Tensor output,
|
160 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
161 |
+
int kH, int dW, int dH, int padW, int padH,
|
162 |
+
int dilationW, int dilationH, int group,
|
163 |
+
int deformable_group, int im2col_step)
|
164 |
+
{
|
165 |
+
// todo: resize columns to include im2col: done
|
166 |
+
// todo: add im2col_step as input
|
167 |
+
// todo: add new output buffer and transpose it to output (or directly
|
168 |
+
// transpose output) todo: possibly change data indexing because of
|
169 |
+
// parallel_imgs
|
170 |
+
|
171 |
+
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
172 |
+
dilationH, dilationW, group, deformable_group);
|
173 |
+
|
174 |
+
input = input.contiguous();
|
175 |
+
offset = offset.contiguous();
|
176 |
+
weight = weight.contiguous();
|
177 |
+
|
178 |
+
int batch = 1;
|
179 |
+
if (input.ndimension() == 3) {
|
180 |
+
// Force batch
|
181 |
+
batch = 0;
|
182 |
+
input.unsqueeze_(0);
|
183 |
+
offset.unsqueeze_(0);
|
184 |
+
}
|
185 |
+
|
186 |
+
// todo: assert batchsize dividable by im2col_step
|
187 |
+
|
188 |
+
long batchSize = input.size(0);
|
189 |
+
long nInputPlane = input.size(1);
|
190 |
+
long inputHeight = input.size(2);
|
191 |
+
long inputWidth = input.size(3);
|
192 |
+
|
193 |
+
long nOutputPlane = weight.size(0);
|
194 |
+
|
195 |
+
long outputWidth =
|
196 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
197 |
+
long outputHeight =
|
198 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
199 |
+
|
200 |
+
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
201 |
+
|
202 |
+
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
203 |
+
outputHeight, outputWidth});
|
204 |
+
columns = at::zeros(
|
205 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
206 |
+
input.options());
|
207 |
+
|
208 |
+
if (ones.ndimension() != 2 ||
|
209 |
+
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
210 |
+
ones = at::ones({outputHeight, outputWidth}, input.options());
|
211 |
+
}
|
212 |
+
|
213 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
214 |
+
inputHeight, inputWidth});
|
215 |
+
offset =
|
216 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
217 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
218 |
+
|
219 |
+
at::Tensor output_buffer =
|
220 |
+
at::zeros({batchSize / im2col_step, nOutputPlane,
|
221 |
+
im2col_step * outputHeight, outputWidth},
|
222 |
+
output.options());
|
223 |
+
|
224 |
+
output_buffer = output_buffer.view(
|
225 |
+
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
226 |
+
output_buffer.size(2), output_buffer.size(3)});
|
227 |
+
|
228 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
229 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
230 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
231 |
+
dilationW, im2col_step, deformable_group, columns);
|
232 |
+
|
233 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
234 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
235 |
+
weight.size(2), weight.size(3)});
|
236 |
+
|
237 |
+
for (int g = 0; g < group; g++) {
|
238 |
+
output_buffer[elt][g] = output_buffer[elt][g]
|
239 |
+
.flatten(1)
|
240 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
241 |
+
.view_as(output_buffer[elt][g]);
|
242 |
+
}
|
243 |
+
}
|
244 |
+
|
245 |
+
output_buffer = output_buffer.view(
|
246 |
+
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
247 |
+
output_buffer.size(3), output_buffer.size(4)});
|
248 |
+
|
249 |
+
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
250 |
+
im2col_step, outputHeight, outputWidth});
|
251 |
+
output_buffer.transpose_(1, 2);
|
252 |
+
output.copy_(output_buffer);
|
253 |
+
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
254 |
+
|
255 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
256 |
+
offset = offset.view(
|
257 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
258 |
+
|
259 |
+
if (batch == 0) {
|
260 |
+
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
261 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
262 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
263 |
+
}
|
264 |
+
|
265 |
+
return 1;
|
266 |
+
}
|
267 |
+
|
268 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
269 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
270 |
+
at::Tensor gradOffset, at::Tensor weight,
|
271 |
+
at::Tensor columns, int kW, int kH, int dW,
|
272 |
+
int dH, int padW, int padH, int dilationW,
|
273 |
+
int dilationH, int group,
|
274 |
+
int deformable_group, int im2col_step)
|
275 |
+
{
|
276 |
+
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
277 |
+
dilationH, dilationW, group, deformable_group);
|
278 |
+
|
279 |
+
input = input.contiguous();
|
280 |
+
offset = offset.contiguous();
|
281 |
+
gradOutput = gradOutput.contiguous();
|
282 |
+
weight = weight.contiguous();
|
283 |
+
|
284 |
+
int batch = 1;
|
285 |
+
|
286 |
+
if (input.ndimension() == 3) {
|
287 |
+
// Force batch
|
288 |
+
batch = 0;
|
289 |
+
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
290 |
+
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
291 |
+
gradOutput = gradOutput.view(
|
292 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
293 |
+
}
|
294 |
+
|
295 |
+
long batchSize = input.size(0);
|
296 |
+
long nInputPlane = input.size(1);
|
297 |
+
long inputHeight = input.size(2);
|
298 |
+
long inputWidth = input.size(3);
|
299 |
+
|
300 |
+
long nOutputPlane = weight.size(0);
|
301 |
+
|
302 |
+
long outputWidth =
|
303 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
304 |
+
long outputHeight =
|
305 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
306 |
+
|
307 |
+
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
308 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
309 |
+
columns = at::zeros(
|
310 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
311 |
+
input.options());
|
312 |
+
|
313 |
+
// change order of grad output
|
314 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
315 |
+
nOutputPlane, outputHeight, outputWidth});
|
316 |
+
gradOutput.transpose_(1, 2);
|
317 |
+
|
318 |
+
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
319 |
+
inputHeight, inputWidth});
|
320 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
321 |
+
inputHeight, inputWidth});
|
322 |
+
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
323 |
+
deformable_group * 2 * kH * kW, outputHeight,
|
324 |
+
outputWidth});
|
325 |
+
offset =
|
326 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
327 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
328 |
+
|
329 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
330 |
+
// divide into groups
|
331 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
332 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
333 |
+
weight.size(2), weight.size(3)});
|
334 |
+
gradOutput = gradOutput.view(
|
335 |
+
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
336 |
+
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
337 |
+
|
338 |
+
for (int g = 0; g < group; g++) {
|
339 |
+
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
340 |
+
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
341 |
+
}
|
342 |
+
|
343 |
+
columns =
|
344 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
345 |
+
gradOutput = gradOutput.view(
|
346 |
+
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
347 |
+
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
348 |
+
|
349 |
+
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
350 |
+
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
351 |
+
dilationH, dilationW, im2col_step, deformable_group,
|
352 |
+
gradOffset[elt]);
|
353 |
+
|
354 |
+
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
355 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
356 |
+
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
357 |
+
}
|
358 |
+
|
359 |
+
gradOutput.transpose_(1, 2);
|
360 |
+
gradOutput =
|
361 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
362 |
+
|
363 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
364 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
365 |
+
gradOffset = gradOffset.view(
|
366 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
367 |
+
offset = offset.view(
|
368 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
369 |
+
|
370 |
+
if (batch == 0) {
|
371 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
372 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
373 |
+
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
374 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
375 |
+
gradOffset =
|
376 |
+
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
377 |
+
}
|
378 |
+
|
379 |
+
return 1;
|
380 |
+
}
|
381 |
+
|
382 |
+
int deform_conv_backward_parameters_cuda(
|
383 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
384 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
385 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
386 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
387 |
+
int deformable_group, float scale, int im2col_step)
|
388 |
+
{
|
389 |
+
// todo: transpose and reshape outGrad
|
390 |
+
// todo: reshape columns
|
391 |
+
// todo: add im2col_step as input
|
392 |
+
|
393 |
+
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
394 |
+
padW, dilationH, dilationW, group, deformable_group);
|
395 |
+
|
396 |
+
input = input.contiguous();
|
397 |
+
offset = offset.contiguous();
|
398 |
+
gradOutput = gradOutput.contiguous();
|
399 |
+
|
400 |
+
int batch = 1;
|
401 |
+
|
402 |
+
if (input.ndimension() == 3) {
|
403 |
+
// Force batch
|
404 |
+
batch = 0;
|
405 |
+
input = input.view(
|
406 |
+
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
407 |
+
gradOutput = gradOutput.view(
|
408 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
409 |
+
}
|
410 |
+
|
411 |
+
long batchSize = input.size(0);
|
412 |
+
long nInputPlane = input.size(1);
|
413 |
+
long inputHeight = input.size(2);
|
414 |
+
long inputWidth = input.size(3);
|
415 |
+
|
416 |
+
long nOutputPlane = gradWeight.size(0);
|
417 |
+
|
418 |
+
long outputWidth =
|
419 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
420 |
+
long outputHeight =
|
421 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
422 |
+
|
423 |
+
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
424 |
+
|
425 |
+
columns = at::zeros(
|
426 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
427 |
+
input.options());
|
428 |
+
|
429 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
430 |
+
nOutputPlane, outputHeight, outputWidth});
|
431 |
+
gradOutput.transpose_(1, 2);
|
432 |
+
|
433 |
+
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
434 |
+
gradOutputBuffer =
|
435 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
436 |
+
outputHeight, outputWidth});
|
437 |
+
gradOutputBuffer.copy_(gradOutput);
|
438 |
+
gradOutputBuffer =
|
439 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
440 |
+
im2col_step * outputHeight, outputWidth});
|
441 |
+
|
442 |
+
gradOutput.transpose_(1, 2);
|
443 |
+
gradOutput =
|
444 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
445 |
+
|
446 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
447 |
+
inputHeight, inputWidth});
|
448 |
+
offset =
|
449 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
450 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
451 |
+
|
452 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
453 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
454 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
455 |
+
dilationW, im2col_step, deformable_group, columns);
|
456 |
+
|
457 |
+
// divide into group
|
458 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
459 |
+
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
460 |
+
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
461 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
462 |
+
gradWeight =
|
463 |
+
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
464 |
+
gradWeight.size(2), gradWeight.size(3)});
|
465 |
+
|
466 |
+
for (int g = 0; g < group; g++) {
|
467 |
+
gradWeight[g] = gradWeight[g]
|
468 |
+
.flatten(1)
|
469 |
+
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
470 |
+
columns[g].transpose(1, 0), 1.0, scale)
|
471 |
+
.view_as(gradWeight[g]);
|
472 |
+
}
|
473 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
474 |
+
{gradOutputBuffer.size(0),
|
475 |
+
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
476 |
+
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
477 |
+
columns =
|
478 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
479 |
+
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
480 |
+
gradWeight.size(2), gradWeight.size(3),
|
481 |
+
gradWeight.size(4)});
|
482 |
+
}
|
483 |
+
|
484 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
485 |
+
offset = offset.view(
|
486 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
487 |
+
|
488 |
+
if (batch == 0) {
|
489 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
490 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
491 |
+
}
|
492 |
+
|
493 |
+
return 1;
|
494 |
+
}
|
495 |
+
|
496 |
+
void modulated_deform_conv_cuda_forward(
|
497 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
498 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
499 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
500 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
501 |
+
const int dilation_w, const int group, const int deformable_group,
|
502 |
+
const bool with_bias)
|
503 |
+
{
|
504 |
+
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
505 |
+
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
506 |
+
|
507 |
+
const int batch = input.size(0);
|
508 |
+
const int channels = input.size(1);
|
509 |
+
const int height = input.size(2);
|
510 |
+
const int width = input.size(3);
|
511 |
+
|
512 |
+
const int channels_out = weight.size(0);
|
513 |
+
const int channels_kernel = weight.size(1);
|
514 |
+
const int kernel_h_ = weight.size(2);
|
515 |
+
const int kernel_w_ = weight.size(3);
|
516 |
+
|
517 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
518 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
519 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
520 |
+
if (channels != channels_kernel * group)
|
521 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
522 |
+
channels, channels_kernel * group);
|
523 |
+
|
524 |
+
const int height_out =
|
525 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
526 |
+
const int width_out =
|
527 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
528 |
+
|
529 |
+
if (ones.ndimension() != 2 ||
|
530 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
531 |
+
// Resize plane and fill with ones...
|
532 |
+
ones = at::ones({height_out, width_out}, input.options());
|
533 |
+
}
|
534 |
+
|
535 |
+
// resize output
|
536 |
+
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
537 |
+
// resize temporary columns
|
538 |
+
columns =
|
539 |
+
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
540 |
+
input.options());
|
541 |
+
|
542 |
+
output = output.view({output.size(0), group, output.size(1) / group,
|
543 |
+
output.size(2), output.size(3)});
|
544 |
+
|
545 |
+
for (int b = 0; b < batch; b++) {
|
546 |
+
modulated_deformable_im2col_cuda(
|
547 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
548 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
549 |
+
dilation_h, dilation_w, deformable_group, columns);
|
550 |
+
|
551 |
+
// divide into group
|
552 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
553 |
+
weight.size(2), weight.size(3)});
|
554 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
555 |
+
|
556 |
+
for (int g = 0; g < group; g++) {
|
557 |
+
output[b][g] = output[b][g]
|
558 |
+
.flatten(1)
|
559 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
560 |
+
.view_as(output[b][g]);
|
561 |
+
}
|
562 |
+
|
563 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
564 |
+
weight.size(3), weight.size(4)});
|
565 |
+
columns =
|
566 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
567 |
+
}
|
568 |
+
|
569 |
+
output = output.view({output.size(0), output.size(1) * output.size(2),
|
570 |
+
output.size(3), output.size(4)});
|
571 |
+
|
572 |
+
if (with_bias) {
|
573 |
+
output += bias.view({1, bias.size(0), 1, 1});
|
574 |
+
}
|
575 |
+
}
|
576 |
+
|
577 |
+
void modulated_deform_conv_cuda_backward(
|
578 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
579 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
580 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
581 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
582 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
583 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
584 |
+
const bool with_bias)
|
585 |
+
{
|
586 |
+
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
587 |
+
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
588 |
+
|
589 |
+
const int batch = input.size(0);
|
590 |
+
const int channels = input.size(1);
|
591 |
+
const int height = input.size(2);
|
592 |
+
const int width = input.size(3);
|
593 |
+
|
594 |
+
const int channels_kernel = weight.size(1);
|
595 |
+
const int kernel_h_ = weight.size(2);
|
596 |
+
const int kernel_w_ = weight.size(3);
|
597 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
598 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
599 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
600 |
+
if (channels != channels_kernel * group)
|
601 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
602 |
+
channels, channels_kernel * group);
|
603 |
+
|
604 |
+
const int height_out =
|
605 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
606 |
+
const int width_out =
|
607 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
608 |
+
|
609 |
+
if (ones.ndimension() != 2 ||
|
610 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
611 |
+
// Resize plane and fill with ones...
|
612 |
+
ones = at::ones({height_out, width_out}, input.options());
|
613 |
+
}
|
614 |
+
|
615 |
+
grad_input = grad_input.view({batch, channels, height, width});
|
616 |
+
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
617 |
+
input.options());
|
618 |
+
|
619 |
+
grad_output =
|
620 |
+
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
621 |
+
grad_output.size(2), grad_output.size(3)});
|
622 |
+
|
623 |
+
for (int b = 0; b < batch; b++) {
|
624 |
+
// divide int group
|
625 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
626 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
627 |
+
weight.size(2), weight.size(3)});
|
628 |
+
|
629 |
+
for (int g = 0; g < group; g++) {
|
630 |
+
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
631 |
+
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
632 |
+
}
|
633 |
+
|
634 |
+
columns =
|
635 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
636 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
637 |
+
weight.size(3), weight.size(4)});
|
638 |
+
|
639 |
+
// gradient w.r.t. input coordinate data
|
640 |
+
modulated_deformable_col2im_coord_cuda(
|
641 |
+
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
642 |
+
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
643 |
+
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
644 |
+
grad_mask[b]);
|
645 |
+
// gradient w.r.t. input data
|
646 |
+
modulated_deformable_col2im_cuda(
|
647 |
+
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
648 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
649 |
+
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
650 |
+
|
651 |
+
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
652 |
+
// group
|
653 |
+
modulated_deformable_im2col_cuda(
|
654 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
655 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
656 |
+
dilation_h, dilation_w, deformable_group, columns);
|
657 |
+
|
658 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
659 |
+
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
660 |
+
grad_weight.size(1), grad_weight.size(2),
|
661 |
+
grad_weight.size(3)});
|
662 |
+
if (with_bias)
|
663 |
+
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
664 |
+
|
665 |
+
for (int g = 0; g < group; g++) {
|
666 |
+
grad_weight[g] =
|
667 |
+
grad_weight[g]
|
668 |
+
.flatten(1)
|
669 |
+
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
670 |
+
.view_as(grad_weight[g]);
|
671 |
+
if (with_bias) {
|
672 |
+
grad_bias[g] =
|
673 |
+
grad_bias[g]
|
674 |
+
.view({-1, 1})
|
675 |
+
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
676 |
+
.view(-1);
|
677 |
+
}
|
678 |
+
}
|
679 |
+
|
680 |
+
columns =
|
681 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
682 |
+
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
683 |
+
grad_weight.size(2), grad_weight.size(3),
|
684 |
+
grad_weight.size(4)});
|
685 |
+
if (with_bias)
|
686 |
+
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
687 |
+
}
|
688 |
+
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
689 |
+
grad_output.size(2), grad_output.size(3),
|
690 |
+
grad_output.size(4)});
|
691 |
+
}
|
maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu
ADDED
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
|
3 |
+
*
|
4 |
+
* COPYRIGHT
|
5 |
+
*
|
6 |
+
* All contributions by the University of California:
|
7 |
+
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
8 |
+
* All rights reserved.
|
9 |
+
*
|
10 |
+
* All other contributions:
|
11 |
+
* Copyright (c) 2014-2017, the respective contributors
|
12 |
+
* All rights reserved.
|
13 |
+
*
|
14 |
+
* Caffe uses a shared copyright model: each contributor holds copyright over
|
15 |
+
* their contributions to Caffe. The project versioning records all such
|
16 |
+
* contribution and copyright details. If a contributor wants to further mark
|
17 |
+
* their specific copyright on a particular contribution, they should indicate
|
18 |
+
* their copyright solely in the commit message of the change when it is
|
19 |
+
* committed.
|
20 |
+
*
|
21 |
+
* LICENSE
|
22 |
+
*
|
23 |
+
* Redistribution and use in source and binary forms, with or without
|
24 |
+
* modification, are permitted provided that the following conditions are met:
|
25 |
+
*
|
26 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
27 |
+
* list of conditions and the following disclaimer.
|
28 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
29 |
+
* this list of conditions and the following disclaimer in the documentation
|
30 |
+
* and/or other materials provided with the distribution.
|
31 |
+
*
|
32 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
33 |
+
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
34 |
+
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
35 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
36 |
+
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
37 |
+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
38 |
+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
39 |
+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
40 |
+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
41 |
+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
42 |
+
*
|
43 |
+
* CONTRIBUTION AGREEMENT
|
44 |
+
*
|
45 |
+
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
46 |
+
* or otherwise, the contributor releases their content to the
|
47 |
+
* license and copyright terms herein.
|
48 |
+
*
|
49 |
+
***************** END Caffe Copyright Notice and Disclaimer ********************
|
50 |
+
*
|
51 |
+
* Copyright (c) 2018 Microsoft
|
52 |
+
* Licensed under The MIT License [see LICENSE for details]
|
53 |
+
* \file modulated_deformable_im2col.cuh
|
54 |
+
* \brief Function definitions of converting an image to
|
55 |
+
* column matrix based on kernel, padding, dilation, and offset.
|
56 |
+
* These functions are mainly used in deformable convolution operators.
|
57 |
+
* \ref: https://arxiv.org/abs/1703.06211
|
58 |
+
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
59 |
+
*/
|
60 |
+
|
61 |
+
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
62 |
+
|
63 |
+
|
64 |
+
#include <ATen/ATen.h>
|
65 |
+
#include <THC/THCAtomics.cuh>
|
66 |
+
#include <stdio.h>
|
67 |
+
#include <math.h>
|
68 |
+
#include <float.h>
|
69 |
+
|
70 |
+
using namespace at;
|
71 |
+
|
72 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
73 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
74 |
+
i += blockDim.x * gridDim.x)
|
75 |
+
|
76 |
+
const int CUDA_NUM_THREADS = 1024;
|
77 |
+
const int kMaxGridNum = 65535;
|
78 |
+
inline int GET_BLOCKS(const int N)
|
79 |
+
{
|
80 |
+
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
81 |
+
}
|
82 |
+
|
83 |
+
/*
|
84 |
+
const int CUDA_NUM_THREADS = 1024;
|
85 |
+
|
86 |
+
inline int GET_BLOCKS(const int N)
|
87 |
+
{
|
88 |
+
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
|
89 |
+
}*/
|
90 |
+
|
91 |
+
template <typename scalar_t>
|
92 |
+
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
93 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
94 |
+
{
|
95 |
+
|
96 |
+
int h_low = floor(h);
|
97 |
+
int w_low = floor(w);
|
98 |
+
int h_high = h_low + 1;
|
99 |
+
int w_high = w_low + 1;
|
100 |
+
|
101 |
+
scalar_t lh = h - h_low;
|
102 |
+
scalar_t lw = w - w_low;
|
103 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
104 |
+
|
105 |
+
scalar_t v1 = 0;
|
106 |
+
if (h_low >= 0 && w_low >= 0)
|
107 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
108 |
+
scalar_t v2 = 0;
|
109 |
+
if (h_low >= 0 && w_high <= width - 1)
|
110 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
111 |
+
scalar_t v3 = 0;
|
112 |
+
if (h_high <= height - 1 && w_low >= 0)
|
113 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
114 |
+
scalar_t v4 = 0;
|
115 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
116 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
117 |
+
|
118 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
119 |
+
|
120 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
121 |
+
return val;
|
122 |
+
}
|
123 |
+
|
124 |
+
template <typename scalar_t>
|
125 |
+
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
126 |
+
const int h, const int w, const int height, const int width)
|
127 |
+
{
|
128 |
+
|
129 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
130 |
+
{
|
131 |
+
//empty
|
132 |
+
return 0;
|
133 |
+
}
|
134 |
+
|
135 |
+
int argmax_h_low = floor(argmax_h);
|
136 |
+
int argmax_w_low = floor(argmax_w);
|
137 |
+
int argmax_h_high = argmax_h_low + 1;
|
138 |
+
int argmax_w_high = argmax_w_low + 1;
|
139 |
+
|
140 |
+
scalar_t weight = 0;
|
141 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
142 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
143 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
144 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
145 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
146 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
147 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
148 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
149 |
+
return weight;
|
150 |
+
}
|
151 |
+
|
152 |
+
template <typename scalar_t>
|
153 |
+
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
154 |
+
const int height, const int width, const scalar_t *im_data,
|
155 |
+
const int data_width, const int bp_dir)
|
156 |
+
{
|
157 |
+
|
158 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
159 |
+
{
|
160 |
+
//empty
|
161 |
+
return 0;
|
162 |
+
}
|
163 |
+
|
164 |
+
int argmax_h_low = floor(argmax_h);
|
165 |
+
int argmax_w_low = floor(argmax_w);
|
166 |
+
int argmax_h_high = argmax_h_low + 1;
|
167 |
+
int argmax_w_high = argmax_w_low + 1;
|
168 |
+
|
169 |
+
scalar_t weight = 0;
|
170 |
+
|
171 |
+
if (bp_dir == 0)
|
172 |
+
{
|
173 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
174 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
175 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
176 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
177 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
178 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
179 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
180 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
181 |
+
}
|
182 |
+
else if (bp_dir == 1)
|
183 |
+
{
|
184 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
185 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
186 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
187 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
188 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
189 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
190 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
191 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
192 |
+
}
|
193 |
+
|
194 |
+
return weight;
|
195 |
+
}
|
196 |
+
|
197 |
+
template <typename scalar_t>
|
198 |
+
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
|
199 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
200 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
201 |
+
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
|
202 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
203 |
+
const int height_col, const int width_col,
|
204 |
+
scalar_t *data_col)
|
205 |
+
{
|
206 |
+
CUDA_KERNEL_LOOP(index, n)
|
207 |
+
{
|
208 |
+
// index index of output matrix
|
209 |
+
const int w_col = index % width_col;
|
210 |
+
const int h_col = (index / width_col) % height_col;
|
211 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
212 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
213 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
214 |
+
|
215 |
+
// compute deformable group index
|
216 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
217 |
+
|
218 |
+
const int h_in = h_col * stride_h - pad_h;
|
219 |
+
const int w_in = w_col * stride_w - pad_w;
|
220 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
221 |
+
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
222 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
223 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
224 |
+
|
225 |
+
for (int i = 0; i < kernel_h; ++i)
|
226 |
+
{
|
227 |
+
for (int j = 0; j < kernel_w; ++j)
|
228 |
+
{
|
229 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
230 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
231 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
232 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
233 |
+
scalar_t val = static_cast<scalar_t>(0);
|
234 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
235 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
236 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
237 |
+
{
|
238 |
+
//const scalar_t map_h = i * dilation_h + offset_h;
|
239 |
+
//const scalar_t map_w = j * dilation_w + offset_w;
|
240 |
+
//const int cur_height = height - h_in;
|
241 |
+
//const int cur_width = width - w_in;
|
242 |
+
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
243 |
+
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
244 |
+
}
|
245 |
+
*data_col_ptr = val;
|
246 |
+
data_col_ptr += batch_size * height_col * width_col;
|
247 |
+
}
|
248 |
+
}
|
249 |
+
}
|
250 |
+
}
|
251 |
+
|
252 |
+
void deformable_im2col(
|
253 |
+
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
|
254 |
+
const int height, const int width, const int ksize_h, const int ksize_w,
|
255 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
256 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
257 |
+
const int deformable_group, at::Tensor data_col)
|
258 |
+
{
|
259 |
+
// num_axes should be smaller than block size
|
260 |
+
// todo: check parallel_imgs is correctly passed in
|
261 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
262 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
263 |
+
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
264 |
+
int channel_per_deformable_group = channels / deformable_group;
|
265 |
+
|
266 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
267 |
+
data_im.type(), "deformable_im2col_gpu", ([&] {
|
268 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
269 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
270 |
+
scalar_t *data_col_ = data_col.data<scalar_t>();
|
271 |
+
|
272 |
+
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
273 |
+
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
|
274 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
275 |
+
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
|
276 |
+
height_col, width_col, data_col_);
|
277 |
+
}));
|
278 |
+
|
279 |
+
cudaError_t err = cudaGetLastError();
|
280 |
+
if (err != cudaSuccess)
|
281 |
+
{
|
282 |
+
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
283 |
+
}
|
284 |
+
}
|
285 |
+
|
286 |
+
template <typename scalar_t>
|
287 |
+
__global__ void deformable_col2im_gpu_kernel(
|
288 |
+
const int n, const scalar_t *data_col, const scalar_t *data_offset,
|
289 |
+
const int channels, const int height, const int width,
|
290 |
+
const int kernel_h, const int kernel_w,
|
291 |
+
const int pad_h, const int pad_w,
|
292 |
+
const int stride_h, const int stride_w,
|
293 |
+
const int dilation_h, const int dilation_w,
|
294 |
+
const int channel_per_deformable_group,
|
295 |
+
const int batch_size, const int deformable_group,
|
296 |
+
const int height_col, const int width_col,
|
297 |
+
scalar_t *grad_im)
|
298 |
+
{
|
299 |
+
CUDA_KERNEL_LOOP(index, n)
|
300 |
+
{
|
301 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
302 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
303 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
304 |
+
// compute the start and end of the output
|
305 |
+
|
306 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
307 |
+
|
308 |
+
int w_out = index % width_col;
|
309 |
+
int h_out = (index / width_col) % height_col;
|
310 |
+
int b = (index / width_col / height_col) % batch_size;
|
311 |
+
int w_in = w_out * stride_w - pad_w;
|
312 |
+
int h_in = h_out * stride_h - pad_h;
|
313 |
+
|
314 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
|
315 |
+
2 * kernel_h * kernel_w * height_col * width_col;
|
316 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
317 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
318 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
319 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
320 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
321 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
322 |
+
|
323 |
+
const scalar_t cur_top_grad = data_col[index];
|
324 |
+
const int cur_h = (int)cur_inv_h_data;
|
325 |
+
const int cur_w = (int)cur_inv_w_data;
|
326 |
+
for (int dy = -2; dy <= 2; dy++)
|
327 |
+
{
|
328 |
+
for (int dx = -2; dx <= 2; dx++)
|
329 |
+
{
|
330 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
331 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
332 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
333 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
334 |
+
{
|
335 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
336 |
+
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
337 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
338 |
+
}
|
339 |
+
}
|
340 |
+
}
|
341 |
+
}
|
342 |
+
}
|
343 |
+
|
344 |
+
void deformable_col2im(
|
345 |
+
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
|
346 |
+
const int height, const int width, const int ksize_h,
|
347 |
+
const int ksize_w, const int pad_h, const int pad_w,
|
348 |
+
const int stride_h, const int stride_w,
|
349 |
+
const int dilation_h, const int dilation_w,
|
350 |
+
const int parallel_imgs, const int deformable_group,
|
351 |
+
at::Tensor grad_im)
|
352 |
+
{
|
353 |
+
|
354 |
+
// todo: make sure parallel_imgs is passed in correctly
|
355 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
356 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
357 |
+
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
358 |
+
int channel_per_deformable_group = channels / deformable_group;
|
359 |
+
|
360 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
361 |
+
data_col.type(), "deformable_col2im_gpu", ([&] {
|
362 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
363 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
364 |
+
scalar_t *grad_im_ = grad_im.data<scalar_t>();
|
365 |
+
|
366 |
+
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
367 |
+
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
|
368 |
+
ksize_w, pad_h, pad_w, stride_h, stride_w,
|
369 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
370 |
+
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
|
371 |
+
}));
|
372 |
+
|
373 |
+
cudaError_t err = cudaGetLastError();
|
374 |
+
if (err != cudaSuccess)
|
375 |
+
{
|
376 |
+
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
377 |
+
}
|
378 |
+
}
|
379 |
+
|
380 |
+
template <typename scalar_t>
|
381 |
+
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
|
382 |
+
const scalar_t *data_im, const scalar_t *data_offset,
|
383 |
+
const int channels, const int height, const int width,
|
384 |
+
const int kernel_h, const int kernel_w,
|
385 |
+
const int pad_h, const int pad_w,
|
386 |
+
const int stride_h, const int stride_w,
|
387 |
+
const int dilation_h, const int dilation_w,
|
388 |
+
const int channel_per_deformable_group,
|
389 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
390 |
+
const int height_col, const int width_col, scalar_t *grad_offset)
|
391 |
+
{
|
392 |
+
CUDA_KERNEL_LOOP(index, n)
|
393 |
+
{
|
394 |
+
scalar_t val = 0;
|
395 |
+
int w = index % width_col;
|
396 |
+
int h = (index / width_col) % height_col;
|
397 |
+
int c = (index / width_col / height_col) % offset_channels;
|
398 |
+
int b = (index / width_col / height_col) / offset_channels;
|
399 |
+
// compute the start and end of the output
|
400 |
+
|
401 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
402 |
+
const int col_step = kernel_h * kernel_w;
|
403 |
+
int cnt = 0;
|
404 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
|
405 |
+
batch_size * width_col * height_col;
|
406 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
|
407 |
+
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
408 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
409 |
+
kernel_h * kernel_w * height_col * width_col;
|
410 |
+
|
411 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
412 |
+
|
413 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
414 |
+
{
|
415 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
416 |
+
const int bp_dir = offset_c % 2;
|
417 |
+
|
418 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
419 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
420 |
+
int w_out = col_pos % width_col;
|
421 |
+
int h_out = (col_pos / width_col) % height_col;
|
422 |
+
int w_in = w_out * stride_w - pad_w;
|
423 |
+
int h_in = h_out * stride_h - pad_h;
|
424 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
425 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
426 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
427 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
428 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
429 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
430 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
431 |
+
{
|
432 |
+
inv_h = inv_w = -2;
|
433 |
+
}
|
434 |
+
const scalar_t weight = get_coordinate_weight(
|
435 |
+
inv_h, inv_w,
|
436 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
437 |
+
val += weight * data_col_ptr[col_pos];
|
438 |
+
cnt += 1;
|
439 |
+
}
|
440 |
+
|
441 |
+
grad_offset[index] = val;
|
442 |
+
}
|
443 |
+
}
|
444 |
+
|
445 |
+
void deformable_col2im_coord(
|
446 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
|
447 |
+
const int channels, const int height, const int width, const int ksize_h,
|
448 |
+
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
|
449 |
+
const int stride_w, const int dilation_h, const int dilation_w,
|
450 |
+
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
|
451 |
+
{
|
452 |
+
|
453 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
454 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
455 |
+
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
|
456 |
+
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
|
457 |
+
|
458 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
459 |
+
data_col.type(), "deformable_col2im_coord_gpu", ([&] {
|
460 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
461 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
462 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
463 |
+
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
|
464 |
+
|
465 |
+
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
466 |
+
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
|
467 |
+
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
468 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
469 |
+
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
470 |
+
height_col, width_col, grad_offset_);
|
471 |
+
}));
|
472 |
+
}
|
473 |
+
|
474 |
+
template <typename scalar_t>
|
475 |
+
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
476 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
477 |
+
{
|
478 |
+
int h_low = floor(h);
|
479 |
+
int w_low = floor(w);
|
480 |
+
int h_high = h_low + 1;
|
481 |
+
int w_high = w_low + 1;
|
482 |
+
|
483 |
+
scalar_t lh = h - h_low;
|
484 |
+
scalar_t lw = w - w_low;
|
485 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
486 |
+
|
487 |
+
scalar_t v1 = 0;
|
488 |
+
if (h_low >= 0 && w_low >= 0)
|
489 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
490 |
+
scalar_t v2 = 0;
|
491 |
+
if (h_low >= 0 && w_high <= width - 1)
|
492 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
493 |
+
scalar_t v3 = 0;
|
494 |
+
if (h_high <= height - 1 && w_low >= 0)
|
495 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
496 |
+
scalar_t v4 = 0;
|
497 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
498 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
499 |
+
|
500 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
501 |
+
|
502 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
503 |
+
return val;
|
504 |
+
}
|
505 |
+
|
506 |
+
template <typename scalar_t>
|
507 |
+
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
508 |
+
const int h, const int w, const int height, const int width)
|
509 |
+
{
|
510 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
511 |
+
{
|
512 |
+
//empty
|
513 |
+
return 0;
|
514 |
+
}
|
515 |
+
|
516 |
+
int argmax_h_low = floor(argmax_h);
|
517 |
+
int argmax_w_low = floor(argmax_w);
|
518 |
+
int argmax_h_high = argmax_h_low + 1;
|
519 |
+
int argmax_w_high = argmax_w_low + 1;
|
520 |
+
|
521 |
+
scalar_t weight = 0;
|
522 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
523 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
524 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
525 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
526 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
527 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
528 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
529 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
530 |
+
return weight;
|
531 |
+
}
|
532 |
+
|
533 |
+
template <typename scalar_t>
|
534 |
+
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
535 |
+
const int height, const int width, const scalar_t *im_data,
|
536 |
+
const int data_width, const int bp_dir)
|
537 |
+
{
|
538 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
539 |
+
{
|
540 |
+
//empty
|
541 |
+
return 0;
|
542 |
+
}
|
543 |
+
|
544 |
+
int argmax_h_low = floor(argmax_h);
|
545 |
+
int argmax_w_low = floor(argmax_w);
|
546 |
+
int argmax_h_high = argmax_h_low + 1;
|
547 |
+
int argmax_w_high = argmax_w_low + 1;
|
548 |
+
|
549 |
+
scalar_t weight = 0;
|
550 |
+
|
551 |
+
if (bp_dir == 0)
|
552 |
+
{
|
553 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
554 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
555 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
556 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
557 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
558 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
559 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
560 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
561 |
+
}
|
562 |
+
else if (bp_dir == 1)
|
563 |
+
{
|
564 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
565 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
566 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
567 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
568 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
569 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
570 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
571 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
572 |
+
}
|
573 |
+
|
574 |
+
return weight;
|
575 |
+
}
|
576 |
+
|
577 |
+
template <typename scalar_t>
|
578 |
+
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
|
579 |
+
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
|
580 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
581 |
+
const int pad_h, const int pad_w,
|
582 |
+
const int stride_h, const int stride_w,
|
583 |
+
const int dilation_h, const int dilation_w,
|
584 |
+
const int channel_per_deformable_group,
|
585 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
586 |
+
const int height_col, const int width_col,
|
587 |
+
scalar_t *data_col)
|
588 |
+
{
|
589 |
+
CUDA_KERNEL_LOOP(index, n)
|
590 |
+
{
|
591 |
+
// index index of output matrix
|
592 |
+
const int w_col = index % width_col;
|
593 |
+
const int h_col = (index / width_col) % height_col;
|
594 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
595 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
596 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
597 |
+
|
598 |
+
// compute deformable group index
|
599 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
600 |
+
|
601 |
+
const int h_in = h_col * stride_h - pad_h;
|
602 |
+
const int w_in = w_col * stride_w - pad_w;
|
603 |
+
|
604 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
605 |
+
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
606 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
607 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
608 |
+
|
609 |
+
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
610 |
+
|
611 |
+
for (int i = 0; i < kernel_h; ++i)
|
612 |
+
{
|
613 |
+
for (int j = 0; j < kernel_w; ++j)
|
614 |
+
{
|
615 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
616 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
617 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
618 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
619 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
620 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
621 |
+
scalar_t val = static_cast<scalar_t>(0);
|
622 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
623 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
624 |
+
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
625 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
626 |
+
{
|
627 |
+
//const float map_h = i * dilation_h + offset_h;
|
628 |
+
//const float map_w = j * dilation_w + offset_w;
|
629 |
+
//const int cur_height = height - h_in;
|
630 |
+
//const int cur_width = width - w_in;
|
631 |
+
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
632 |
+
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
633 |
+
}
|
634 |
+
*data_col_ptr = val * mask;
|
635 |
+
data_col_ptr += batch_size * height_col * width_col;
|
636 |
+
//data_col_ptr += height_col * width_col;
|
637 |
+
}
|
638 |
+
}
|
639 |
+
}
|
640 |
+
}
|
641 |
+
|
642 |
+
template <typename scalar_t>
|
643 |
+
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
|
644 |
+
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
|
645 |
+
const int channels, const int height, const int width,
|
646 |
+
const int kernel_h, const int kernel_w,
|
647 |
+
const int pad_h, const int pad_w,
|
648 |
+
const int stride_h, const int stride_w,
|
649 |
+
const int dilation_h, const int dilation_w,
|
650 |
+
const int channel_per_deformable_group,
|
651 |
+
const int batch_size, const int deformable_group,
|
652 |
+
const int height_col, const int width_col,
|
653 |
+
scalar_t *grad_im)
|
654 |
+
{
|
655 |
+
CUDA_KERNEL_LOOP(index, n)
|
656 |
+
{
|
657 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
658 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
659 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
660 |
+
// compute the start and end of the output
|
661 |
+
|
662 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
663 |
+
|
664 |
+
int w_out = index % width_col;
|
665 |
+
int h_out = (index / width_col) % height_col;
|
666 |
+
int b = (index / width_col / height_col) % batch_size;
|
667 |
+
int w_in = w_out * stride_w - pad_w;
|
668 |
+
int h_in = h_out * stride_h - pad_h;
|
669 |
+
|
670 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
671 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
672 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
673 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
674 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
675 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
676 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
677 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
678 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
679 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
680 |
+
|
681 |
+
const scalar_t cur_top_grad = data_col[index] * mask;
|
682 |
+
const int cur_h = (int)cur_inv_h_data;
|
683 |
+
const int cur_w = (int)cur_inv_w_data;
|
684 |
+
for (int dy = -2; dy <= 2; dy++)
|
685 |
+
{
|
686 |
+
for (int dx = -2; dx <= 2; dx++)
|
687 |
+
{
|
688 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
689 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
690 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
691 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
692 |
+
{
|
693 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
694 |
+
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
695 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
696 |
+
}
|
697 |
+
}
|
698 |
+
}
|
699 |
+
}
|
700 |
+
}
|
701 |
+
|
702 |
+
template <typename scalar_t>
|
703 |
+
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
|
704 |
+
const scalar_t *data_col, const scalar_t *data_im,
|
705 |
+
const scalar_t *data_offset, const scalar_t *data_mask,
|
706 |
+
const int channels, const int height, const int width,
|
707 |
+
const int kernel_h, const int kernel_w,
|
708 |
+
const int pad_h, const int pad_w,
|
709 |
+
const int stride_h, const int stride_w,
|
710 |
+
const int dilation_h, const int dilation_w,
|
711 |
+
const int channel_per_deformable_group,
|
712 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
713 |
+
const int height_col, const int width_col,
|
714 |
+
scalar_t *grad_offset, scalar_t *grad_mask)
|
715 |
+
{
|
716 |
+
CUDA_KERNEL_LOOP(index, n)
|
717 |
+
{
|
718 |
+
scalar_t val = 0, mval = 0;
|
719 |
+
int w = index % width_col;
|
720 |
+
int h = (index / width_col) % height_col;
|
721 |
+
int c = (index / width_col / height_col) % offset_channels;
|
722 |
+
int b = (index / width_col / height_col) / offset_channels;
|
723 |
+
// compute the start and end of the output
|
724 |
+
|
725 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
726 |
+
const int col_step = kernel_h * kernel_w;
|
727 |
+
int cnt = 0;
|
728 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
|
729 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
730 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
731 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
732 |
+
|
733 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
734 |
+
|
735 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
736 |
+
{
|
737 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
738 |
+
const int bp_dir = offset_c % 2;
|
739 |
+
|
740 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
741 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
742 |
+
int w_out = col_pos % width_col;
|
743 |
+
int h_out = (col_pos / width_col) % height_col;
|
744 |
+
int w_in = w_out * stride_w - pad_w;
|
745 |
+
int h_in = h_out * stride_h - pad_h;
|
746 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
747 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
748 |
+
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
749 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
750 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
751 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
752 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
753 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
754 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
755 |
+
{
|
756 |
+
inv_h = inv_w = -2;
|
757 |
+
}
|
758 |
+
else
|
759 |
+
{
|
760 |
+
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
|
761 |
+
}
|
762 |
+
const scalar_t weight = dmcn_get_coordinate_weight(
|
763 |
+
inv_h, inv_w,
|
764 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
765 |
+
val += weight * data_col_ptr[col_pos] * mask;
|
766 |
+
cnt += 1;
|
767 |
+
}
|
768 |
+
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
769 |
+
grad_offset[index] = val;
|
770 |
+
if (offset_c % 2 == 0)
|
771 |
+
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
|
772 |
+
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
|
773 |
+
}
|
774 |
+
}
|
775 |
+
|
776 |
+
void modulated_deformable_im2col_cuda(
|
777 |
+
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
778 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
779 |
+
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
|
780 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
781 |
+
const int dilation_h, const int dilation_w,
|
782 |
+
const int deformable_group, at::Tensor data_col)
|
783 |
+
{
|
784 |
+
// num_axes should be smaller than block size
|
785 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
786 |
+
const int num_kernels = channels * batch_size * height_col * width_col;
|
787 |
+
|
788 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
789 |
+
data_im.type(), "modulated_deformable_im2col_gpu", ([&] {
|
790 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
791 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
792 |
+
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
793 |
+
scalar_t *data_col_ = data_col.data<scalar_t>();
|
794 |
+
|
795 |
+
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
796 |
+
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
|
797 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
798 |
+
batch_size, channels, deformable_group, height_col, width_col, data_col_);
|
799 |
+
}));
|
800 |
+
|
801 |
+
cudaError_t err = cudaGetLastError();
|
802 |
+
if (err != cudaSuccess)
|
803 |
+
{
|
804 |
+
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
805 |
+
}
|
806 |
+
}
|
807 |
+
|
808 |
+
void modulated_deformable_col2im_cuda(
|
809 |
+
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
|
810 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
811 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
812 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
813 |
+
const int dilation_h, const int dilation_w,
|
814 |
+
const int deformable_group, at::Tensor grad_im)
|
815 |
+
{
|
816 |
+
|
817 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
818 |
+
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
819 |
+
|
820 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
821 |
+
data_col.type(), "modulated_deformable_col2im_gpu", ([&] {
|
822 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
823 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
824 |
+
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
825 |
+
scalar_t *grad_im_ = grad_im.data<scalar_t>();
|
826 |
+
|
827 |
+
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
828 |
+
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
|
829 |
+
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
|
830 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
831 |
+
batch_size, deformable_group, height_col, width_col, grad_im_);
|
832 |
+
}));
|
833 |
+
|
834 |
+
cudaError_t err = cudaGetLastError();
|
835 |
+
if (err != cudaSuccess)
|
836 |
+
{
|
837 |
+
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
838 |
+
}
|
839 |
+
}
|
840 |
+
|
841 |
+
void modulated_deformable_col2im_coord_cuda(
|
842 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
843 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
844 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
845 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
846 |
+
const int dilation_h, const int dilation_w,
|
847 |
+
const int deformable_group,
|
848 |
+
at::Tensor grad_offset, at::Tensor grad_mask)
|
849 |
+
{
|
850 |
+
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
|
851 |
+
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
|
852 |
+
|
853 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
854 |
+
data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
855 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
856 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
857 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
858 |
+
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
859 |
+
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
|
860 |
+
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
|
861 |
+
|
862 |
+
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
863 |
+
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
|
864 |
+
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
865 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
866 |
+
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
|
867 |
+
grad_offset_, grad_mask_);
|
868 |
+
}));
|
869 |
+
cudaError_t err = cudaGetLastError();
|
870 |
+
if (err != cudaSuccess)
|
871 |
+
{
|
872 |
+
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
|
873 |
+
}
|
874 |
+
}
|
maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// modify from
|
2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
|
3 |
+
|
4 |
+
// based on
|
5 |
+
// author: Charles Shang
|
6 |
+
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
|
7 |
+
|
8 |
+
#include <ATen/ATen.h>
|
9 |
+
#include <ATen/cuda/CUDAContext.h>
|
10 |
+
|
11 |
+
#include <THC/THC.h>
|
12 |
+
#include <THC/THCDeviceUtils.cuh>
|
13 |
+
|
14 |
+
#include <vector>
|
15 |
+
#include <iostream>
|
16 |
+
#include <cmath>
|
17 |
+
|
18 |
+
|
19 |
+
void DeformablePSROIPoolForward(
|
20 |
+
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
|
21 |
+
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
|
22 |
+
const int height, const int width, const int num_bbox,
|
23 |
+
const int channels_trans, const int no_trans, const float spatial_scale,
|
24 |
+
const int output_dim, const int group_size, const int pooled_size,
|
25 |
+
const int part_size, const int sample_per_part, const float trans_std);
|
26 |
+
|
27 |
+
void DeformablePSROIPoolBackwardAcc(
|
28 |
+
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
|
29 |
+
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
|
30 |
+
at::Tensor trans_grad, const int batch, const int channels,
|
31 |
+
const int height, const int width, const int num_bbox,
|
32 |
+
const int channels_trans, const int no_trans, const float spatial_scale,
|
33 |
+
const int output_dim, const int group_size, const int pooled_size,
|
34 |
+
const int part_size, const int sample_per_part, const float trans_std);
|
35 |
+
|
36 |
+
void deform_psroi_pooling_cuda_forward(
|
37 |
+
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
|
38 |
+
at::Tensor top_count, const int no_trans, const float spatial_scale,
|
39 |
+
const int output_dim, const int group_size, const int pooled_size,
|
40 |
+
const int part_size, const int sample_per_part, const float trans_std)
|
41 |
+
{
|
42 |
+
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
43 |
+
|
44 |
+
const int batch = input.size(0);
|
45 |
+
const int channels = input.size(1);
|
46 |
+
const int height = input.size(2);
|
47 |
+
const int width = input.size(3);
|
48 |
+
const int channels_trans = no_trans ? 2 : trans.size(1);
|
49 |
+
|
50 |
+
const int num_bbox = bbox.size(0);
|
51 |
+
if (num_bbox != out.size(0))
|
52 |
+
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
|
53 |
+
out.size(0), num_bbox);
|
54 |
+
|
55 |
+
DeformablePSROIPoolForward(
|
56 |
+
input, bbox, trans, out, top_count, batch, channels, height, width,
|
57 |
+
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
|
58 |
+
pooled_size, part_size, sample_per_part, trans_std);
|
59 |
+
}
|
60 |
+
|
61 |
+
void deform_psroi_pooling_cuda_backward(
|
62 |
+
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
|
63 |
+
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
|
64 |
+
const int no_trans, const float spatial_scale, const int output_dim,
|
65 |
+
const int group_size, const int pooled_size, const int part_size,
|
66 |
+
const int sample_per_part, const float trans_std)
|
67 |
+
{
|
68 |
+
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
|
69 |
+
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
70 |
+
|
71 |
+
const int batch = input.size(0);
|
72 |
+
const int channels = input.size(1);
|
73 |
+
const int height = input.size(2);
|
74 |
+
const int width = input.size(3);
|
75 |
+
const int channels_trans = no_trans ? 2 : trans.size(1);
|
76 |
+
|
77 |
+
const int num_bbox = bbox.size(0);
|
78 |
+
if (num_bbox != out_grad.size(0))
|
79 |
+
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
|
80 |
+
out_grad.size(0), num_bbox);
|
81 |
+
|
82 |
+
DeformablePSROIPoolBackwardAcc(
|
83 |
+
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
|
84 |
+
channels, height, width, num_bbox, channels_trans, no_trans,
|
85 |
+
spatial_scale, output_dim, group_size, pooled_size, part_size,
|
86 |
+
sample_per_part, trans_std);
|
87 |
+
}
|