3v324v23 commited on
Commit
c310e19
1 Parent(s): 5c20bda
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE.md +159 -0
  2. app.py +39 -0
  3. configs/mixtrain/seg_rec_poly_fuse_feature.yaml +97 -0
  4. configs/pretrain/seg_rec_poly_fuse_feature.yaml +94 -0
  5. evaluation/icdar2015/e2e/prepare_results.py +263 -0
  6. evaluation/icdar2015/e2e/rrc_evaluation_funcs.py +369 -0
  7. evaluation/icdar2015/e2e/script.py +461 -0
  8. evaluation/icdar2015/gt.zip +0 -0
  9. evaluation/rotated_icdar2013/e2e/prepare_results.py +267 -0
  10. evaluation/rotated_icdar2013/e2e/rrc_evaluation_funcs.py +369 -0
  11. evaluation/rotated_icdar2013/e2e/script.py +460 -0
  12. evaluation/rotated_icdar2013/gt/gt.zip +0 -0
  13. evaluation/rotated_icdar2013/gt/gt_-15.zip +0 -0
  14. evaluation/rotated_icdar2013/gt/gt_-30.zip +0 -0
  15. evaluation/rotated_icdar2013/gt/gt_-45.zip +0 -0
  16. evaluation/rotated_icdar2013/gt/gt_-60.zip +0 -0
  17. evaluation/rotated_icdar2013/gt/gt_-75.zip +0 -0
  18. evaluation/rotated_icdar2013/gt/gt_-90.zip +0 -0
  19. evaluation/rotated_icdar2013/gt/gt_0.zip +0 -0
  20. evaluation/rotated_icdar2013/gt/gt_15.zip +0 -0
  21. evaluation/rotated_icdar2013/gt/gt_30.zip +0 -0
  22. evaluation/rotated_icdar2013/gt/gt_45.zip +0 -0
  23. evaluation/rotated_icdar2013/gt/gt_60.zip +0 -0
  24. evaluation/rotated_icdar2013/gt/gt_75.zip +0 -0
  25. evaluation/rotated_icdar2013/gt/gt_85.zip +0 -0
  26. evaluation/rotated_icdar2013/gt/gt_90.zip +0 -0
  27. evaluation/totaltext/e2e/prepare_results.py +234 -0
  28. evaluation/totaltext/e2e/rrc_evaluation_funcs.py +369 -0
  29. evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py +363 -0
  30. evaluation/totaltext/e2e/script.py +452 -0
  31. evaluation/totaltext/gt.zip +0 -0
  32. evaluation/weighted_editdistance.py +55 -0
  33. example1.jpg +0 -0
  34. example2.jpg +0 -0
  35. example3.jpg +0 -0
  36. maskrcnn_benchmark/config/__init__.py +2 -0
  37. maskrcnn_benchmark/config/defaults.py +373 -0
  38. maskrcnn_benchmark/config/paths_catalog.py +237 -0
  39. maskrcnn_benchmark/csrc/ROIAlign.h +46 -0
  40. maskrcnn_benchmark/csrc/ROIPool.h +48 -0
  41. maskrcnn_benchmark/csrc/SigmoidFocalLoss.h +41 -0
  42. maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp +257 -0
  43. maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp +75 -0
  44. maskrcnn_benchmark/csrc/cpu/vision.h +16 -0
  45. maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu +346 -0
  46. maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu +202 -0
  47. maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu +189 -0
  48. maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu +691 -0
  49. maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu +874 -0
  50. maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu +87 -0
LICENSE.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Creative Commons Attribution-NonCommercial 4.0 International
2
+
3
+ Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
4
+
5
+ ### Using Creative Commons Public Licenses
6
+
7
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+
9
+ * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10
+
11
+ * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12
+
13
+ ## Creative Commons Attribution-NonCommercial 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ ### Section 1 – Definitions.
18
+
19
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+
21
+ b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
22
+
23
+ c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
24
+
25
+ d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
26
+
27
+ e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
28
+
29
+ f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
30
+
31
+ g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
32
+
33
+ h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
34
+
35
+ i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
36
+
37
+ j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
38
+
39
+ k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
40
+
41
+ l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. __Your__ has a corresponding meaning.
42
+
43
+ ### Section 2 – Scope.
44
+
45
+ a. ___License grant.___
46
+
47
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
48
+
49
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
50
+
51
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
52
+
53
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
54
+
55
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
56
+
57
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
58
+
59
+ 5. __Downstream recipients.__
60
+
61
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
62
+
63
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
64
+
65
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
66
+
67
+ b. ___Other rights.___
68
+
69
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
70
+
71
+ 2. Patent and trademark rights are not licensed under this Public License.
72
+
73
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
74
+
75
+ ### Section 3 – License Conditions.
76
+
77
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
78
+
79
+ a. ___Attribution.___
80
+
81
+ 1. If You Share the Licensed Material (including in modified form), You must:
82
+
83
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
84
+
85
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
86
+
87
+ ii. a copyright notice;
88
+
89
+ iii. a notice that refers to this Public License;
90
+
91
+ iv. a notice that refers to the disclaimer of warranties;
92
+
93
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
94
+
95
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
96
+
97
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
98
+
99
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
100
+
101
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
102
+
103
+ 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
104
+
105
+ ### Section 4 – Sui Generis Database Rights.
106
+
107
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
108
+
109
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
110
+
111
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
112
+
113
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
114
+
115
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
116
+
117
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
118
+
119
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
120
+
121
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
122
+
123
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
124
+
125
+ ### Section 6 – Term and Termination.
126
+
127
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
128
+
129
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
130
+
131
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
132
+
133
+ 2. upon express reinstatement by the Licensor.
134
+
135
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
136
+
137
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
138
+
139
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
140
+
141
+ ### Section 7 – Other Terms and Conditions.
142
+
143
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
144
+
145
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
146
+
147
+ ### Section 8 – Interpretation.
148
+
149
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
150
+
151
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
152
+
153
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
154
+
155
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
156
+
157
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
158
+ >
159
+ > Creative Commons may be contacted at creativecommons.org
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('python setup.py build develop')
3
+ os.system('pip install --upgrade --no-cache-dir gdown')
4
+ os.system('gdown -O output/mixtrain/ 1XQsikiNY7ILgZvmvOeUf9oPDG4fTp0zs')
5
+
6
+ import cv2
7
+ import pandas as pd
8
+ import gradio as gr
9
+ from tools.demo import TextDemo
10
+ from maskrcnn_benchmark.config import cfg
11
+
12
+
13
+ def infer(filepath):
14
+ cfg.merge_from_file('configs/mixtrain/seg_rec_poly_fuse_feature.yaml')
15
+ # manual override some options
16
+ cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
17
+
18
+ text_demo = TextDemo(
19
+ cfg,
20
+ min_image_size=800,
21
+ confidence_threshold=0.7,
22
+ output_polygon=True
23
+ )
24
+ image = cv2.imread(filepath)
25
+ result_polygons, result_words = text_demo.run_on_opencv_image(image)
26
+ text_demo.visualization(image, result_polygons, result_words)
27
+ cv2.imwrite('result.jpg', image)
28
+ return 'result.jpg', pd.DataFrame(result_words)
29
+
30
+
31
+ iface = gr.Interface(
32
+ fn=infer,
33
+ title="Mask TextSpotter v3",
34
+ description="Mask TextSpotter v3 is an end-to-end trainable scene text spotter that adopts a Segmentation Proposal Network (SPN) instead of an RPN. Mask TextSpotter v3 significantly improves robustness to rotations, aspect ratios, and shapes.",
35
+ inputs=[gr.inputs.Image(label="image", type="filepath")],
36
+ outputs=[gr.outputs.Image(), gr.outputs.Dataframe(headers=['word'])],
37
+ examples=['example1.jpg', 'example2.jpg', 'example3.jpg'],
38
+ article="<a href=\"https://github.com/MhLiao/MaskTextSpotterV3\">GitHub Repo</a>",
39
+ ).launch(enable_queue=True, cache_examples=True)
configs/mixtrain/seg_rec_poly_fuse_feature.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ # WEIGHT: './output/path-to-pretrain-model' # for training
4
+ WEIGHT: './output/mixtrain/trained_model.pth' # for testing
5
+ BACKBONE:
6
+ CONV_BODY: "R-50-FPN"
7
+ OUT_CHANNELS: 256
8
+ RESNETS:
9
+ BACKBONE_OUT_CHANNELS: 256
10
+ RPN:
11
+ USE_FPN: True
12
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
13
+ PRE_NMS_TOP_N_TRAIN: 2000
14
+ PRE_NMS_TOP_N_TEST: 1000
15
+ POST_NMS_TOP_N_TEST: 1000
16
+ FPN_POST_NMS_TOP_N_TEST: 1000
17
+ SEG:
18
+ USE_FPN: True
19
+ USE_FUSE_FEATURE: True
20
+ TOP_N_TRAIN: 1000
21
+ TOP_N_TEST: 1000
22
+ BINARY_THRESH: 0.1
23
+ BOX_THRESH: 0.1
24
+ MIN_SIZE: 5
25
+ SHRINK_RATIO: 0.4
26
+ EXPAND_RATIO: 3.0
27
+ ROI_HEADS:
28
+ USE_FPN: True
29
+ BATCH_SIZE_PER_IMAGE: 512
30
+ ROI_BOX_HEAD:
31
+ POOLER_RESOLUTION: 7
32
+ POOLER_SCALES: (0.25,)
33
+ POOLER_SAMPLING_RATIO: 2
34
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
35
+ PREDICTOR: "FPNPredictor"
36
+ NUM_CLASSES: 2
37
+ USE_MASKED_FEATURE: True
38
+ ROI_MASK_HEAD:
39
+ POOLER_SCALES: (0.25,)
40
+ FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
41
+ PREDICTOR: "SeqCharMaskRCNNC4Predictor"
42
+ POOLER_RESOLUTION: 14
43
+ POOLER_RESOLUTION_H: 32
44
+ POOLER_RESOLUTION_W: 32
45
+ POOLER_SAMPLING_RATIO: 2
46
+ RESOLUTION: 28
47
+ RESOLUTION_H: 64
48
+ RESOLUTION_W: 64
49
+ SHARE_BOX_FEATURE_EXTRACTOR: False
50
+ CHAR_NUM_CLASSES: 37
51
+ USE_WEIGHTED_CHAR_MASK: True
52
+ MASK_BATCH_SIZE_PER_IM: 64
53
+ USE_MASKED_FEATURE: True
54
+ MASK_ON: True
55
+ CHAR_MASK_ON: True
56
+ SEG_ON: True
57
+ # TRAIN_DETECTION_ONLY: True
58
+ SEQUENCE:
59
+ SEQ_ON: True
60
+ NUM_CHAR: 38
61
+ BOS_TOKEN: 0
62
+ MAX_LENGTH: 32
63
+ TEACHER_FORCE_RATIO: 1.0
64
+ DATASETS:
65
+ # TRAIN: ("synthtext_train",)
66
+ TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train")
67
+ RATIOS: [0.25,0.25,0.25,0.125,0.125]
68
+ # TEST: ("icdar_2015_test",)
69
+ TEST: ("total_text_test",)
70
+ # TEST: ("rotated_ic13_test_45",)
71
+ AUG: True
72
+ IGNORE_DIFFICULT: True
73
+ MAX_ROTATE_THETA: 90
74
+ DATALOADER:
75
+ SIZE_DIVISIBILITY: 32
76
+ NUM_WORKERS: 4
77
+ ASPECT_RATIO_GROUPING: False
78
+ SOLVER:
79
+ BASE_LR: 0.002 #0.02
80
+ WARMUP_FACTOR: 0.1
81
+ WEIGHT_DECAY: 0.0001
82
+ STEPS: (100000, 160000)
83
+ MAX_ITER: 300000
84
+ IMS_PER_BATCH: 8
85
+ RESUME: False
86
+ DISPLAY_FREQ: 20
87
+ OUTPUT_DIR: "./output/mixtrain"
88
+ TEST:
89
+ VIS: True
90
+ CHAR_THRESH: 192
91
+ IMS_PER_BATCH: 1
92
+ INPUT:
93
+ MIN_SIZE_TRAIN: (800, 1000, 1200, 1400)
94
+ MAX_SIZE_TRAIN: 2333
95
+ MIN_SIZE_TEST: 1000
96
+ # MIN_SIZE_TEST: 1440
97
+ MAX_SIZE_TEST: 4000
configs/pretrain/seg_rec_poly_fuse_feature.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
4
+ BACKBONE:
5
+ CONV_BODY: "R-50-FPN"
6
+ OUT_CHANNELS: 256
7
+ RESNETS:
8
+ BACKBONE_OUT_CHANNELS: 256
9
+ RPN:
10
+ USE_FPN: True
11
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
12
+ PRE_NMS_TOP_N_TRAIN: 2000
13
+ PRE_NMS_TOP_N_TEST: 1000
14
+ POST_NMS_TOP_N_TEST: 1000
15
+ FPN_POST_NMS_TOP_N_TEST: 1000
16
+ SEG:
17
+ USE_FPN: True
18
+ USE_FUSE_FEATURE: True
19
+ TOP_N_TRAIN: 1000
20
+ TOP_N_TEST: 1000
21
+ BINARY_THRESH: 0.1
22
+ BOX_THRESH: 0.1
23
+ MIN_SIZE: 5
24
+ SHRINK_RATIO: 0.4
25
+ EXPAND_RATIO: 3.0
26
+ ROI_HEADS:
27
+ USE_FPN: True
28
+ BATCH_SIZE_PER_IMAGE: 512
29
+ ROI_BOX_HEAD:
30
+ POOLER_RESOLUTION: 7
31
+ POOLER_SCALES: (0.25,)
32
+ POOLER_SAMPLING_RATIO: 2
33
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
34
+ PREDICTOR: "FPNPredictor"
35
+ NUM_CLASSES: 2
36
+ USE_MASKED_FEATURE: True
37
+ ROI_MASK_HEAD:
38
+ POOLER_SCALES: (0.25,)
39
+ FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
40
+ PREDICTOR: "SeqCharMaskRCNNC4Predictor"
41
+ POOLER_RESOLUTION: 14
42
+ POOLER_RESOLUTION_H: 32
43
+ POOLER_RESOLUTION_W: 32
44
+ POOLER_SAMPLING_RATIO: 2
45
+ RESOLUTION: 28
46
+ RESOLUTION_H: 64
47
+ RESOLUTION_W: 64
48
+ SHARE_BOX_FEATURE_EXTRACTOR: False
49
+ CHAR_NUM_CLASSES: 37
50
+ USE_WEIGHTED_CHAR_MASK: True
51
+ MASK_BATCH_SIZE_PER_IM: 64
52
+ USE_MASKED_FEATURE: True
53
+ MASK_ON: True
54
+ CHAR_MASK_ON: True
55
+ SEG_ON: True
56
+ SEQUENCE:
57
+ SEQ_ON: True
58
+ NUM_CHAR: 38
59
+ BOS_TOKEN: 0
60
+ MAX_LENGTH: 32
61
+ TEACHER_FORCE_RATIO: 1.0
62
+ DATASETS:
63
+ TRAIN: ("synthtext_train",)
64
+ # TRAIN: ("synthtext_train","icdar_2013_train","icdar_2015_train","scut-eng-char_train","total_text_train")
65
+ # RATIOS: [0.25,0.25,0.25,0.125,0.125]
66
+ TEST: ("icdar_2015_test",)
67
+ # TEST: ("total_text_test",)
68
+ AUG: True
69
+ IGNORE_DIFFICULT: True
70
+ MAX_ROTATE_THETA: 90
71
+ DATALOADER:
72
+ SIZE_DIVISIBILITY: 32
73
+ NUM_WORKERS: 4
74
+ ASPECT_RATIO_GROUPING: False
75
+ SOLVER:
76
+ BASE_LR: 0.02 #0.02
77
+ WARMUP_FACTOR: 0.1
78
+ WEIGHT_DECAY: 0.0001
79
+ STEPS: (100000, 200000)
80
+ MAX_ITER: 300000
81
+ IMS_PER_BATCH: 8
82
+ RESUME: True
83
+ DISPLAY_FREQ: 20
84
+ OUTPUT_DIR: "./output/pretrain"
85
+ TEST:
86
+ VIS: False
87
+ CHAR_THRESH: 192
88
+ IMS_PER_BATCH: 1
89
+ INPUT:
90
+ MIN_SIZE_TRAIN: (600, 800)
91
+ # MIN_SIZE_TRAIN: (800, 1000, 1200, 1400)
92
+ MAX_SIZE_TRAIN: 2333
93
+ MIN_SIZE_TEST: 1440
94
+ MAX_SIZE_TEST: 4000
evaluation/icdar2015/e2e/prepare_results.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ import sys
4
+ import os
5
+ sys.path.append('./')
6
+ import shapely
7
+ from shapely.geometry import Polygon,MultiPoint
8
+ import numpy as np
9
+ import editdistance
10
+ sys.path.append('../../')
11
+ from weighted_editdistance import weighted_edit_distance
12
+ from tqdm import tqdm
13
+ try:
14
+ import pickle
15
+ except ImportError:
16
+ import cPickle as pickle
17
+
18
+ def list_from_str(st):
19
+ line = st.split(',')
20
+ # box[0:4], polygon[4:12], word, seq_word, detection_score, rec_socre, seq_score, char_score_path
21
+ new_line = [float(a) for a in line[4:12]]+[float(line[-4])]+[line[-5]]+[line[-6]]+[float(line[-3])]+[float(line[-2])] + [line[-1]]
22
+ return new_line
23
+
24
+ def polygon_from_list(line):
25
+ """
26
+ Create a shapely polygon object from gt or dt line.
27
+ """
28
+ polygon_points = np.array(line).reshape(4, 2)
29
+ polygon = Polygon(polygon_points).convex_hull
30
+ return polygon
31
+
32
+ def polygon_iou(list1, list2):
33
+ """
34
+ Intersection over union between two shapely polygons.
35
+ """
36
+ polygon_points1 = np.array(list1).reshape(4, 2)
37
+ poly1 = Polygon(polygon_points1).convex_hull
38
+ polygon_points2 = np.array(list2).reshape(4, 2)
39
+ poly2 = Polygon(polygon_points2).convex_hull
40
+ union_poly = np.concatenate((polygon_points1,polygon_points2))
41
+ if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
42
+ iou = 0
43
+ else:
44
+ try:
45
+ inter_area = poly1.intersection(poly2).area
46
+ #union_area = poly1.area + poly2.area - inter_area
47
+ union_area = MultiPoint(union_poly).convex_hull.area
48
+ iou = float(inter_area) / (union_area+1e-6)
49
+ except shapely.geos.TopologicalError:
50
+ print('shapely.geos.TopologicalError occured, iou set to 0')
51
+ iou = 0
52
+ return iou
53
+
54
+ def nms(boxes,overlap):
55
+ rec_scores = [b[-2] for b in boxes]
56
+ indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k])
57
+ box_num = len(boxes)
58
+ nms_flag = [True]*box_num
59
+ for i in range(box_num):
60
+ ii = indices[i]
61
+ if not nms_flag[ii]:
62
+ continue
63
+ for j in range(box_num):
64
+ jj = indices[j]
65
+ if j == i:
66
+ continue
67
+ if not nms_flag[jj]:
68
+ continue
69
+ box1 = boxes[ii]
70
+ box2 = boxes[jj]
71
+ box1_score = rec_scores[ii]
72
+ box2_score = rec_scores[jj]
73
+ str1 = box1[9]
74
+ str2 = box2[9]
75
+ box_i = [box1[0],box1[1],box1[4],box1[5]]
76
+ box_j = [box2[0],box2[1],box2[4],box2[5]]
77
+ poly1 = polygon_from_list(box1[0:8])
78
+ poly2 = polygon_from_list(box2[0:8])
79
+ iou = polygon_iou(box1[0:8],box2[0:8])
80
+ thresh = overlap
81
+
82
+ if iou > thresh:
83
+ if box1_score > box2_score:
84
+ nms_flag[jj] = False
85
+ if box1_score == box2_score and poly1.area > poly2.area:
86
+ nms_flag[jj] = False
87
+ if box1_score == box2_score and poly1.area<=poly2.area:
88
+ nms_flag[ii] = False
89
+ break
90
+
91
+ return nms_flag
92
+
93
+ def packing(save_dir, cache_dir, pack_name):
94
+ files = os.listdir(save_dir)
95
+ if not os.path.exists(cache_dir):
96
+ os.mkdir(cache_dir)
97
+ os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*')
98
+
99
+ def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False):
100
+ '''
101
+ results_dir: result directory
102
+ score_det: score of detection bounding box
103
+ score_rec: score of the mask recognition branch
104
+ socre_rec_seq: score of the sequence recognition branch
105
+ overlap: overlap threshold used for nms
106
+ lexicon_type: 1 for generic; 2 for weak; 3 for strong
107
+ use_seq: use the recognition result of sequence branch
108
+ use_mix: use both the recognition result of the mask and sequence branches, selected by score
109
+ '''
110
+ print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix)
111
+ if not os.path.exists(cache_dir):
112
+ os.mkdir(cache_dir)
113
+ nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq))
114
+ if not os.path.exists(nms_dir):
115
+ os.mkdir(nms_dir)
116
+ if lexicon_type==1:
117
+ # generic lexicon
118
+ lexicon_path = '../../lexicons/ic15/GenericVocabulary_new.txt'
119
+ lexicon_fid=open(lexicon_path, 'r')
120
+ pair_list = open('../../lexicons/ic15/GenericVocabulary_pair_list.txt', 'r')
121
+ pairs = dict()
122
+ for line in pair_list.readlines():
123
+ line=line.strip()
124
+ word = line.split(' ')[0].upper()
125
+ word_gt = line[len(word)+1:]
126
+ pairs[word] = word_gt
127
+ lexicon_fid=open(lexicon_path, 'r')
128
+ lexicon=[]
129
+ for line in lexicon_fid.readlines():
130
+ line=line.strip()
131
+ lexicon.append(line)
132
+ if lexicon_type==2:
133
+ # weak lexicon
134
+ lexicon_path = '../../lexicons/ic15/ch4_test_vocabulary_new.txt'
135
+ lexicon_fid=open(lexicon_path, 'r')
136
+ pair_list = open('../../lexicons/ic15/ch4_test_vocabulary_pair_list.txt', 'r')
137
+ pairs = dict()
138
+ for line in pair_list.readlines():
139
+ line=line.strip()
140
+ word = line.split(' ')[0].upper()
141
+ word_gt = line[len(word)+1:]
142
+ pairs[word] = word_gt
143
+ lexicon_fid=open(lexicon_path, 'r')
144
+ lexicon=[]
145
+ for line in lexicon_fid.readlines():
146
+ line=line.strip()
147
+ lexicon.append(line)
148
+
149
+ for i in tqdm(range(1,501)):
150
+ img = 'img_'+str(i)+'.jpg'
151
+ gt_img = 'gt_img_'+str(i)+'.txt'
152
+ if lexicon_type==3:
153
+ # weak
154
+ lexicon_path = '../../lexicons/ic15/new_strong_lexicon/new_voc_img_' + str(i) + '.txt'
155
+ lexicon_fid=open(lexicon_path, 'r')
156
+ pair_list = open('../../lexicons/ic15/new_strong_lexicon/pair_voc_img_' + str(i) + '.txt', 'r')
157
+ pairs = dict()
158
+ for line in pair_list.readlines():
159
+ line=line.strip()
160
+ word = line.split(' ')[0].upper()
161
+ word_gt = line[len(word)+1:]
162
+ pairs[word] = word_gt
163
+ lexicon_fid=open(lexicon_path, 'r')
164
+ lexicon=[]
165
+ for line in lexicon_fid.readlines():
166
+ line=line.strip()
167
+ lexicon.append(line)
168
+ result_path = os.path.join(results_dir,'res_img_'+str(i)+'.txt')
169
+ if os.path.isfile(result_path):
170
+ with open(result_path,'r') as f:
171
+ dt_lines = [a.strip() for a in f.readlines()]
172
+ dt_lines = [list_from_str(dt) for dt in dt_lines]
173
+ else:
174
+ dt_lines = []
175
+ dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det]
176
+ nms_flag = nms(dt_lines,overlap)
177
+ boxes = []
178
+ for k in range(len(dt_lines)):
179
+ dt = dt_lines[k]
180
+ if nms_flag[k]:
181
+ if dt not in boxes:
182
+ boxes.append(dt)
183
+
184
+ with open(os.path.join(nms_dir,'res_img_'+str(i)+'.txt'),'w') as f:
185
+ for g in boxes:
186
+ gt_coors = [int(b) for b in g[0:8]]
187
+ with open('../../../' + g[-1], "rb") as input_file:
188
+ # with open(g[-1], "rb") as input_file:
189
+ dict_scores = pickle.load(input_file)
190
+ if use_char and use_seq:
191
+ if g[-2]>g[-3]:
192
+ word = g[-5]
193
+ scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
194
+ else:
195
+ word = g[-4]
196
+ scores = dict_scores['seg_char_scores']
197
+ elif use_seq:
198
+ word = g[-5]
199
+ scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
200
+ else:
201
+ word = g[-4]
202
+ scores = dict_scores['seg_char_scores']
203
+ match_word, match_dist = find_match_word(word, lexicon, pairs, scores, use_lexicon, weighted_ed)
204
+ if match_dist<1.5 or lexicon_type==1:
205
+ gt_coor_strs = [str(a) for a in gt_coors]+ [match_word]
206
+ f.write(','.join(gt_coor_strs)+'\r\n')
207
+
208
+ pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap)
209
+
210
+ packing(nms_dir,cache_dir,pack_name)
211
+ submit_file_path = os.path.join(cache_dir, pack_name+'.zip')
212
+ return submit_file_path
213
+
214
+ def find_match_word(rec_str, lexicon, pairs, scores_numpy, use_ed = True, weighted_ed = False):
215
+ if not use_ed:
216
+ return rec_str
217
+ rec_str = rec_str.upper()
218
+ dist_min = 100
219
+ dist_min_pre = 100
220
+ match_word = ''
221
+ match_dist = 100
222
+ if not weighted_ed:
223
+ for word in lexicon:
224
+ word = word.upper()
225
+ ed = editdistance.eval(rec_str, word)
226
+ length_dist = abs(len(word) - len(rec_str))
227
+ # dist = ed + length_dist
228
+ dist = ed
229
+ if dist<dist_min:
230
+ dist_min = dist
231
+ match_word = pairs[word]
232
+ match_dist = dist
233
+ return match_word, match_dist
234
+ else:
235
+ small_lexicon_dict = dict()
236
+ for word in lexicon:
237
+ word = word.upper()
238
+ ed = editdistance.eval(rec_str, word)
239
+ small_lexicon_dict[word] = ed
240
+ dist = ed
241
+ if dist<dist_min_pre:
242
+ dist_min_pre = dist
243
+ small_lexicon = []
244
+ for word in small_lexicon_dict:
245
+ if small_lexicon_dict[word]<=dist_min_pre+2:
246
+ small_lexicon.append(word)
247
+
248
+ for word in small_lexicon:
249
+ word = word.upper()
250
+ ed = weighted_edit_distance(rec_str, word, scores_numpy)
251
+ dist = ed
252
+ if dist<dist_min:
253
+ dist_min = dist
254
+ match_word = pairs[word]
255
+ match_dist = dist
256
+ return match_word, match_dist
257
+
258
+
259
+ def prepare_results_for_evaluation(results_dir, lexicon_type, cache_dir, score_det, score_rec, score_rec_seq):
260
+ if not os.path.isdir(cache_dir):
261
+ os.mkdir(cache_dir)
262
+ result_path = test_single(results_dir,score_det=score_det,score_rec=score_rec,score_rec_seq=score_rec_seq,overlap=0.2,cache_dir=cache_dir,lexicon_type=lexicon_type, use_lexicon=True, weighted_ed=True, use_seq=True, use_char=True, mix=True)
263
+ return result_path
evaluation/icdar2015/e2e/rrc_evaluation_funcs.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ #encoding: UTF-8
3
+ import json
4
+ import sys;sys.path.append('./')
5
+ import zipfile
6
+ import re
7
+ import sys
8
+ import os
9
+ import codecs
10
+ import importlib
11
+ try:
12
+ from StringIO import StringIO
13
+ except ImportError:
14
+ from io import StringIO
15
+
16
+ def print_help():
17
+ sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
18
+ sys.exit(2)
19
+
20
+
21
+ def load_zip_file_keys(file,fileNameRegExp=''):
22
+ """
23
+ Returns an array with the entries of the ZIP file that match with the regular expression.
24
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
25
+ """
26
+ try:
27
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
28
+ except :
29
+ raise Exception('Error loading the ZIP archive.')
30
+
31
+ pairs = []
32
+
33
+ for name in archive.namelist():
34
+ addFile = True
35
+ keyName = name
36
+ if fileNameRegExp!="":
37
+ m = re.match(fileNameRegExp,name)
38
+ if m == None:
39
+ addFile = False
40
+ else:
41
+ if len(m.groups())>0:
42
+ keyName = m.group(1)
43
+
44
+ if addFile:
45
+ pairs.append( keyName )
46
+
47
+ return pairs
48
+
49
+
50
+ def load_zip_file(file,fileNameRegExp='',allEntries=False):
51
+ """
52
+ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
53
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
54
+ allEntries validates that all entries in the ZIP file pass the fileNameRegExp
55
+ """
56
+ try:
57
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
58
+ except :
59
+ raise Exception('Error loading the ZIP archive')
60
+
61
+ pairs = []
62
+ for name in archive.namelist():
63
+ addFile = True
64
+ keyName = name
65
+ if fileNameRegExp!="":
66
+ m = re.match(fileNameRegExp,name)
67
+ if m == None:
68
+ addFile = False
69
+ else:
70
+ if len(m.groups())>0:
71
+ keyName = m.group(1)
72
+
73
+ if addFile:
74
+ pairs.append( [ keyName , archive.read(name)] )
75
+ else:
76
+ if allEntries:
77
+ raise Exception('ZIP entry not valid: %s' %name)
78
+
79
+ return dict(pairs)
80
+
81
+ def decode_utf8(raw):
82
+ """
83
+ Returns a Unicode object on success, or None on failure
84
+ """
85
+ try:
86
+ raw = codecs.decode(raw,'utf-8', 'replace')
87
+ #extracts BOM if exists
88
+ raw = raw.encode('utf8')
89
+ if raw.startswith(codecs.BOM_UTF8):
90
+ raw = raw.replace(codecs.BOM_UTF8, '', 1)
91
+ return raw.decode('utf-8')
92
+ except:
93
+ return None
94
+
95
+ def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
96
+ """
97
+ This function validates that all lines of the file calling the Line validation function for each line
98
+ """
99
+ utf8File = decode_utf8(file_contents)
100
+ if (utf8File is None) :
101
+ raise Exception("The file %s is not UTF-8" %fileName)
102
+
103
+ lines = utf8File.split( "\r\n" if CRLF else "\n" )
104
+ for line in lines:
105
+ line = line.replace("\r","").replace("\n","")
106
+ if(line != ""):
107
+ try:
108
+ validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
109
+ except Exception as e:
110
+ raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
111
+
112
+
113
+
114
+ def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
115
+ """
116
+ Validate the format of the line. If the line is not valid an exception will be raised.
117
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
118
+ Posible values are:
119
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
120
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
121
+ """
122
+ get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
123
+
124
+
125
+ def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
126
+ """
127
+ Validate the format of the line. If the line is not valid an exception will be raised.
128
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
129
+ Posible values are:
130
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
131
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
132
+ Returns values from a textline. Points , [Confidences], [Transcriptions]
133
+ """
134
+ confidence = 0.0
135
+ transcription = "";
136
+ points = []
137
+
138
+ numPoints = 4;
139
+
140
+ if LTRB:
141
+
142
+ numPoints = 4;
143
+
144
+ if withTranscription and withConfidence:
145
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
146
+ if m == None :
147
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
148
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
149
+ elif withConfidence:
150
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
151
+ if m == None :
152
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
153
+ elif withTranscription:
154
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
155
+ if m == None :
156
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
157
+ else:
158
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
159
+ if m == None :
160
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
161
+
162
+ xmin = int(m.group(1))
163
+ ymin = int(m.group(2))
164
+ xmax = int(m.group(3))
165
+ ymax = int(m.group(4))
166
+ if(xmax<xmin):
167
+ raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
168
+ if(ymax<ymin):
169
+ raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
170
+
171
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
172
+
173
+ if (imWidth>0 and imHeight>0):
174
+ validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
175
+ validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
176
+
177
+ else:
178
+
179
+ numPoints = 8;
180
+
181
+ if withTranscription and withConfidence:
182
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
183
+ if m == None :
184
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
185
+ elif withConfidence:
186
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
187
+ if m == None :
188
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
189
+ elif withTranscription:
190
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
191
+ if m == None :
192
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
193
+ else:
194
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
195
+ if m == None :
196
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
197
+
198
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
199
+
200
+ validate_clockwise_points(points)
201
+
202
+ if (imWidth>0 and imHeight>0):
203
+ validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
204
+ validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
205
+ validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
206
+ validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
207
+
208
+
209
+ if withConfidence:
210
+ try:
211
+ confidence = float(m.group(numPoints+1))
212
+ except ValueError:
213
+ raise Exception("Confidence value must be a float")
214
+
215
+ if withTranscription:
216
+ posTranscription = numPoints + (2 if withConfidence else 1)
217
+ transcription = m.group(posTranscription)
218
+ m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
219
+ if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
220
+ transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
221
+
222
+ return points,confidence,transcription
223
+
224
+
225
+ def validate_point_inside_bounds(x,y,imWidth,imHeight):
226
+ if(x<0 or x>imWidth):
227
+ raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
228
+ if(y<0 or y>imHeight):
229
+ raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
230
+
231
+ def validate_clockwise_points(points):
232
+ """
233
+ Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
234
+ """
235
+
236
+ if len(points) != 8:
237
+ raise Exception("Points list not valid." + str(len(points)))
238
+
239
+ point = [
240
+ [int(points[0]) , int(points[1])],
241
+ [int(points[2]) , int(points[3])],
242
+ [int(points[4]) , int(points[5])],
243
+ [int(points[6]) , int(points[7])]
244
+ ]
245
+ edge = [
246
+ ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
247
+ ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
248
+ ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
249
+ ( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
250
+ ]
251
+
252
+ summatory = edge[0] + edge[1] + edge[2] + edge[3];
253
+ if summatory>0:
254
+ raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
255
+
256
+ def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
257
+ """
258
+ Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
259
+ xmin,ymin,xmax,ymax,[confidence],[transcription]
260
+ x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
261
+ """
262
+ pointsList = []
263
+ transcriptionsList = []
264
+ confidencesList = []
265
+
266
+ lines = content.split( "\r\n" if CRLF else "\n" )
267
+ for line in lines:
268
+ line = line.replace("\r","").replace("\n","")
269
+ if(line != "") :
270
+ points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
271
+ pointsList.append(points)
272
+ transcriptionsList.append(transcription)
273
+ confidencesList.append(confidence)
274
+
275
+ if withConfidence and len(confidencesList)>0 and sort_by_confidences:
276
+ import numpy as np
277
+ sorted_ind = np.argsort(-np.array(confidencesList))
278
+ confidencesList = [confidencesList[i] for i in sorted_ind]
279
+ pointsList = [pointsList[i] for i in sorted_ind]
280
+ transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
281
+
282
+ return pointsList,confidencesList,transcriptionsList
283
+
284
+ def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
285
+ """
286
+ This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
287
+ Params:
288
+ p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
289
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
290
+ validate_data_fn: points to a method that validates the corrct format of the submission
291
+ evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
292
+ """
293
+
294
+ if (p == None):
295
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
296
+ if(len(sys.argv)<3):
297
+ print_help()
298
+
299
+ evalParams = default_evaluation_params_fn()
300
+ if 'p' in p.keys():
301
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
302
+
303
+ resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
304
+ try:
305
+ validate_data_fn(p['g'], p['s'], evalParams)
306
+ evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
307
+ resDict.update(evalData)
308
+
309
+ except Exception as e:
310
+ resDict['Message']= str(e)
311
+ resDict['calculated']=False
312
+
313
+ if 'o' in p:
314
+ if not os.path.exists(p['o']):
315
+ os.makedirs(p['o'])
316
+
317
+ resultsOutputname = p['o'] + '/results.zip'
318
+ outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
319
+
320
+ del resDict['per_sample']
321
+ if 'output_items' in resDict.keys():
322
+ del resDict['output_items']
323
+
324
+ outZip.writestr('method.json',json.dumps(resDict))
325
+
326
+ if not resDict['calculated']:
327
+ if show_result:
328
+ sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
329
+ if 'o' in p:
330
+ outZip.close()
331
+ return resDict
332
+
333
+ if 'o' in p:
334
+ if per_sample == True:
335
+ for k,v in evalData['per_sample'].items():
336
+ outZip.writestr( k + '.json',json.dumps(v))
337
+
338
+ if 'output_items' in evalData.keys():
339
+ for k, v in evalData['output_items'].items():
340
+ outZip.writestr( k,v)
341
+
342
+ outZip.close()
343
+
344
+ if show_result:
345
+ sys.stdout.write("Calculated!")
346
+ sys.stdout.write(json.dumps(resDict['method']))
347
+
348
+ return resDict
349
+
350
+
351
+ def main_validation(default_evaluation_params_fn,validate_data_fn):
352
+ """
353
+ This process validates a method
354
+ Params:
355
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
356
+ validate_data_fn: points to a method that validates the corrct format of the submission
357
+ """
358
+ try:
359
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
360
+ evalParams = default_evaluation_params_fn()
361
+ if 'p' in p.keys():
362
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
363
+
364
+ validate_data_fn(p['g'], p['s'], evalParams)
365
+ print('SUCCESS')
366
+ sys.exit(0)
367
+ except Exception as e:
368
+ print(str(e))
369
+ sys.exit(101)
evaluation/icdar2015/e2e/script.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # encoding=utf8
4
+ from collections import namedtuple
5
+ import rrc_evaluation_funcs
6
+ import importlib
7
+ from prepare_results import prepare_results_for_evaluation
8
+
9
+ def evaluation_imports():
10
+ """
11
+ evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
12
+ """
13
+ return {
14
+ 'Polygon':'plg',
15
+ 'numpy':'np'
16
+ }
17
+
18
+ def default_evaluation_params():
19
+ """
20
+ default_evaluation_params: Default parameters to use for the validation and evaluation.
21
+ """
22
+ return {
23
+ 'IOU_CONSTRAINT' :0.5,
24
+ 'AREA_PRECISION_CONSTRAINT' :0.5,
25
+ 'WORD_SPOTTING' :False,
26
+ 'MIN_LENGTH_CARE_WORD' :3,
27
+ 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
28
+ 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
29
+ 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
30
+ 'CRLF':False, # Lines are delimited by Windows CRLF format
31
+ 'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
32
+ 'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
33
+ 'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
34
+ }
35
+
36
+ def validate_data(gtFilePath, submFilePath, evaluationParams):
37
+ """
38
+ Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
39
+ Validates also that there are no missing files in the folder.
40
+ If some error detected, the method raises the error
41
+ """
42
+ gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
43
+
44
+ subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
45
+
46
+ #Validate format of GroundTruth
47
+ for k in gt:
48
+ rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
49
+
50
+ #Validate format of results
51
+ for k in subm:
52
+ if (k in gt) == False :
53
+ raise Exception("The sample %s not present in GT" %k)
54
+
55
+ rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
56
+
57
+
58
+ def evaluate_method(gtFilePath, submFilePath, evaluationParams):
59
+ """
60
+ Method evaluate_method: evaluate method and returns the results
61
+ Results. Dictionary with the following values:
62
+ - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
63
+ - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
64
+ """
65
+ for module,alias in evaluation_imports().items():
66
+ globals()[alias] = importlib.import_module(module)
67
+
68
+ def polygon_from_points(points,correctOffset=False):
69
+ """
70
+ Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
71
+ """
72
+
73
+ if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax
74
+ points[2] -= 1
75
+ points[4] -= 1
76
+ points[5] -= 1
77
+ points[7] -= 1
78
+
79
+ resBoxes=np.empty([1,8],dtype='int32')
80
+ resBoxes[0,0]=int(points[0])
81
+ resBoxes[0,4]=int(points[1])
82
+ resBoxes[0,1]=int(points[2])
83
+ resBoxes[0,5]=int(points[3])
84
+ resBoxes[0,2]=int(points[4])
85
+ resBoxes[0,6]=int(points[5])
86
+ resBoxes[0,3]=int(points[6])
87
+ resBoxes[0,7]=int(points[7])
88
+ pointMat = resBoxes[0].reshape([2,4]).T
89
+ return plg.Polygon( pointMat)
90
+
91
+ def rectangle_to_polygon(rect):
92
+ resBoxes=np.empty([1,8],dtype='int32')
93
+ resBoxes[0,0]=int(rect.xmin)
94
+ resBoxes[0,4]=int(rect.ymax)
95
+ resBoxes[0,1]=int(rect.xmin)
96
+ resBoxes[0,5]=int(rect.ymin)
97
+ resBoxes[0,2]=int(rect.xmax)
98
+ resBoxes[0,6]=int(rect.ymin)
99
+ resBoxes[0,3]=int(rect.xmax)
100
+ resBoxes[0,7]=int(rect.ymax)
101
+
102
+ pointMat = resBoxes[0].reshape([2,4]).T
103
+
104
+ return plg.Polygon( pointMat)
105
+
106
+ def rectangle_to_points(rect):
107
+ points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
108
+ return points
109
+
110
+ def get_union(pD,pG):
111
+ areaA = pD.area();
112
+ areaB = pG.area();
113
+ return areaA + areaB - get_intersection(pD, pG);
114
+
115
+ def get_intersection_over_union(pD,pG):
116
+ try:
117
+ return get_intersection(pD, pG) / get_union(pD, pG);
118
+ except:
119
+ return 0
120
+
121
+ def get_intersection(pD,pG):
122
+ pInt = pD & pG
123
+ if len(pInt) == 0:
124
+ return 0
125
+ return pInt.area()
126
+
127
+ def compute_ap(confList, matchList,numGtCare):
128
+ correct = 0
129
+ AP = 0
130
+ if len(confList)>0:
131
+ confList = np.array(confList)
132
+ matchList = np.array(matchList)
133
+ sorted_ind = np.argsort(-confList)
134
+ confList = confList[sorted_ind]
135
+ matchList = matchList[sorted_ind]
136
+ for n in range(len(confList)):
137
+ match = matchList[n]
138
+ if match:
139
+ correct += 1
140
+ AP += float(correct)/(n + 1)
141
+
142
+ if numGtCare>0:
143
+ AP /= numGtCare
144
+
145
+ return AP
146
+
147
+ def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
148
+
149
+ if onlyRemoveFirstLastCharacterGT:
150
+ #special characters in GT are allowed only at initial or final position
151
+ if (transGt==transDet):
152
+ return True
153
+
154
+ if specialCharacters.find(transGt[0])>-1:
155
+ if transGt[1:]==transDet:
156
+ return True
157
+
158
+ if specialCharacters.find(transGt[-1])>-1:
159
+ if transGt[0:len(transGt)-1]==transDet:
160
+ return True
161
+
162
+ if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
163
+ if transGt[1:len(transGt)-1]==transDet:
164
+ return True
165
+ return False
166
+ else:
167
+ #Special characters are removed from the begining and the end of both Detection and GroundTruth
168
+ while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
169
+ transGt = transGt[1:]
170
+
171
+ while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
172
+ transDet = transDet[1:]
173
+
174
+ while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
175
+ transGt = transGt[0:len(transGt)-1]
176
+
177
+ while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
178
+ transDet = transDet[0:len(transDet)-1]
179
+
180
+ return transGt == transDet
181
+
182
+
183
+ def include_in_dictionary(transcription):
184
+ """
185
+ Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
186
+ """
187
+ #special case 's at final
188
+ if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
189
+ transcription = transcription[0:len(transcription)-2]
190
+
191
+ #hypens at init or final of the word
192
+ transcription = transcription.strip('-');
193
+
194
+ specialCharacters = "'!?.:,*\"()·[]/";
195
+ for character in specialCharacters:
196
+ transcription = transcription.replace(character,' ')
197
+
198
+ transcription = transcription.strip()
199
+
200
+ if len(transcription) != len(transcription.replace(" ","")) :
201
+ return False;
202
+
203
+ if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
204
+ return False;
205
+
206
+ notAllowed = "×÷·";
207
+
208
+ range1 = [ ord(u'a'), ord(u'z') ]
209
+ range2 = [ ord(u'A'), ord(u'Z') ]
210
+ range3 = [ ord(u'À'), ord(u'ƿ') ]
211
+ range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
212
+ range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
213
+ range6 = [ ord(u'-'), ord(u'-') ]
214
+
215
+ for char in transcription :
216
+ charCode = ord(char)
217
+ if(notAllowed.find(char) != -1):
218
+ return False
219
+
220
+ valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
221
+ if valid == False:
222
+ return False
223
+
224
+ return True
225
+
226
+ def include_in_dictionary_transcription(transcription):
227
+ """
228
+ Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
229
+ """
230
+ #special case 's at final
231
+ if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
232
+ transcription = transcription[0:len(transcription)-2]
233
+
234
+ #hypens at init or final of the word
235
+ transcription = transcription.strip('-');
236
+
237
+ specialCharacters = "'!?.:,*\"()·[]/";
238
+ for character in specialCharacters:
239
+ transcription = transcription.replace(character,' ')
240
+
241
+ transcription = transcription.strip()
242
+
243
+ return transcription
244
+
245
+ perSampleMetrics = {}
246
+
247
+ matchedSum = 0
248
+
249
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
250
+
251
+ gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
252
+ subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
253
+
254
+ numGlobalCareGt = 0;
255
+ numGlobalCareDet = 0;
256
+
257
+ arrGlobalConfidences = [];
258
+ arrGlobalMatches = [];
259
+
260
+ for resFile in gt:
261
+
262
+ gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
263
+ if (gtFile is None) :
264
+ raise Exception("The file %s is not UTF-8" %resFile)
265
+
266
+ recall = 0
267
+ precision = 0
268
+ hmean = 0
269
+ detCorrect = 0
270
+ iouMat = np.empty([1,1])
271
+ gtPols = []
272
+ detPols = []
273
+ gtTrans = []
274
+ detTrans = []
275
+ gtPolPoints = []
276
+ detPolPoints = []
277
+ gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
278
+ detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
279
+ detMatchedNums = []
280
+ pairs = []
281
+
282
+ arrSampleConfidences = [];
283
+ arrSampleMatch = [];
284
+ sampleAP = 0;
285
+
286
+ evaluationLog = ""
287
+
288
+ pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
289
+ for n in range(len(pointsList)):
290
+ points = pointsList[n]
291
+ transcription = transcriptionsList[n]
292
+ dontCare = transcription == "###"
293
+ if evaluationParams['LTRB']:
294
+ gtRect = Rectangle(*points)
295
+ gtPol = rectangle_to_polygon(gtRect)
296
+ else:
297
+ gtPol = polygon_from_points(points)
298
+ gtPols.append(gtPol)
299
+ gtPolPoints.append(points)
300
+
301
+ #On word spotting we will filter some transcriptions with special characters
302
+ if evaluationParams['WORD_SPOTTING'] :
303
+ if dontCare == False :
304
+ if include_in_dictionary(transcription) == False :
305
+ dontCare = True
306
+ else:
307
+ transcription = include_in_dictionary_transcription(transcription)
308
+
309
+ gtTrans.append(transcription)
310
+ if dontCare:
311
+ gtDontCarePolsNum.append( len(gtPols)-1 )
312
+
313
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
314
+
315
+ if resFile in subm:
316
+
317
+ detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
318
+
319
+ pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
320
+
321
+ for n in range(len(pointsList)):
322
+ points = pointsList[n]
323
+ transcription = transcriptionsList[n]
324
+
325
+ if evaluationParams['LTRB']:
326
+ detRect = Rectangle(*points)
327
+ detPol = rectangle_to_polygon(detRect)
328
+ else:
329
+ detPol = polygon_from_points(points)
330
+ detPols.append(detPol)
331
+ detPolPoints.append(points)
332
+ detTrans.append(transcription)
333
+
334
+ if len(gtDontCarePolsNum)>0 :
335
+ for dontCarePol in gtDontCarePolsNum:
336
+ dontCarePol = gtPols[dontCarePol]
337
+ intersected_area = get_intersection(dontCarePol,detPol)
338
+ pdDimensions = detPol.area()
339
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
340
+ if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
341
+ detDontCarePolsNum.append( len(detPols)-1 )
342
+ break
343
+
344
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
345
+
346
+ if len(gtPols)>0 and len(detPols)>0:
347
+ #Calculate IoU and precision matrixs
348
+ outputShape=[len(gtPols),len(detPols)]
349
+ iouMat = np.empty(outputShape)
350
+ gtRectMat = np.zeros(len(gtPols),np.int8)
351
+ detRectMat = np.zeros(len(detPols),np.int8)
352
+ for gtNum in range(len(gtPols)):
353
+ for detNum in range(len(detPols)):
354
+ pG = gtPols[gtNum]
355
+ pD = detPols[detNum]
356
+ iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
357
+
358
+ for gtNum in range(len(gtPols)):
359
+ for detNum in range(len(detPols)):
360
+ if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
361
+ if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
362
+ gtRectMat[gtNum] = 1
363
+ detRectMat[detNum] = 1
364
+ #detection matched only if transcription is equal
365
+ if evaluationParams['WORD_SPOTTING']:
366
+ correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
367
+ else:
368
+ correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
369
+ detCorrect += (1 if correct else 0)
370
+ if correct:
371
+ detMatchedNums.append(detNum)
372
+ pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
373
+ evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
374
+
375
+ if evaluationParams['CONFIDENCES']:
376
+ for detNum in range(len(detPols)):
377
+ if detNum not in detDontCarePolsNum :
378
+ #we exclude the don't care detections
379
+ match = detNum in detMatchedNums
380
+
381
+ arrSampleConfidences.append(confidencesList[detNum])
382
+ arrSampleMatch.append(match)
383
+
384
+ arrGlobalConfidences.append(confidencesList[detNum]);
385
+ arrGlobalMatches.append(match);
386
+
387
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
388
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
389
+ if numGtCare == 0:
390
+ recall = float(1)
391
+ precision = float(0) if numDetCare >0 else float(1)
392
+ sampleAP = precision
393
+ else:
394
+ recall = float(detCorrect) / numGtCare
395
+ precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
396
+ if evaluationParams['CONFIDENCES']:
397
+ sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
398
+
399
+ hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
400
+
401
+ matchedSum += detCorrect
402
+ numGlobalCareGt += numGtCare
403
+ numGlobalCareDet += numDetCare
404
+
405
+ perSampleMetrics[resFile] = {
406
+ 'precision':precision,
407
+ 'recall':recall,
408
+ 'hmean':hmean,
409
+ 'pairs':pairs,
410
+ 'AP':sampleAP,
411
+ 'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
412
+ 'gtPolPoints':gtPolPoints,
413
+ 'detPolPoints':detPolPoints,
414
+ 'gtTrans':gtTrans,
415
+ 'detTrans':detTrans,
416
+ 'gtDontCare':gtDontCarePolsNum,
417
+ 'detDontCare':detDontCarePolsNum,
418
+ 'evaluationParams': evaluationParams,
419
+ 'evaluationLog': evaluationLog
420
+ }
421
+
422
+ # Compute AP
423
+ AP = 0
424
+ if evaluationParams['CONFIDENCES']:
425
+ AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
426
+
427
+ methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
428
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
429
+ methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
430
+
431
+ methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
432
+
433
+ resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
434
+
435
+
436
+ return resDict;
437
+
438
+
439
+
440
+ if __name__=='__main__':
441
+ '''
442
+ results_dir: result directory
443
+ score_det: score of detection bounding box
444
+ score_rec: score of the mask recognition branch
445
+ score_rec_seq: score of the sequence recognition branch
446
+ lexicon_type: 1 for generic; 2 for weak; 3 for strong
447
+ '''
448
+ results_dir = '../../../output/mixtrain/inference/icdar_2015_test/model_0250000_1440_results/'
449
+ lexicon_type = 3
450
+ score_det = 0.01
451
+ score_rec = 0.4
452
+ # score_rec_seq set to 0.7 for lexicon_type 3 or 2; 0.8 for lexicon_type 1
453
+ score_rec_seq = 0.7
454
+ evaluate_result_path = prepare_results_for_evaluation(results_dir,
455
+ lexicon_type=lexicon_type, cache_dir='./cache_files',
456
+ score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
457
+ p = {
458
+ 'g': "../gt.zip",
459
+ 's': evaluate_result_path
460
+ }
461
+ rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
evaluation/icdar2015/gt.zip ADDED
Binary file (250 kB). View file
 
evaluation/rotated_icdar2013/e2e/prepare_results.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ import sys
4
+ import os
5
+ sys.path.append('./')
6
+ import shapely
7
+ from shapely.geometry import Polygon,MultiPoint
8
+ import numpy as np
9
+ import editdistance
10
+ sys.path.append('../../')
11
+ from weighted_editdistance import weighted_edit_distance
12
+ from tqdm import tqdm
13
+ try:
14
+ import pickle
15
+ except ImportError:
16
+ import cPickle as pickle
17
+
18
+ def list_from_str(st):
19
+ line = st.split(',')
20
+ # box[0:4], polygon[4:12], word, seq_word, detection_score, rec_socre, seq_score, char_score_path
21
+ new_line = [float(a) for a in line[4:12]]+[float(line[-4])]+[line[-5]]+[line[-6]]+[float(line[-3])]+[float(line[-2])] + [line[-1]]
22
+ return new_line
23
+
24
+ def polygon_from_list(line):
25
+ """
26
+ Create a shapely polygon object from gt or dt line.
27
+ """
28
+ polygon_points = np.array(line).reshape(4, 2)
29
+ polygon = Polygon(polygon_points).convex_hull
30
+ return polygon
31
+
32
+ def polygon_iou(list1, list2):
33
+ """
34
+ Intersection over union between two shapely polygons.
35
+ """
36
+ polygon_points1 = np.array(list1).reshape(4, 2)
37
+ poly1 = Polygon(polygon_points1).convex_hull
38
+ polygon_points2 = np.array(list2).reshape(4, 2)
39
+ poly2 = Polygon(polygon_points2).convex_hull
40
+ union_poly = np.concatenate((polygon_points1,polygon_points2))
41
+ if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
42
+ iou = 0
43
+ else:
44
+ try:
45
+ inter_area = poly1.intersection(poly2).area
46
+ #union_area = poly1.area + poly2.area - inter_area
47
+ union_area = MultiPoint(union_poly).convex_hull.area
48
+ iou = float(inter_area) / (union_area+1e-6)
49
+ except shapely.geos.TopologicalError:
50
+ print('shapely.geos.TopologicalError occured, iou set to 0')
51
+ iou = 0
52
+ return iou
53
+
54
+ def nms(boxes,overlap):
55
+ rec_scores = [b[-2] for b in boxes]
56
+ indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k])
57
+ box_num = len(boxes)
58
+ nms_flag = [True]*box_num
59
+ for i in range(box_num):
60
+ ii = indices[i]
61
+ if not nms_flag[ii]:
62
+ continue
63
+ for j in range(box_num):
64
+ jj = indices[j]
65
+ if j == i:
66
+ continue
67
+ if not nms_flag[jj]:
68
+ continue
69
+ box1 = boxes[ii]
70
+ box2 = boxes[jj]
71
+ box1_score = rec_scores[ii]
72
+ box2_score = rec_scores[jj]
73
+ str1 = box1[9]
74
+ str2 = box2[9]
75
+ box_i = [box1[0],box1[1],box1[4],box1[5]]
76
+ box_j = [box2[0],box2[1],box2[4],box2[5]]
77
+ poly1 = polygon_from_list(box1[0:8])
78
+ poly2 = polygon_from_list(box2[0:8])
79
+ iou = polygon_iou(box1[0:8],box2[0:8])
80
+ thresh = overlap
81
+
82
+ if iou > thresh:
83
+ if box1_score > box2_score:
84
+ nms_flag[jj] = False
85
+ if box1_score == box2_score and poly1.area > poly2.area:
86
+ nms_flag[jj] = False
87
+ if box1_score == box2_score and poly1.area<=poly2.area:
88
+ nms_flag[ii] = False
89
+ break
90
+
91
+ return nms_flag
92
+
93
+ def packing(save_dir, cache_dir, pack_name):
94
+ files = os.listdir(save_dir)
95
+ if not os.path.exists(cache_dir):
96
+ os.mkdir(cache_dir)
97
+ os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*')
98
+
99
+ def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False):
100
+ '''
101
+ results_dir: result directory
102
+ score_det: score of detection bounding box
103
+ score_rec: score of the mask recognition branch
104
+ socre_rec_seq: score of the sequence recognition branch
105
+ overlap: overlap threshold used for nms
106
+ lexicon_type: 1 for generic; 2 for weak; 3 for strong
107
+ use_seq: use the recognition result of sequence branch
108
+ use_mix: use both the recognition result of the mask and sequence branches, selected by score
109
+ '''
110
+ print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix)
111
+ if not os.path.exists(cache_dir):
112
+ os.mkdir(cache_dir)
113
+ nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq))
114
+ if not os.path.exists(nms_dir):
115
+ os.mkdir(nms_dir)
116
+ if lexicon_type==1:
117
+ # generic lexicon
118
+ lexicon_path = '../../lexicons/ic13/GenericVocabulary_new.txt'
119
+ lexicon_fid=open(lexicon_path, 'r')
120
+ pair_list = open('../../lexicons/ic13/GenericVocabulary_pair_list.txt', 'r')
121
+ pairs = dict()
122
+ for line in pair_list.readlines():
123
+ line=line.strip()
124
+ word = line.split(' ')[0].upper()
125
+ word_gt = line[len(word)+1:]
126
+ pairs[word] = word_gt
127
+ lexicon_fid=open(lexicon_path, 'r')
128
+ lexicon=[]
129
+ for line in lexicon_fid.readlines():
130
+ line=line.strip()
131
+ lexicon.append(line)
132
+ if lexicon_type==2:
133
+ # weak lexicon
134
+ lexicon_path = '../../lexicons/ic13/ch4_test_vocabulary_new.txt'
135
+ lexicon_fid=open(lexicon_path, 'r')
136
+ pair_list = open('../../lexicons/ic13/ch4_test_vocabulary_pair_list.txt', 'r')
137
+ pairs = dict()
138
+ for line in pair_list.readlines():
139
+ line=line.strip()
140
+ word = line.split(' ')[0].upper()
141
+ word_gt = line[len(word)+1:]
142
+ pairs[word] = word_gt
143
+ lexicon_fid=open(lexicon_path, 'r')
144
+ lexicon=[]
145
+ for line in lexicon_fid.readlines():
146
+ line=line.strip()
147
+ lexicon.append(line)
148
+
149
+ for i in tqdm(range(1,234)):
150
+ img = 'img_'+str(i)+'.jpg'
151
+ gt_img = 'gt_img_'+str(i)+'.txt'
152
+ if lexicon_type==3:
153
+ # weak
154
+ lexicon_path = '../../lexicons/ic13/new_strong_lexicon/new_voc_img_' + str(i) + '.txt'
155
+ lexicon_fid=open(lexicon_path, 'r')
156
+ pair_list = open('../../lexicons/ic13/new_strong_lexicon/pair_voc_img_' + str(i) + '.txt', 'r')
157
+ pairs = dict()
158
+ for line in pair_list.readlines():
159
+ line=line.strip()
160
+ word = line.split(' ')[0].upper()
161
+ word_gt = line[len(word)+1:]
162
+ pairs[word] = word_gt
163
+ lexicon_fid=open(lexicon_path, 'r')
164
+ lexicon=[]
165
+ for line in lexicon_fid.readlines():
166
+ line=line.strip()
167
+ lexicon.append(line)
168
+ result_path = os.path.join(results_dir,'res_img_'+str(i)+'.txt')
169
+ if os.path.isfile(result_path):
170
+ with open(result_path,'r') as f:
171
+ dt_lines = [a.strip() for a in f.readlines()]
172
+ dt_lines = [list_from_str(dt) for dt in dt_lines]
173
+ else:
174
+ dt_lines = []
175
+ dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det]
176
+ nms_flag = nms(dt_lines,overlap)
177
+ boxes = []
178
+ for k in range(len(dt_lines)):
179
+ dt = dt_lines[k]
180
+ if nms_flag[k]:
181
+ if dt not in boxes:
182
+ boxes.append(dt)
183
+
184
+ with open(os.path.join(nms_dir,'res_img_'+str(i)+'.txt'),'w') as f:
185
+ for g in boxes:
186
+ gt_coors = [int(b) for b in g[0:8]]
187
+ with open('../../../' + g[-1], "rb") as input_file:
188
+ # with open(g[-1], "rb") as input_file:
189
+ dict_scores = pickle.load(input_file)
190
+ if use_char and use_seq:
191
+ if g[-2]>g[-3]:
192
+ word = g[-5]
193
+ scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
194
+ else:
195
+ word = g[-4]
196
+ scores = dict_scores['seg_char_scores']
197
+ elif use_seq:
198
+ word = g[-5]
199
+ scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
200
+ else:
201
+ word = g[-4]
202
+ scores = dict_scores['seg_char_scores']
203
+ if not use_lexicon:
204
+ match_word = word
205
+ match_dist = 0.
206
+ else:
207
+ match_word, match_dist = find_match_word(word, lexicon, pairs, scores, use_lexicon, weighted_ed)
208
+ if match_dist<1.5 or lexicon_type==1:
209
+ gt_coor_strs = [str(a) for a in gt_coors]+ [match_word]
210
+ f.write(','.join(gt_coor_strs)+'\r\n')
211
+
212
+ pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap)
213
+
214
+ packing(nms_dir,cache_dir,pack_name)
215
+ submit_file_path = os.path.join(cache_dir, pack_name+'.zip')
216
+ return submit_file_path
217
+
218
+ def find_match_word(rec_str, lexicon, pairs, scores_numpy, use_ed = True, weighted_ed = False):
219
+ if not use_ed:
220
+ return rec_str
221
+ rec_str = rec_str.upper()
222
+ dist_min = 100
223
+ dist_min_pre = 100
224
+ match_word = ''
225
+ match_dist = 100
226
+ if not weighted_ed:
227
+ for word in lexicon:
228
+ word = word.upper()
229
+ ed = editdistance.eval(rec_str, word)
230
+ length_dist = abs(len(word) - len(rec_str))
231
+ # dist = ed + length_dist
232
+ dist = ed
233
+ if dist<dist_min:
234
+ dist_min = dist
235
+ match_word = pairs[word]
236
+ match_dist = dist
237
+ return match_word, match_dist
238
+ else:
239
+ small_lexicon_dict = dict()
240
+ for word in lexicon:
241
+ word = word.upper()
242
+ ed = editdistance.eval(rec_str, word)
243
+ small_lexicon_dict[word] = ed
244
+ dist = ed
245
+ if dist<dist_min_pre:
246
+ dist_min_pre = dist
247
+ small_lexicon = []
248
+ for word in small_lexicon_dict:
249
+ if small_lexicon_dict[word]<=dist_min_pre+2:
250
+ small_lexicon.append(word)
251
+
252
+ for word in small_lexicon:
253
+ word = word.upper()
254
+ ed = weighted_edit_distance(rec_str, word, scores_numpy)
255
+ dist = ed
256
+ if dist<dist_min:
257
+ dist_min = dist
258
+ match_word = pairs[word]
259
+ match_dist = dist
260
+ return match_word, match_dist
261
+
262
+
263
+ def prepare_results_for_evaluation(results_dir, use_lexicon, cache_dir, score_det, score_rec, score_rec_seq):
264
+ if not os.path.isdir(cache_dir):
265
+ os.mkdir(cache_dir)
266
+ result_path = test_single(results_dir,score_det=score_det,score_rec=score_rec,score_rec_seq=score_rec_seq,overlap=0.2,cache_dir=cache_dir,lexicon_type=3, use_lexicon=use_lexicon, weighted_ed=True, use_seq=True, use_char=True, mix=True)
267
+ return result_path
evaluation/rotated_icdar2013/e2e/rrc_evaluation_funcs.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ #encoding: UTF-8
3
+ import json
4
+ import sys;sys.path.append('./')
5
+ import zipfile
6
+ import re
7
+ import sys
8
+ import os
9
+ import codecs
10
+ import importlib
11
+ try:
12
+ from StringIO import StringIO
13
+ except ImportError:
14
+ from io import StringIO
15
+
16
+ def print_help():
17
+ sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
18
+ sys.exit(2)
19
+
20
+
21
+ def load_zip_file_keys(file,fileNameRegExp=''):
22
+ """
23
+ Returns an array with the entries of the ZIP file that match with the regular expression.
24
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
25
+ """
26
+ try:
27
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
28
+ except :
29
+ raise Exception('Error loading the ZIP archive.')
30
+
31
+ pairs = []
32
+
33
+ for name in archive.namelist():
34
+ addFile = True
35
+ keyName = name
36
+ if fileNameRegExp!="":
37
+ m = re.match(fileNameRegExp,name)
38
+ if m == None:
39
+ addFile = False
40
+ else:
41
+ if len(m.groups())>0:
42
+ keyName = m.group(1)
43
+
44
+ if addFile:
45
+ pairs.append( keyName )
46
+
47
+ return pairs
48
+
49
+
50
+ def load_zip_file(file,fileNameRegExp='',allEntries=False):
51
+ """
52
+ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
53
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
54
+ allEntries validates that all entries in the ZIP file pass the fileNameRegExp
55
+ """
56
+ try:
57
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
58
+ except :
59
+ raise Exception('Error loading the ZIP archive')
60
+
61
+ pairs = []
62
+ for name in archive.namelist():
63
+ addFile = True
64
+ keyName = name
65
+ if fileNameRegExp!="":
66
+ m = re.match(fileNameRegExp,name)
67
+ if m == None:
68
+ addFile = False
69
+ else:
70
+ if len(m.groups())>0:
71
+ keyName = m.group(1)
72
+
73
+ if addFile:
74
+ pairs.append( [ keyName , archive.read(name)] )
75
+ else:
76
+ if allEntries:
77
+ raise Exception('ZIP entry not valid: %s' %name)
78
+
79
+ return dict(pairs)
80
+
81
+ def decode_utf8(raw):
82
+ """
83
+ Returns a Unicode object on success, or None on failure
84
+ """
85
+ try:
86
+ raw = codecs.decode(raw,'utf-8', 'replace')
87
+ #extracts BOM if exists
88
+ raw = raw.encode('utf8')
89
+ if raw.startswith(codecs.BOM_UTF8):
90
+ raw = raw.replace(codecs.BOM_UTF8, '', 1)
91
+ return raw.decode('utf-8')
92
+ except:
93
+ return None
94
+
95
+ def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
96
+ """
97
+ This function validates that all lines of the file calling the Line validation function for each line
98
+ """
99
+ utf8File = decode_utf8(file_contents)
100
+ if (utf8File is None) :
101
+ raise Exception("The file %s is not UTF-8" %fileName)
102
+
103
+ lines = utf8File.split( "\r\n" if CRLF else "\n" )
104
+ for line in lines:
105
+ line = line.replace("\r","").replace("\n","")
106
+ if(line != ""):
107
+ try:
108
+ validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
109
+ except Exception as e:
110
+ raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
111
+
112
+
113
+
114
+ def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
115
+ """
116
+ Validate the format of the line. If the line is not valid an exception will be raised.
117
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
118
+ Posible values are:
119
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
120
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
121
+ """
122
+ get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
123
+
124
+
125
+ def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
126
+ """
127
+ Validate the format of the line. If the line is not valid an exception will be raised.
128
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
129
+ Posible values are:
130
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
131
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
132
+ Returns values from a textline. Points , [Confidences], [Transcriptions]
133
+ """
134
+ confidence = 0.0
135
+ transcription = "";
136
+ points = []
137
+
138
+ numPoints = 4;
139
+
140
+ if LTRB:
141
+
142
+ numPoints = 4;
143
+
144
+ if withTranscription and withConfidence:
145
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
146
+ if m == None :
147
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
148
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
149
+ elif withConfidence:
150
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
151
+ if m == None :
152
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
153
+ elif withTranscription:
154
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
155
+ if m == None :
156
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
157
+ else:
158
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
159
+ if m == None :
160
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
161
+
162
+ xmin = int(m.group(1))
163
+ ymin = int(m.group(2))
164
+ xmax = int(m.group(3))
165
+ ymax = int(m.group(4))
166
+ if(xmax<xmin):
167
+ raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
168
+ if(ymax<ymin):
169
+ raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
170
+
171
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
172
+
173
+ if (imWidth>0 and imHeight>0):
174
+ validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
175
+ validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
176
+
177
+ else:
178
+
179
+ numPoints = 8;
180
+
181
+ if withTranscription and withConfidence:
182
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
183
+ if m == None :
184
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
185
+ elif withConfidence:
186
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
187
+ if m == None :
188
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
189
+ elif withTranscription:
190
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
191
+ if m == None :
192
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
193
+ else:
194
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
195
+ if m == None :
196
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
197
+
198
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
199
+
200
+ validate_clockwise_points(points)
201
+
202
+ if (imWidth>0 and imHeight>0):
203
+ validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
204
+ validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
205
+ validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
206
+ validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
207
+
208
+
209
+ if withConfidence:
210
+ try:
211
+ confidence = float(m.group(numPoints+1))
212
+ except ValueError:
213
+ raise Exception("Confidence value must be a float")
214
+
215
+ if withTranscription:
216
+ posTranscription = numPoints + (2 if withConfidence else 1)
217
+ transcription = m.group(posTranscription)
218
+ m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
219
+ if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
220
+ transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
221
+
222
+ return points,confidence,transcription
223
+
224
+
225
+ def validate_point_inside_bounds(x,y,imWidth,imHeight):
226
+ if(x<0 or x>imWidth):
227
+ raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
228
+ if(y<0 or y>imHeight):
229
+ raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
230
+
231
+ def validate_clockwise_points(points):
232
+ """
233
+ Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
234
+ """
235
+
236
+ if len(points) != 8:
237
+ raise Exception("Points list not valid." + str(len(points)))
238
+
239
+ point = [
240
+ [int(points[0]) , int(points[1])],
241
+ [int(points[2]) , int(points[3])],
242
+ [int(points[4]) , int(points[5])],
243
+ [int(points[6]) , int(points[7])]
244
+ ]
245
+ edge = [
246
+ ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
247
+ ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
248
+ ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
249
+ ( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
250
+ ]
251
+
252
+ summatory = edge[0] + edge[1] + edge[2] + edge[3];
253
+ if summatory>0:
254
+ raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
255
+
256
+ def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
257
+ """
258
+ Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
259
+ xmin,ymin,xmax,ymax,[confidence],[transcription]
260
+ x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
261
+ """
262
+ pointsList = []
263
+ transcriptionsList = []
264
+ confidencesList = []
265
+
266
+ lines = content.split( "\r\n" if CRLF else "\n" )
267
+ for line in lines:
268
+ line = line.replace("\r","").replace("\n","")
269
+ if(line != "") :
270
+ points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
271
+ pointsList.append(points)
272
+ transcriptionsList.append(transcription)
273
+ confidencesList.append(confidence)
274
+
275
+ if withConfidence and len(confidencesList)>0 and sort_by_confidences:
276
+ import numpy as np
277
+ sorted_ind = np.argsort(-np.array(confidencesList))
278
+ confidencesList = [confidencesList[i] for i in sorted_ind]
279
+ pointsList = [pointsList[i] for i in sorted_ind]
280
+ transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
281
+
282
+ return pointsList,confidencesList,transcriptionsList
283
+
284
+ def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
285
+ """
286
+ This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
287
+ Params:
288
+ p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
289
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
290
+ validate_data_fn: points to a method that validates the corrct format of the submission
291
+ evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
292
+ """
293
+
294
+ if (p == None):
295
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
296
+ if(len(sys.argv)<3):
297
+ print_help()
298
+
299
+ evalParams = default_evaluation_params_fn()
300
+ if 'p' in p.keys():
301
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
302
+
303
+ resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
304
+ try:
305
+ validate_data_fn(p['g'], p['s'], evalParams)
306
+ evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
307
+ resDict.update(evalData)
308
+
309
+ except Exception as e:
310
+ resDict['Message']= str(e)
311
+ resDict['calculated']=False
312
+
313
+ if 'o' in p:
314
+ if not os.path.exists(p['o']):
315
+ os.makedirs(p['o'])
316
+
317
+ resultsOutputname = p['o'] + '/results.zip'
318
+ outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
319
+
320
+ del resDict['per_sample']
321
+ if 'output_items' in resDict.keys():
322
+ del resDict['output_items']
323
+
324
+ outZip.writestr('method.json',json.dumps(resDict))
325
+
326
+ if not resDict['calculated']:
327
+ if show_result:
328
+ sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
329
+ if 'o' in p:
330
+ outZip.close()
331
+ return resDict
332
+
333
+ if 'o' in p:
334
+ if per_sample == True:
335
+ for k,v in evalData['per_sample'].items():
336
+ outZip.writestr( k + '.json',json.dumps(v))
337
+
338
+ if 'output_items' in evalData.keys():
339
+ for k, v in evalData['output_items'].items():
340
+ outZip.writestr( k,v)
341
+
342
+ outZip.close()
343
+
344
+ if show_result:
345
+ sys.stdout.write("Calculated!")
346
+ sys.stdout.write(json.dumps(resDict['method']))
347
+
348
+ return resDict
349
+
350
+
351
+ def main_validation(default_evaluation_params_fn,validate_data_fn):
352
+ """
353
+ This process validates a method
354
+ Params:
355
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
356
+ validate_data_fn: points to a method that validates the corrct format of the submission
357
+ """
358
+ try:
359
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
360
+ evalParams = default_evaluation_params_fn()
361
+ if 'p' in p.keys():
362
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
363
+
364
+ validate_data_fn(p['g'], p['s'], evalParams)
365
+ print('SUCCESS')
366
+ sys.exit(0)
367
+ except Exception as e:
368
+ print(str(e))
369
+ sys.exit(101)
evaluation/rotated_icdar2013/e2e/script.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # encoding=utf8
4
+ from collections import namedtuple
5
+ import rrc_evaluation_funcs
6
+ import importlib
7
+ from prepare_results import prepare_results_for_evaluation
8
+
9
+ def evaluation_imports():
10
+ """
11
+ evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
12
+ """
13
+ return {
14
+ 'Polygon':'plg',
15
+ 'numpy':'np'
16
+ }
17
+
18
+ def default_evaluation_params():
19
+ """
20
+ default_evaluation_params: Default parameters to use for the validation and evaluation.
21
+ """
22
+ return {
23
+ 'IOU_CONSTRAINT' :0.5,
24
+ 'AREA_PRECISION_CONSTRAINT' :0.5,
25
+ 'WORD_SPOTTING' :False,
26
+ 'MIN_LENGTH_CARE_WORD' :3,
27
+ 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
28
+ 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
29
+ 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
30
+ 'CRLF':False, # Lines are delimited by Windows CRLF format
31
+ 'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
32
+ 'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
33
+ 'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
34
+ }
35
+
36
+ def validate_data(gtFilePath, submFilePath, evaluationParams):
37
+ """
38
+ Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
39
+ Validates also that there are no missing files in the folder.
40
+ If some error detected, the method raises the error
41
+ """
42
+ gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
43
+
44
+ subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
45
+
46
+ #Validate format of GroundTruth
47
+ for k in gt:
48
+ rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
49
+
50
+ #Validate format of results
51
+ for k in subm:
52
+ if (k in gt) == False :
53
+ raise Exception("The sample %s not present in GT" %k)
54
+
55
+ rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
56
+
57
+
58
+ def evaluate_method(gtFilePath, submFilePath, evaluationParams):
59
+ """
60
+ Method evaluate_method: evaluate method and returns the results
61
+ Results. Dictionary with the following values:
62
+ - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
63
+ - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
64
+ """
65
+ for module,alias in evaluation_imports().items():
66
+ globals()[alias] = importlib.import_module(module)
67
+
68
+ def polygon_from_points(points,correctOffset=False):
69
+ """
70
+ Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
71
+ """
72
+
73
+ if correctOffset: #this will substract 1 from the coordinates that correspond to the xmax and ymax
74
+ points[2] -= 1
75
+ points[4] -= 1
76
+ points[5] -= 1
77
+ points[7] -= 1
78
+
79
+ resBoxes=np.empty([1,8],dtype='int32')
80
+ resBoxes[0,0]=int(points[0])
81
+ resBoxes[0,4]=int(points[1])
82
+ resBoxes[0,1]=int(points[2])
83
+ resBoxes[0,5]=int(points[3])
84
+ resBoxes[0,2]=int(points[4])
85
+ resBoxes[0,6]=int(points[5])
86
+ resBoxes[0,3]=int(points[6])
87
+ resBoxes[0,7]=int(points[7])
88
+ pointMat = resBoxes[0].reshape([2,4]).T
89
+ return plg.Polygon( pointMat)
90
+
91
+ def rectangle_to_polygon(rect):
92
+ resBoxes=np.empty([1,8],dtype='int32')
93
+ resBoxes[0,0]=int(rect.xmin)
94
+ resBoxes[0,4]=int(rect.ymax)
95
+ resBoxes[0,1]=int(rect.xmin)
96
+ resBoxes[0,5]=int(rect.ymin)
97
+ resBoxes[0,2]=int(rect.xmax)
98
+ resBoxes[0,6]=int(rect.ymin)
99
+ resBoxes[0,3]=int(rect.xmax)
100
+ resBoxes[0,7]=int(rect.ymax)
101
+
102
+ pointMat = resBoxes[0].reshape([2,4]).T
103
+
104
+ return plg.Polygon( pointMat)
105
+
106
+ def rectangle_to_points(rect):
107
+ points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
108
+ return points
109
+
110
+ def get_union(pD,pG):
111
+ areaA = pD.area();
112
+ areaB = pG.area();
113
+ return areaA + areaB - get_intersection(pD, pG);
114
+
115
+ def get_intersection_over_union(pD,pG):
116
+ try:
117
+ return get_intersection(pD, pG) / get_union(pD, pG);
118
+ except:
119
+ return 0
120
+
121
+ def get_intersection(pD,pG):
122
+ pInt = pD & pG
123
+ if len(pInt) == 0:
124
+ return 0
125
+ return pInt.area()
126
+
127
+ def compute_ap(confList, matchList,numGtCare):
128
+ correct = 0
129
+ AP = 0
130
+ if len(confList)>0:
131
+ confList = np.array(confList)
132
+ matchList = np.array(matchList)
133
+ sorted_ind = np.argsort(-confList)
134
+ confList = confList[sorted_ind]
135
+ matchList = matchList[sorted_ind]
136
+ for n in range(len(confList)):
137
+ match = matchList[n]
138
+ if match:
139
+ correct += 1
140
+ AP += float(correct)/(n + 1)
141
+
142
+ if numGtCare>0:
143
+ AP /= numGtCare
144
+
145
+ return AP
146
+
147
+ def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
148
+
149
+ if onlyRemoveFirstLastCharacterGT:
150
+ #special characters in GT are allowed only at initial or final position
151
+ if (transGt==transDet):
152
+ return True
153
+
154
+ if specialCharacters.find(transGt[0])>-1:
155
+ if transGt[1:]==transDet:
156
+ return True
157
+
158
+ if specialCharacters.find(transGt[-1])>-1:
159
+ if transGt[0:len(transGt)-1]==transDet:
160
+ return True
161
+
162
+ if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
163
+ if transGt[1:len(transGt)-1]==transDet:
164
+ return True
165
+ return False
166
+ else:
167
+ #Special characters are removed from the begining and the end of both Detection and GroundTruth
168
+ while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
169
+ transGt = transGt[1:]
170
+
171
+ while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
172
+ transDet = transDet[1:]
173
+
174
+ while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
175
+ transGt = transGt[0:len(transGt)-1]
176
+
177
+ while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
178
+ transDet = transDet[0:len(transDet)-1]
179
+
180
+ return transGt == transDet
181
+
182
+
183
+ def include_in_dictionary(transcription):
184
+ """
185
+ Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
186
+ """
187
+ #special case 's at final
188
+ if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
189
+ transcription = transcription[0:len(transcription)-2]
190
+
191
+ #hypens at init or final of the word
192
+ transcription = transcription.strip('-');
193
+
194
+ specialCharacters = "'!?.:,*\"()·[]/";
195
+ for character in specialCharacters:
196
+ transcription = transcription.replace(character,' ')
197
+
198
+ transcription = transcription.strip()
199
+
200
+ if len(transcription) != len(transcription.replace(" ","")) :
201
+ return False;
202
+
203
+ if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
204
+ return False;
205
+
206
+ notAllowed = "×÷·";
207
+
208
+ range1 = [ ord(u'a'), ord(u'z') ]
209
+ range2 = [ ord(u'A'), ord(u'Z') ]
210
+ range3 = [ ord(u'À'), ord(u'ƿ') ]
211
+ range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
212
+ range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
213
+ range6 = [ ord(u'-'), ord(u'-') ]
214
+
215
+ for char in transcription :
216
+ charCode = ord(char)
217
+ if(notAllowed.find(char) != -1):
218
+ return False
219
+
220
+ valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
221
+ if valid == False:
222
+ return False
223
+
224
+ return True
225
+
226
+ def include_in_dictionary_transcription(transcription):
227
+ """
228
+ Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
229
+ """
230
+ #special case 's at final
231
+ if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
232
+ transcription = transcription[0:len(transcription)-2]
233
+
234
+ #hypens at init or final of the word
235
+ transcription = transcription.strip('-');
236
+
237
+ specialCharacters = "'!?.:,*\"()·[]/";
238
+ for character in specialCharacters:
239
+ transcription = transcription.replace(character,' ')
240
+
241
+ transcription = transcription.strip()
242
+
243
+ return transcription
244
+
245
+ perSampleMetrics = {}
246
+
247
+ matchedSum = 0
248
+
249
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
250
+
251
+ gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
252
+ subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
253
+
254
+ numGlobalCareGt = 0;
255
+ numGlobalCareDet = 0;
256
+
257
+ arrGlobalConfidences = [];
258
+ arrGlobalMatches = [];
259
+
260
+ for resFile in gt:
261
+
262
+ gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
263
+ if (gtFile is None) :
264
+ raise Exception("The file %s is not UTF-8" %resFile)
265
+
266
+ recall = 0
267
+ precision = 0
268
+ hmean = 0
269
+ detCorrect = 0
270
+ iouMat = np.empty([1,1])
271
+ gtPols = []
272
+ detPols = []
273
+ gtTrans = []
274
+ detTrans = []
275
+ gtPolPoints = []
276
+ detPolPoints = []
277
+ gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
278
+ detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
279
+ detMatchedNums = []
280
+ pairs = []
281
+
282
+ arrSampleConfidences = [];
283
+ arrSampleMatch = [];
284
+ sampleAP = 0;
285
+
286
+ evaluationLog = ""
287
+
288
+ pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
289
+ for n in range(len(pointsList)):
290
+ points = pointsList[n]
291
+ transcription = transcriptionsList[n]
292
+ dontCare = transcription == "###"
293
+ if evaluationParams['LTRB']:
294
+ gtRect = Rectangle(*points)
295
+ gtPol = rectangle_to_polygon(gtRect)
296
+ else:
297
+ gtPol = polygon_from_points(points)
298
+ gtPols.append(gtPol)
299
+ gtPolPoints.append(points)
300
+
301
+ #On word spotting we will filter some transcriptions with special characters
302
+ if evaluationParams['WORD_SPOTTING'] :
303
+ if dontCare == False :
304
+ if include_in_dictionary(transcription) == False :
305
+ dontCare = True
306
+ else:
307
+ transcription = include_in_dictionary_transcription(transcription)
308
+
309
+ gtTrans.append(transcription)
310
+ if dontCare:
311
+ gtDontCarePolsNum.append( len(gtPols)-1 )
312
+
313
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
314
+
315
+ if resFile in subm:
316
+
317
+ detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
318
+
319
+ pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
320
+
321
+ for n in range(len(pointsList)):
322
+ points = pointsList[n]
323
+ transcription = transcriptionsList[n]
324
+
325
+ if evaluationParams['LTRB']:
326
+ detRect = Rectangle(*points)
327
+ detPol = rectangle_to_polygon(detRect)
328
+ else:
329
+ detPol = polygon_from_points(points)
330
+ detPols.append(detPol)
331
+ detPolPoints.append(points)
332
+ detTrans.append(transcription)
333
+
334
+ if len(gtDontCarePolsNum)>0 :
335
+ for dontCarePol in gtDontCarePolsNum:
336
+ dontCarePol = gtPols[dontCarePol]
337
+ intersected_area = get_intersection(dontCarePol,detPol)
338
+ pdDimensions = detPol.area()
339
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
340
+ if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
341
+ detDontCarePolsNum.append( len(detPols)-1 )
342
+ break
343
+
344
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
345
+
346
+ if len(gtPols)>0 and len(detPols)>0:
347
+ #Calculate IoU and precision matrixs
348
+ outputShape=[len(gtPols),len(detPols)]
349
+ iouMat = np.empty(outputShape)
350
+ gtRectMat = np.zeros(len(gtPols),np.int8)
351
+ detRectMat = np.zeros(len(detPols),np.int8)
352
+ for gtNum in range(len(gtPols)):
353
+ for detNum in range(len(detPols)):
354
+ pG = gtPols[gtNum]
355
+ pD = detPols[detNum]
356
+ iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
357
+
358
+ for gtNum in range(len(gtPols)):
359
+ for detNum in range(len(detPols)):
360
+ if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
361
+ if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
362
+ gtRectMat[gtNum] = 1
363
+ detRectMat[detNum] = 1
364
+ #detection matched only if transcription is equal
365
+ if evaluationParams['WORD_SPOTTING']:
366
+ correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
367
+ else:
368
+ correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
369
+ detCorrect += (1 if correct else 0)
370
+ if correct:
371
+ detMatchedNums.append(detNum)
372
+ pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
373
+ evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
374
+
375
+ if evaluationParams['CONFIDENCES']:
376
+ for detNum in range(len(detPols)):
377
+ if detNum not in detDontCarePolsNum :
378
+ #we exclude the don't care detections
379
+ match = detNum in detMatchedNums
380
+
381
+ arrSampleConfidences.append(confidencesList[detNum])
382
+ arrSampleMatch.append(match)
383
+
384
+ arrGlobalConfidences.append(confidencesList[detNum]);
385
+ arrGlobalMatches.append(match);
386
+
387
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
388
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
389
+ if numGtCare == 0:
390
+ recall = float(1)
391
+ precision = float(0) if numDetCare >0 else float(1)
392
+ sampleAP = precision
393
+ else:
394
+ recall = float(detCorrect) / numGtCare
395
+ precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
396
+ if evaluationParams['CONFIDENCES']:
397
+ sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
398
+
399
+ hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
400
+
401
+ matchedSum += detCorrect
402
+ numGlobalCareGt += numGtCare
403
+ numGlobalCareDet += numDetCare
404
+
405
+ perSampleMetrics[resFile] = {
406
+ 'precision':precision,
407
+ 'recall':recall,
408
+ 'hmean':hmean,
409
+ 'pairs':pairs,
410
+ 'AP':sampleAP,
411
+ 'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
412
+ 'gtPolPoints':gtPolPoints,
413
+ 'detPolPoints':detPolPoints,
414
+ 'gtTrans':gtTrans,
415
+ 'detTrans':detTrans,
416
+ 'gtDontCare':gtDontCarePolsNum,
417
+ 'detDontCare':detDontCarePolsNum,
418
+ 'evaluationParams': evaluationParams,
419
+ 'evaluationLog': evaluationLog
420
+ }
421
+
422
+ # Compute AP
423
+ AP = 0
424
+ if evaluationParams['CONFIDENCES']:
425
+ AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
426
+
427
+ methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
428
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
429
+ methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
430
+
431
+ methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
432
+
433
+ resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
434
+
435
+
436
+ return resDict;
437
+
438
+
439
+
440
+ if __name__=='__main__':
441
+ '''
442
+ results_dir: result directory
443
+ score_det: score of detection bounding box
444
+ score_rec: score of the mask recognition branch
445
+ score_rec_seq: score of the sequence recognition branch
446
+ lexicon_type: 1 for generic; 2 for weak; 3 for strong
447
+ '''
448
+ angle = 45
449
+ results_dir = '../../../output/mixtrain/inference/rotated_ic13_test_' + str(angle) + '/model_0250000_1000_results/'
450
+ score_rec_seq = 0.9
451
+ score_rec = 0.4
452
+ score_det = 0.1
453
+ evaluate_result_path = prepare_results_for_evaluation(results_dir,
454
+ use_lexicon=False, cache_dir='./cache_files',
455
+ score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
456
+ p = {
457
+ 'g': '../gt/gt_'+str(angle)+'.zip',
458
+ 's': evaluate_result_path
459
+ }
460
+ rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
evaluation/rotated_icdar2013/gt/gt.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_-15.zip ADDED
Binary file (64.9 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_-30.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_-45.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_-60.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_-75.zip ADDED
Binary file (64.9 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_-90.zip ADDED
Binary file (59.9 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_0.zip ADDED
Binary file (59.6 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_15.zip ADDED
Binary file (65 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_30.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_45.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_60.zip ADDED
Binary file (65.2 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_75.zip ADDED
Binary file (64.9 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_85.zip ADDED
Binary file (64.4 kB). View file
 
evaluation/rotated_icdar2013/gt/gt_90.zip ADDED
Binary file (59.7 kB). View file
 
evaluation/totaltext/e2e/prepare_results.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ import sys
4
+ import os
5
+ import glob
6
+ sys.path.append('./')
7
+ import shapely
8
+ from shapely.geometry import Polygon,MultiPoint
9
+ import numpy as np
10
+ import editdistance
11
+ sys.path.append('../../')
12
+ from weighted_editdistance import weighted_edit_distance
13
+ from tqdm import tqdm
14
+ try:
15
+ import pickle
16
+ except ImportError:
17
+ import cPickle as pickle
18
+
19
+ def list_from_str(st):
20
+ line = st.split(';')
21
+ segms = line[1].split(',')
22
+ scores = line[2].split(',')
23
+ new_line = [float(a) for a in segms]+[float(scores[-4])]+[scores[-5]]+[scores[-6]]+[float(scores[-3])]+[float(scores[-2])] + [scores[-1]]
24
+ return new_line
25
+
26
+ def polygon_from_list(line):
27
+ """
28
+ Create a shapely polygon object from gt or dt line.
29
+ """
30
+ polygon_points = np.array(line).reshape(-1, 2)
31
+ polygon = Polygon(polygon_points).convex_hull
32
+ return polygon
33
+
34
+ def polygon_iou(list1, list2):
35
+ """
36
+ Intersection over union between two shapely polygons.
37
+ """
38
+ polygon_points1 = np.array(list1).reshape(-1, 2)
39
+ poly1 = Polygon(polygon_points1).convex_hull
40
+ polygon_points2 = np.array(list2).reshape(-1, 2)
41
+ poly2 = Polygon(polygon_points2).convex_hull
42
+ union_poly = np.concatenate((polygon_points1,polygon_points2))
43
+ if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
44
+ iou = 0
45
+ else:
46
+ try:
47
+ inter_area = poly1.intersection(poly2).area
48
+ #union_area = poly1.area + poly2.area - inter_area
49
+ union_area = MultiPoint(union_poly).convex_hull.area
50
+ iou = float(inter_area) / (union_area+1e-6)
51
+ except shapely.geos.TopologicalError:
52
+ print('shapely.geos.TopologicalError occured, iou set to 0')
53
+ iou = 0
54
+ return iou
55
+
56
+ def nms(boxes,overlap):
57
+ rec_scores = [b[-6] for b in boxes]
58
+ indices = sorted(range(len(rec_scores)), key=lambda k: -rec_scores[k])
59
+ box_num = len(boxes)
60
+ nms_flag = [True]*box_num
61
+ for i in range(box_num):
62
+ ii = indices[i]
63
+ if not nms_flag[ii]:
64
+ continue
65
+ for j in range(box_num):
66
+ jj = indices[j]
67
+ if j == i:
68
+ continue
69
+ if not nms_flag[jj]:
70
+ continue
71
+ box1 = boxes[ii]
72
+ box2 = boxes[jj]
73
+ box1_score = rec_scores[ii]
74
+ box2_score = rec_scores[jj]
75
+ str1 = box1[9]
76
+ str2 = box2[9]
77
+ box_i = [box1[0],box1[1],box1[4],box1[5]]
78
+ box_j = [box2[0],box2[1],box2[4],box2[5]]
79
+ poly1 = polygon_from_list(box1[0:-6])
80
+ poly2 = polygon_from_list(box2[0:-6])
81
+ iou = polygon_iou(box1[0:-6],box2[0:-6])
82
+ thresh = overlap
83
+
84
+ if iou > thresh:
85
+ if box1_score > box2_score:
86
+ nms_flag[jj] = False
87
+ if box1_score == box2_score and poly1.area > poly2.area:
88
+ nms_flag[jj] = False
89
+ if box1_score == box2_score and poly1.area<=poly2.area:
90
+ nms_flag[ii] = False
91
+ break
92
+
93
+ return nms_flag
94
+
95
+ def packing(save_dir, cache_dir, pack_name):
96
+ files = os.listdir(save_dir)
97
+ if not os.path.exists(cache_dir):
98
+ os.mkdir(cache_dir)
99
+ os.system('zip -r -q -j '+os.path.join(cache_dir, pack_name+'.zip')+' '+save_dir+'/*')
100
+
101
+ def test_single(results_dir,lexicon_type=3,cache_dir='./cache_dir',score_det=0.5,score_rec=0.5,score_rec_seq=0.5,overlap=0.2, use_lexicon=True, weighted_ed=True, use_seq=False, use_char=False, mix=False):
102
+ '''
103
+ results_dir: result directory
104
+ score_det: score of detection bounding box
105
+ score_rec: score of the mask recognition branch
106
+ socre_rec_seq: score of the sequence recognition branch
107
+ overlap: overlap threshold used for nms
108
+ lexicon_type: 1 for generic; 2 for weak; 3 for strong
109
+ use_seq: use the recognition result of sequence branch
110
+ use_mix: use both the recognition result of the mask and sequence branches, selected by score
111
+ '''
112
+ print('score_det:', 'score_det:', score_det, 'score_rec:', score_rec, 'score_rec_seq:', score_rec_seq, 'overlap:', overlap,'lexicon_type:', lexicon_type, 'weighted_ed:', weighted_ed, 'use_seq:', use_seq, 'use_char:', use_char, 'mix:', mix)
113
+ if not os.path.exists(cache_dir):
114
+ os.mkdir(cache_dir)
115
+ nms_dir = os.path.join(cache_dir,str(score_det)+'_'+str(score_rec)+'_'+str(score_rec_seq))
116
+ if not os.path.exists(nms_dir):
117
+ os.mkdir(nms_dir)
118
+ if use_lexicon and lexicon_type==2:
119
+ # weak lexicon
120
+ lexicon_path = '../../lexicons/totaltext/weak_voc_new.txt'
121
+ lexicon_fid=open(lexicon_path, 'r')
122
+ pair_list = open('../../lexicons/totaltext/weak_voc_pair_list.txt', 'r')
123
+ pairs = dict()
124
+ for line in pair_list.readlines():
125
+ line=line.strip()
126
+ word = line.split(' ')[0].upper()
127
+ word_gt = line[len(word)+1:]
128
+ pairs[word] = word_gt
129
+ lexicon_fid=open(lexicon_path, 'r')
130
+ lexicon=[]
131
+ for line in lexicon_fid.readlines():
132
+ line=line.strip()
133
+ lexicon.append(line)
134
+
135
+ for res_file in glob.glob("*.txt"):
136
+ result_path = os.path.join(results_dir,res_file)
137
+ if os.path.isfile(result_path):
138
+ with open(result_path,'r') as f:
139
+ dt_lines = [a.strip() for a in f.readlines()]
140
+ dt_lines = [list_from_str(dt) for dt in dt_lines]
141
+ else:
142
+ dt_lines = []
143
+ dt_lines = [dt for dt in dt_lines if dt[-2]>score_rec_seq and dt[-3]>score_rec and dt[-6]>score_det]
144
+ nms_flag = nms(dt_lines,overlap)
145
+ boxes = []
146
+ for k in range(len(dt_lines)):
147
+ dt = dt_lines[k]
148
+ if nms_flag[k]:
149
+ if dt not in boxes:
150
+ boxes.append(dt)
151
+
152
+ with open(os.path.join(nms_dir,'gt_'+res_file.split('.')[0].split('_')[1]+'.txt'),'w') as f:
153
+ for g in boxes:
154
+ gt_coors = [int(b) for b in g[0:-6]]
155
+ with open('../../../' + g[-1], "rb") as input_file:
156
+ dict_scores = pickle.load(input_file)
157
+ if use_char and use_seq:
158
+ if g[-2]>g[-3]:
159
+ word = g[-5]
160
+ scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
161
+ else:
162
+ word = g[-4]
163
+ scores = dict_scores['seg_char_scores']
164
+ elif use_seq:
165
+ word = g[-5]
166
+ scores = dict_scores['seq_char_scores'][:,1:-1].swapaxes(0,1)
167
+ else:
168
+ word = g[-4]
169
+ scores = dict_scores['seg_char_scores']
170
+ if not use_lexicon:
171
+ match_word = word
172
+ match_dist = 0.
173
+ else:
174
+ match_word, match_dist = find_match_word(word, pairs, scores, use_lexicon, weighted_ed, lexicon)
175
+ if match_dist<1.5 or lexicon_type==1:
176
+ gt_coor_strs = [str(a) for a in gt_coors]+ [match_word]
177
+ f.write(','.join(gt_coor_strs)+'\r\n')
178
+
179
+ pack_name = str(score_det)+'_'+str(score_rec)+'_over'+str(overlap)
180
+
181
+ packing(nms_dir,cache_dir,pack_name)
182
+ submit_file_path = os.path.join(cache_dir, pack_name+'.zip')
183
+ return submit_file_path
184
+
185
+ def find_match_word(rec_str, pairs, scores_numpy, use_ed=True, weighted_ed=False, lexicon=None):
186
+ if not use_ed:
187
+ return rec_str
188
+ rec_str = rec_str.upper()
189
+ dist_min = 100
190
+ dist_min_pre = 100
191
+ match_word = ''
192
+ match_dist = 100
193
+ if not weighted_ed:
194
+ for word in lexicon:
195
+ word = word.upper()
196
+ ed = editdistance.eval(rec_str, word)
197
+ length_dist = abs(len(word) - len(rec_str))
198
+ # dist = ed + length_dist
199
+ dist = ed
200
+ if dist<dist_min:
201
+ dist_min = dist
202
+ match_word = pairs[word]
203
+ match_dist = dist
204
+ return match_word, match_dist
205
+ else:
206
+ small_lexicon_dict = dict()
207
+ for word in lexicon:
208
+ word = word.upper()
209
+ ed = editdistance.eval(rec_str, word)
210
+ small_lexicon_dict[word] = ed
211
+ dist = ed
212
+ if dist<dist_min_pre:
213
+ dist_min_pre = dist
214
+ small_lexicon = []
215
+ for word in small_lexicon_dict:
216
+ if small_lexicon_dict[word]<=dist_min_pre+2:
217
+ small_lexicon.append(word)
218
+
219
+ for word in small_lexicon:
220
+ word = word.upper()
221
+ ed = weighted_edit_distance(rec_str, word, scores_numpy)
222
+ dist = ed
223
+ if dist<dist_min:
224
+ dist_min = dist
225
+ match_word = pairs[word]
226
+ match_dist = dist
227
+ return match_word, match_dist
228
+
229
+
230
+ def prepare_results_for_evaluation(results_dir, use_lexicon, cache_dir, score_det, score_rec, score_rec_seq):
231
+ if not os.path.isdir(cache_dir):
232
+ os.mkdir(cache_dir)
233
+ result_path = test_single(results_dir,score_det=score_det,score_rec=score_rec,score_rec_seq=score_rec_seq,overlap=0.2,cache_dir=cache_dir,lexicon_type=2, use_lexicon=use_lexicon, weighted_ed=True, use_seq=True, use_char=True, mix=True)
234
+ return result_path
evaluation/totaltext/e2e/rrc_evaluation_funcs.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ #encoding: UTF-8
3
+ import json
4
+ import sys;sys.path.append('./')
5
+ import zipfile
6
+ import re
7
+ import sys
8
+ import os
9
+ import codecs
10
+ import importlib
11
+ try:
12
+ from StringIO import StringIO
13
+ except ImportError:
14
+ from io import StringIO
15
+
16
+ def print_help():
17
+ sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> [-o=<outputFolder> -p=<jsonParams>]' %sys.argv[0])
18
+ sys.exit(2)
19
+
20
+
21
+ def load_zip_file_keys(file,fileNameRegExp=''):
22
+ """
23
+ Returns an array with the entries of the ZIP file that match with the regular expression.
24
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
25
+ """
26
+ try:
27
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
28
+ except :
29
+ raise Exception('Error loading the ZIP archive.')
30
+
31
+ pairs = []
32
+
33
+ for name in archive.namelist():
34
+ addFile = True
35
+ keyName = name
36
+ if fileNameRegExp!="":
37
+ m = re.match(fileNameRegExp,name)
38
+ if m == None:
39
+ addFile = False
40
+ else:
41
+ if len(m.groups())>0:
42
+ keyName = m.group(1)
43
+
44
+ if addFile:
45
+ pairs.append( keyName )
46
+
47
+ return pairs
48
+
49
+
50
+ def load_zip_file(file,fileNameRegExp='',allEntries=False):
51
+ """
52
+ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
53
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
54
+ allEntries validates that all entries in the ZIP file pass the fileNameRegExp
55
+ """
56
+ try:
57
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
58
+ except :
59
+ raise Exception('Error loading the ZIP archive')
60
+
61
+ pairs = []
62
+ for name in archive.namelist():
63
+ addFile = True
64
+ keyName = name
65
+ if fileNameRegExp!="":
66
+ m = re.match(fileNameRegExp,name)
67
+ if m == None:
68
+ addFile = False
69
+ else:
70
+ if len(m.groups())>0:
71
+ keyName = m.group(1)
72
+
73
+ if addFile:
74
+ pairs.append( [ keyName , archive.read(name)] )
75
+ else:
76
+ if allEntries:
77
+ raise Exception('ZIP entry not valid: %s' %name)
78
+
79
+ return dict(pairs)
80
+
81
+ def decode_utf8(raw):
82
+ """
83
+ Returns a Unicode object on success, or None on failure
84
+ """
85
+ try:
86
+ raw = codecs.decode(raw,'utf-8', 'replace')
87
+ #extracts BOM if exists
88
+ raw = raw.encode('utf8')
89
+ if raw.startswith(codecs.BOM_UTF8):
90
+ raw = raw.replace(codecs.BOM_UTF8, '', 1)
91
+ return raw.decode('utf-8')
92
+ except:
93
+ return None
94
+
95
+ def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
96
+ """
97
+ This function validates that all lines of the file calling the Line validation function for each line
98
+ """
99
+ utf8File = decode_utf8(file_contents)
100
+ if (utf8File is None) :
101
+ raise Exception("The file %s is not UTF-8" %fileName)
102
+
103
+ lines = utf8File.split( "\r\n" if CRLF else "\n" )
104
+ for line in lines:
105
+ line = line.replace("\r","").replace("\n","")
106
+ if(line != ""):
107
+ try:
108
+ validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
109
+ except Exception as e:
110
+ raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
111
+
112
+
113
+
114
+ def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
115
+ """
116
+ Validate the format of the line. If the line is not valid an exception will be raised.
117
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
118
+ Posible values are:
119
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
120
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
121
+ """
122
+ get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
123
+
124
+
125
+ def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
126
+ """
127
+ Validate the format of the line. If the line is not valid an exception will be raised.
128
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
129
+ Posible values are:
130
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
131
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
132
+ Returns values from a textline. Points , [Confidences], [Transcriptions]
133
+ """
134
+ confidence = 0.0
135
+ transcription = "";
136
+ points = []
137
+
138
+ numPoints = 4;
139
+
140
+ if LTRB:
141
+
142
+ numPoints = 4;
143
+
144
+ if withTranscription and withConfidence:
145
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
146
+ if m == None :
147
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
148
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
149
+ elif withConfidence:
150
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
151
+ if m == None :
152
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
153
+ elif withTranscription:
154
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
155
+ if m == None :
156
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
157
+ else:
158
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
159
+ if m == None :
160
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
161
+
162
+ xmin = int(m.group(1))
163
+ ymin = int(m.group(2))
164
+ xmax = int(m.group(3))
165
+ ymax = int(m.group(4))
166
+ if(xmax<xmin):
167
+ raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
168
+ if(ymax<ymin):
169
+ raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
170
+
171
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
172
+
173
+ if (imWidth>0 and imHeight>0):
174
+ validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
175
+ validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
176
+
177
+ else:
178
+
179
+ numPoints = 8;
180
+
181
+ if withTranscription and withConfidence:
182
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
183
+ if m == None :
184
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
185
+ elif withConfidence:
186
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
187
+ if m == None :
188
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
189
+ elif withTranscription:
190
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
191
+ if m == None :
192
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
193
+ else:
194
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
195
+ if m == None :
196
+ raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
197
+
198
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
199
+
200
+ validate_clockwise_points(points)
201
+
202
+ if (imWidth>0 and imHeight>0):
203
+ validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
204
+ validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
205
+ validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
206
+ validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
207
+
208
+
209
+ if withConfidence:
210
+ try:
211
+ confidence = float(m.group(numPoints+1))
212
+ except ValueError:
213
+ raise Exception("Confidence value must be a float")
214
+
215
+ if withTranscription:
216
+ posTranscription = numPoints + (2 if withConfidence else 1)
217
+ transcription = m.group(posTranscription)
218
+ m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
219
+ if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
220
+ transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
221
+
222
+ return points,confidence,transcription
223
+
224
+
225
+ def validate_point_inside_bounds(x,y,imWidth,imHeight):
226
+ if(x<0 or x>imWidth):
227
+ raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
228
+ if(y<0 or y>imHeight):
229
+ raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
230
+
231
+ def validate_clockwise_points(points):
232
+ """
233
+ Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
234
+ """
235
+
236
+ if len(points) != 8:
237
+ raise Exception("Points list not valid." + str(len(points)))
238
+
239
+ point = [
240
+ [int(points[0]) , int(points[1])],
241
+ [int(points[2]) , int(points[3])],
242
+ [int(points[4]) , int(points[5])],
243
+ [int(points[6]) , int(points[7])]
244
+ ]
245
+ edge = [
246
+ ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
247
+ ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
248
+ ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
249
+ ( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
250
+ ]
251
+
252
+ summatory = edge[0] + edge[1] + edge[2] + edge[3];
253
+ if summatory>0:
254
+ raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
255
+
256
+ def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
257
+ """
258
+ Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
259
+ xmin,ymin,xmax,ymax,[confidence],[transcription]
260
+ x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
261
+ """
262
+ pointsList = []
263
+ transcriptionsList = []
264
+ confidencesList = []
265
+
266
+ lines = content.split( "\r\n" if CRLF else "\n" )
267
+ for line in lines:
268
+ line = line.replace("\r","").replace("\n","")
269
+ if(line != "") :
270
+ points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
271
+ pointsList.append(points)
272
+ transcriptionsList.append(transcription)
273
+ confidencesList.append(confidence)
274
+
275
+ if withConfidence and len(confidencesList)>0 and sort_by_confidences:
276
+ import numpy as np
277
+ sorted_ind = np.argsort(-np.array(confidencesList))
278
+ confidencesList = [confidencesList[i] for i in sorted_ind]
279
+ pointsList = [pointsList[i] for i in sorted_ind]
280
+ transcriptionsList = [transcriptionsList[i] for i in sorted_ind]
281
+
282
+ return pointsList,confidencesList,transcriptionsList
283
+
284
+ def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
285
+ """
286
+ This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
287
+ Params:
288
+ p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
289
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
290
+ validate_data_fn: points to a method that validates the corrct format of the submission
291
+ evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
292
+ """
293
+
294
+ if (p == None):
295
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
296
+ if(len(sys.argv)<3):
297
+ print_help()
298
+
299
+ evalParams = default_evaluation_params_fn()
300
+ if 'p' in p.keys():
301
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
302
+
303
+ resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
304
+ try:
305
+ validate_data_fn(p['g'], p['s'], evalParams)
306
+ evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
307
+ resDict.update(evalData)
308
+
309
+ except Exception as e:
310
+ resDict['Message']= str(e)
311
+ resDict['calculated']=False
312
+
313
+ if 'o' in p:
314
+ if not os.path.exists(p['o']):
315
+ os.makedirs(p['o'])
316
+
317
+ resultsOutputname = p['o'] + '/results.zip'
318
+ outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
319
+
320
+ del resDict['per_sample']
321
+ if 'output_items' in resDict.keys():
322
+ del resDict['output_items']
323
+
324
+ outZip.writestr('method.json',json.dumps(resDict))
325
+
326
+ if not resDict['calculated']:
327
+ if show_result:
328
+ sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
329
+ if 'o' in p:
330
+ outZip.close()
331
+ return resDict
332
+
333
+ if 'o' in p:
334
+ if per_sample == True:
335
+ for k,v in evalData['per_sample'].items():
336
+ outZip.writestr( k + '.json',json.dumps(v))
337
+
338
+ if 'output_items' in evalData.keys():
339
+ for k, v in evalData['output_items'].items():
340
+ outZip.writestr( k,v)
341
+
342
+ outZip.close()
343
+
344
+ if show_result:
345
+ sys.stdout.write("Calculated!")
346
+ sys.stdout.write(json.dumps(resDict['method']))
347
+
348
+ return resDict
349
+
350
+
351
+ def main_validation(default_evaluation_params_fn,validate_data_fn):
352
+ """
353
+ This process validates a method
354
+ Params:
355
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
356
+ validate_data_fn: points to a method that validates the corrct format of the submission
357
+ """
358
+ try:
359
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
360
+ evalParams = default_evaluation_params_fn()
361
+ if 'p' in p.keys():
362
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
363
+
364
+ validate_data_fn(p['g'], p['s'], evalParams)
365
+ print('SUCCESS')
366
+ sys.exit(0)
367
+ except Exception as e:
368
+ print(str(e))
369
+ sys.exit(101)
evaluation/totaltext/e2e/rrc_evaluation_funcs_total_text.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python2
2
+ #encoding: UTF-8
3
+ import json
4
+ import sys;sys.path.append('./')
5
+ import zipfile
6
+ import re
7
+ import sys
8
+ import os
9
+ import codecs
10
+ import importlib
11
+ from io import StringIO
12
+
13
+ def print_help():
14
+ sys.stdout.write('Usage: python %s.py -g=<gtFile> -s=<submFile> -o=<outputFolder> [-i=<gtImagesFile> -p=<jsonParams>]' %sys.argv[0])
15
+ sys.exit(2)
16
+
17
+
18
+ def load_zip_file_keys(file,fileNameRegExp=''):
19
+ """
20
+ Returns an array with the entries of the ZIP file that match with the regular expression.
21
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
22
+ """
23
+ try:
24
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
25
+ except :
26
+ raise Exception('Error loading the ZIP archive.')
27
+
28
+ pairs = []
29
+
30
+ for name in archive.namelist():
31
+ addFile = True
32
+ keyName = name
33
+ # if fileNameRegExp!="":
34
+ # m = re.match(fileNameRegExp,name)
35
+ # if m == None:
36
+ # addFile = False
37
+ # else:
38
+ # if len(m.groups())>0:
39
+ # keyName = m.group(1)
40
+
41
+ if addFile:
42
+ pairs.append( keyName )
43
+
44
+ return pairs
45
+
46
+
47
+ def load_zip_file(file,fileNameRegExp='',allEntries=False):
48
+ """
49
+ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file.
50
+ The key's are the names or the file or the capturing group definied in the fileNameRegExp
51
+ allEntries validates that all entries in the ZIP file pass the fileNameRegExp
52
+ """
53
+ try:
54
+ archive=zipfile.ZipFile(file, mode='r', allowZip64=True)
55
+ except :
56
+ raise Exception('Error loading the ZIP archive')
57
+
58
+ pairs = []
59
+ for name in archive.namelist():
60
+ addFile = True
61
+ keyName = name
62
+ # if fileNameRegExp!="":
63
+ # m = re.match(fileNameRegExp,name)
64
+ # if m == None:
65
+ # addFile = False
66
+ # else:
67
+ # if len(m.groups())>0:
68
+ # keyName = m.group(1)
69
+
70
+ if addFile:
71
+ pairs.append( [ keyName , archive.read(name)] )
72
+ else:
73
+ if allEntries:
74
+ raise Exception('ZIP entry not valid: %s' %name)
75
+
76
+ return dict(pairs)
77
+
78
+ def decode_utf8(raw):
79
+ """
80
+ Returns a Unicode object on success, or None on failure
81
+ """
82
+ try:
83
+ raw = codecs.decode(raw,'utf-8', 'replace')
84
+ #extracts BOM if exists
85
+ raw = raw.encode('utf8')
86
+ if raw.startswith(codecs.BOM_UTF8):
87
+ raw = raw.replace(codecs.BOM_UTF8, '', 1)
88
+ return raw.decode('utf-8')
89
+ except:
90
+ return None
91
+
92
+ def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
93
+ """
94
+ This function validates that all lines of the file calling the Line validation function for each line
95
+ """
96
+ utf8File = decode_utf8(file_contents)
97
+ if (utf8File is None) :
98
+ raise Exception("The file %s is not UTF-8" %fileName)
99
+
100
+ lines = utf8File.split( "\r\n" if CRLF else "\n" )
101
+ for line in lines:
102
+ line = line.replace("\r","").replace("\n","")
103
+ if(line != ""):
104
+ try:
105
+ validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
106
+ except Exception as e:
107
+ raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace'))
108
+
109
+
110
+
111
+ def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0):
112
+ """
113
+ Validate the format of the line. If the line is not valid an exception will be raised.
114
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
115
+ Posible values are:
116
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
117
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
118
+ """
119
+ get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight)
120
+
121
+
122
+ def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0):
123
+ """
124
+ Validate the format of the line. If the line is not valid an exception will be raised.
125
+ If maxWidth and maxHeight are specified, all points must be inside the imgage bounds.
126
+ Posible values are:
127
+ LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription]
128
+ LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription]
129
+ Returns values from a textline. Points , [Confidences], [Transcriptions]
130
+ """
131
+ confidence = 0.0
132
+ transcription = "";
133
+ points = []
134
+
135
+ numPoints = 4;
136
+ if LTRB:
137
+
138
+ numPoints = 4;
139
+
140
+ if withTranscription and withConfidence:
141
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
142
+ if m == None :
143
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
144
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription")
145
+ elif withConfidence:
146
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
147
+ if m == None :
148
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence")
149
+ elif withTranscription:
150
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line)
151
+ if m == None :
152
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription")
153
+ else:
154
+ m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line)
155
+ if m == None :
156
+ raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax")
157
+
158
+ xmin = int(m.group(1))
159
+ ymin = int(m.group(2))
160
+ xmax = int(m.group(3))
161
+ ymax = int(m.group(4))
162
+ if(xmax<xmin):
163
+ raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." %(xmax))
164
+ if(ymax<ymin):
165
+ raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." %(ymax))
166
+
167
+ points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
168
+
169
+ if (imWidth>0 and imHeight>0):
170
+ validate_point_inside_bounds(xmin,ymin,imWidth,imHeight);
171
+ validate_point_inside_bounds(xmax,ymax,imWidth,imHeight);
172
+
173
+ else:
174
+ line_split = line.split(',')
175
+ # print(line_split)
176
+ numPoints = int((len(line_split) - 1) / 2)
177
+ points = [ float(line_split[i]) for i in range(2 * numPoints) ]
178
+ # print(points)
179
+ transcription = line_split[-1]
180
+ # numPoints = 8;
181
+
182
+ # if withTranscription and withConfidence:
183
+ # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line)
184
+ # if m == None :
185
+ # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription")
186
+ # elif withConfidence:
187
+ # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line)
188
+ # if m == None :
189
+ # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence")
190
+ # elif withTranscription:
191
+ # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line)
192
+ # if m == None :
193
+ # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription")
194
+ # else:
195
+ # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line)
196
+ # if m == None :
197
+ # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4")
198
+
199
+ # points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ]
200
+
201
+ # validate_clockwise_points(points)
202
+
203
+ # if (imWidth>0 and imHeight>0):
204
+ # validate_point_inside_bounds(points[0],points[1],imWidth,imHeight);
205
+ # validate_point_inside_bounds(points[2],points[3],imWidth,imHeight);
206
+ # validate_point_inside_bounds(points[4],points[5],imWidth,imHeight);
207
+ # validate_point_inside_bounds(points[6],points[7],imWidth,imHeight);
208
+
209
+
210
+ # if withConfidence:
211
+ # try:
212
+ # confidence = float(m.group(numPoints+1))
213
+ # except ValueError:
214
+ # raise Exception("Confidence value must be a float")
215
+
216
+ # if withTranscription:
217
+ # posTranscription = numPoints + (2 if withConfidence else 1)
218
+ # transcription = m.group(posTranscription)
219
+ # m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription)
220
+ # if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters
221
+ # transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"")
222
+
223
+ return points,confidence,transcription
224
+
225
+
226
+ def validate_point_inside_bounds(x,y,imWidth,imHeight):
227
+ if(x<0 or x>imWidth):
228
+ raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight))
229
+ if(y<0 or y>imHeight):
230
+ raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight))
231
+
232
+ def validate_clockwise_points(points):
233
+ """
234
+ Validates that the points that the 4 points that dlimite a polygon are in clockwise order.
235
+ """
236
+
237
+ if len(points) != 8:
238
+ raise Exception("Points list not valid." + str(len(points)))
239
+
240
+ point = [
241
+ [int(points[0]) , int(points[1])],
242
+ [int(points[2]) , int(points[3])],
243
+ [int(points[4]) , int(points[5])],
244
+ [int(points[6]) , int(points[7])]
245
+ ]
246
+ edge = [
247
+ ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]),
248
+ ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]),
249
+ ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]),
250
+ ( point[0][0] - point[3][0])*( point[0][1] + point[3][1])
251
+ ]
252
+
253
+ summatory = edge[0] + edge[1] + edge[2] + edge[3];
254
+ if summatory>0:
255
+ raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.")
256
+
257
+ def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True):
258
+ """
259
+ Returns all points, confindences and transcriptions of a file in lists. Valid line formats:
260
+ xmin,ymin,xmax,ymax,[confidence],[transcription]
261
+ x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription]
262
+ """
263
+ pointsList = []
264
+ transcriptionsList = []
265
+ confidencesList = []
266
+
267
+ lines = content.split( "\r\n" if CRLF else "\n" )
268
+ for line in lines:
269
+ line = line.replace("\r","").replace("\n","")
270
+ if(line != "") :
271
+ points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight);
272
+ pointsList.append(points)
273
+ transcriptionsList.append(transcription)
274
+ confidencesList.append(confidence)
275
+
276
+ if withConfidence and len(confidencesList)>0 and sort_by_confidences:
277
+ confidencesList, pointsList,transcriptionsList = (list(t) for t in zip(*sorted(zip(confidencesList, pointsList, transcriptionsList), reverse=True)))
278
+
279
+ return pointsList,confidencesList,transcriptionsList
280
+
281
+ def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True):
282
+ """
283
+ This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample.
284
+ Params:
285
+ p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used.
286
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
287
+ validate_data_fn: points to a method that validates the corrct format of the submission
288
+ evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results
289
+ """
290
+
291
+ if (p == None):
292
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
293
+ if(len(sys.argv)<2):
294
+ print_help()
295
+
296
+ evalParams = default_evaluation_params_fn()
297
+ if 'p' in list(p.keys()):
298
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
299
+
300
+ resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'}
301
+ try:
302
+ validate_data_fn(p['g'], p['s'], evalParams)
303
+ evalData = evaluate_method_fn(p['g'], p['s'], evalParams)
304
+ resDict.update(evalData)
305
+
306
+ except Exception as e:
307
+ resDict['Message']= str(e)
308
+ resDict['calculated']=False
309
+
310
+ if not os.path.exists(p['o']):
311
+ os.makedirs(p['o'])
312
+
313
+ resultsOutputname = p['o'] + '/results.zip'
314
+ outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True)
315
+
316
+ del resDict['per_sample']
317
+ if 'output_items' in list(resDict.keys()):
318
+ del resDict['output_items']
319
+
320
+ outZip.writestr('method.json',json.dumps(resDict))
321
+
322
+ if not resDict['calculated']:
323
+ if show_result:
324
+ sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n')
325
+ outZip.close()
326
+ return resDict
327
+
328
+ if per_sample == True:
329
+ for k,v in evalData['per_sample'].items():
330
+ outZip.writestr( k + '.json',json.dumps(v))
331
+
332
+ if 'output_items' in list(evalData.keys()):
333
+ for k, v in evalData['output_items'].items():
334
+ outZip.writestr( k,v)
335
+
336
+ outZip.close()
337
+
338
+ if show_result:
339
+ sys.stdout.write("Calculated!")
340
+ sys.stdout.write(json.dumps(resDict['method']))
341
+
342
+ return resDict
343
+
344
+
345
+ def main_validation(default_evaluation_params_fn,validate_data_fn):
346
+ """
347
+ This process validates a method
348
+ Params:
349
+ default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation
350
+ validate_data_fn: points to a method that validates the corrct format of the submission
351
+ """
352
+ try:
353
+ p = dict([s[1:].split('=') for s in sys.argv[1:]])
354
+ evalParams = default_evaluation_params_fn()
355
+ if 'p' in list(p.keys()):
356
+ evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) )
357
+
358
+ validate_data_fn(p['g'], p['s'], evalParams)
359
+ print('SUCCESS')
360
+ sys.exit(0)
361
+ except Exception as e:
362
+ print(str(e))
363
+ sys.exit(101)
evaluation/totaltext/e2e/script.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # encoding=utf8
4
+ from collections import namedtuple
5
+ import rrc_evaluation_funcs_total_text as rrc_evaluation_funcs
6
+ import importlib
7
+ from prepare_results import prepare_results_for_evaluation
8
+
9
+ def evaluation_imports():
10
+ """
11
+ evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
12
+ """
13
+ return {
14
+ 'Polygon':'plg',
15
+ 'numpy':'np'
16
+ }
17
+
18
+ def default_evaluation_params():
19
+ """
20
+ default_evaluation_params: Default parameters to use for the validation and evaluation.
21
+ """
22
+ return {
23
+ 'IOU_CONSTRAINT' :0.5,
24
+ 'AREA_PRECISION_CONSTRAINT' :0.5,
25
+ 'WORD_SPOTTING' :False,
26
+ 'MIN_LENGTH_CARE_WORD' :3,
27
+ 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt',
28
+ 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt',
29
+ 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
30
+ 'CRLF':False, # Lines are delimited by Windows CRLF format
31
+ 'CONFIDENCES':False, #Detections must include confidence value. MAP and MAR will be calculated,
32
+ 'SPECIAL_CHARACTERS':'!?.:,*"()·[]/\'',
33
+ 'ONLY_REMOVE_FIRST_LAST_CHARACTER' : True
34
+ }
35
+
36
+ def validate_data(gtFilePath, submFilePath, evaluationParams):
37
+ """
38
+ Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
39
+ Validates also that there are no missing files in the folder.
40
+ If some error detected, the method raises the error
41
+ """
42
+ gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
43
+
44
+ subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
45
+
46
+ #Validate format of GroundTruth
47
+ for k in gt:
48
+ rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True)
49
+
50
+ #Validate format of results
51
+ for k in subm:
52
+ if (k in gt) == False :
53
+ raise Exception("The sample %s not present in GT" %k)
54
+
55
+ rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
56
+
57
+
58
+ def evaluate_method(gtFilePath, submFilePath, evaluationParams):
59
+ """
60
+ Method evaluate_method: evaluate method and returns the results
61
+ Results. Dictionary with the following values:
62
+ - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
63
+ - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
64
+ """
65
+ for module,alias in evaluation_imports().items():
66
+ globals()[alias] = importlib.import_module(module)
67
+
68
+ def polygon_from_points(points,correctOffset=False):
69
+ """
70
+ Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
71
+ """
72
+ resBoxes=np.empty([1,len(points)],dtype='int32')
73
+ for i in range(int(len(points) / 2)):
74
+ resBoxes[0, i] = int(points[2*i])
75
+ resBoxes[0, int(len(points) / 2) + i] = int(points[2*i+1])
76
+
77
+ pointMat = resBoxes[0].reshape([2,-1]).T
78
+ return plg.Polygon( pointMat)
79
+
80
+ def rectangle_to_polygon(rect):
81
+ resBoxes=np.empty([1,8],dtype='int32')
82
+ resBoxes[0,0]=int(rect.xmin)
83
+ resBoxes[0,4]=int(rect.ymax)
84
+ resBoxes[0,1]=int(rect.xmin)
85
+ resBoxes[0,5]=int(rect.ymin)
86
+ resBoxes[0,2]=int(rect.xmax)
87
+ resBoxes[0,6]=int(rect.ymin)
88
+ resBoxes[0,3]=int(rect.xmax)
89
+ resBoxes[0,7]=int(rect.ymax)
90
+
91
+ pointMat = resBoxes[0].reshape([2,4]).T
92
+
93
+ return plg.Polygon( pointMat)
94
+
95
+ def rectangle_to_points(rect):
96
+ points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)]
97
+ return points
98
+
99
+ def get_union(pD,pG):
100
+ areaA = pD.area();
101
+ areaB = pG.area();
102
+ return areaA + areaB - get_intersection(pD, pG);
103
+
104
+ def get_intersection_over_union(pD,pG):
105
+ try:
106
+ return get_intersection(pD, pG) / get_union(pD, pG);
107
+ except:
108
+ return 0
109
+
110
+ def get_intersection(pD,pG):
111
+ pInt = pD & pG
112
+ if len(pInt) == 0:
113
+ return 0
114
+ return pInt.area()
115
+
116
+ def compute_ap(confList, matchList,numGtCare):
117
+ correct = 0
118
+ AP = 0
119
+ if len(confList)>0:
120
+ confList = np.array(confList)
121
+ matchList = np.array(matchList)
122
+ sorted_ind = np.argsort(-confList)
123
+ confList = confList[sorted_ind]
124
+ matchList = matchList[sorted_ind]
125
+ for n in range(len(confList)):
126
+ match = matchList[n]
127
+ if match:
128
+ correct += 1
129
+ AP += float(correct)/(n + 1)
130
+
131
+ if numGtCare>0:
132
+ AP /= numGtCare
133
+
134
+ return AP
135
+
136
+ def transcription_match(transGt,transDet,specialCharacters='!?.:,*"()·[]/\'',onlyRemoveFirstLastCharacterGT=True):
137
+
138
+ if onlyRemoveFirstLastCharacterGT:
139
+ #special characters in GT are allowed only at initial or final position
140
+ if (transGt==transDet):
141
+ return True
142
+
143
+ if specialCharacters.find(transGt[0])>-1:
144
+ if transGt[1:]==transDet:
145
+ return True
146
+
147
+ if specialCharacters.find(transGt[-1])>-1:
148
+ if transGt[0:len(transGt)-1]==transDet:
149
+ return True
150
+
151
+ if specialCharacters.find(transGt[0])>-1 and specialCharacters.find(transGt[-1])>-1:
152
+ if transGt[1:len(transGt)-1]==transDet:
153
+ return True
154
+ return False
155
+ else:
156
+ #Special characters are removed from the begining and the end of both Detection and GroundTruth
157
+ while len(transGt)>0 and specialCharacters.find(transGt[0])>-1:
158
+ transGt = transGt[1:]
159
+
160
+ while len(transDet)>0 and specialCharacters.find(transDet[0])>-1:
161
+ transDet = transDet[1:]
162
+
163
+ while len(transGt)>0 and specialCharacters.find(transGt[-1])>-1 :
164
+ transGt = transGt[0:len(transGt)-1]
165
+
166
+ while len(transDet)>0 and specialCharacters.find(transDet[-1])>-1:
167
+ transDet = transDet[0:len(transDet)-1]
168
+
169
+ return transGt == transDet
170
+
171
+
172
+ def include_in_dictionary(transcription):
173
+ """
174
+ Function used in Word Spotting that finds if the Ground Truth transcription meets the rules to enter into the dictionary. If not, the transcription will be cared as don't care
175
+ """
176
+ #special case 's at final
177
+ if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
178
+ transcription = transcription[0:len(transcription)-2]
179
+
180
+ #hypens at init or final of the word
181
+ transcription = transcription.strip('-');
182
+
183
+ specialCharacters = "'!?.:,*\"()·[]/";
184
+ for character in specialCharacters:
185
+ transcription = transcription.replace(character,' ')
186
+
187
+ transcription = transcription.strip()
188
+
189
+ if len(transcription) != len(transcription.replace(" ","")) :
190
+ return False;
191
+
192
+ if len(transcription) < evaluationParams['MIN_LENGTH_CARE_WORD']:
193
+ return False;
194
+
195
+ notAllowed = "×÷·";
196
+
197
+ range1 = [ ord(u'a'), ord(u'z') ]
198
+ range2 = [ ord(u'A'), ord(u'Z') ]
199
+ range3 = [ ord(u'À'), ord(u'ƿ') ]
200
+ range4 = [ ord(u'DŽ'), ord(u'ɿ') ]
201
+ range5 = [ ord(u'Ά'), ord(u'Ͽ') ]
202
+ range6 = [ ord(u'-'), ord(u'-') ]
203
+
204
+ for char in transcription :
205
+ charCode = ord(char)
206
+ if(notAllowed.find(char) != -1):
207
+ return False
208
+
209
+ valid = ( charCode>=range1[0] and charCode<=range1[1] ) or ( charCode>=range2[0] and charCode<=range2[1] ) or ( charCode>=range3[0] and charCode<=range3[1] ) or ( charCode>=range4[0] and charCode<=range4[1] ) or ( charCode>=range5[0] and charCode<=range5[1] ) or ( charCode>=range6[0] and charCode<=range6[1] )
210
+ if valid == False:
211
+ return False
212
+
213
+ return True
214
+
215
+ def include_in_dictionary_transcription(transcription):
216
+ """
217
+ Function applied to the Ground Truth transcriptions used in Word Spotting. It removes special characters or terminations
218
+ """
219
+ #special case 's at final
220
+ if transcription[len(transcription)-2:]=="'s" or transcription[len(transcription)-2:]=="'S":
221
+ transcription = transcription[0:len(transcription)-2]
222
+
223
+ #hypens at init or final of the word
224
+ transcription = transcription.strip('-');
225
+
226
+ specialCharacters = "'!?.:,*\"()·[]/";
227
+ for character in specialCharacters:
228
+ transcription = transcription.replace(character,' ')
229
+
230
+ transcription = transcription.strip()
231
+
232
+ return transcription
233
+
234
+ perSampleMetrics = {}
235
+
236
+ matchedSum = 0
237
+
238
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
239
+
240
+ gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID'])
241
+ subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True)
242
+
243
+ numGlobalCareGt = 0;
244
+ numGlobalCareDet = 0;
245
+
246
+ arrGlobalConfidences = [];
247
+ arrGlobalMatches = [];
248
+
249
+ for resFile in gt:
250
+
251
+ gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile])
252
+ if (gtFile is None) :
253
+ raise Exception("The file %s is not UTF-8" %resFile)
254
+
255
+ recall = 0
256
+ precision = 0
257
+ hmean = 0
258
+ detCorrect = 0
259
+ iouMat = np.empty([1,1])
260
+ gtPols = []
261
+ detPols = []
262
+ gtTrans = []
263
+ detTrans = []
264
+ gtPolPoints = []
265
+ detPolPoints = []
266
+ gtDontCarePolsNum = [] #Array of Ground Truth Polygons' keys marked as don't Care
267
+ detDontCarePolsNum = [] #Array of Detected Polygons' matched with a don't Care GT
268
+ detMatchedNums = []
269
+ pairs = []
270
+
271
+ arrSampleConfidences = [];
272
+ arrSampleMatch = [];
273
+ sampleAP = 0;
274
+
275
+ evaluationLog = ""
276
+
277
+ pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False)
278
+ for n in range(len(pointsList)):
279
+ points = pointsList[n]
280
+ transcription = transcriptionsList[n]
281
+ dontCare = transcription == "###"
282
+ if evaluationParams['LTRB']:
283
+ gtRect = Rectangle(*points)
284
+ gtPol = rectangle_to_polygon(gtRect)
285
+ else:
286
+ gtPol = polygon_from_points(points)
287
+ gtPols.append(gtPol)
288
+ gtPolPoints.append(points)
289
+
290
+ #On word spotting we will filter some transcriptions with special characters
291
+ if evaluationParams['WORD_SPOTTING'] :
292
+ if dontCare == False :
293
+ if include_in_dictionary(transcription) == False :
294
+ dontCare = True
295
+ else:
296
+ transcription = include_in_dictionary_transcription(transcription)
297
+
298
+ gtTrans.append(transcription)
299
+ if dontCare:
300
+ gtDontCarePolsNum.append( len(gtPols)-1 )
301
+
302
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n")
303
+
304
+ if resFile in subm:
305
+
306
+ detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile])
307
+
308
+ pointsList,confidencesList,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,evaluationParams['CONFIDENCES'])
309
+
310
+ for n in range(len(pointsList)):
311
+ points = pointsList[n]
312
+ transcription = transcriptionsList[n]
313
+
314
+ if evaluationParams['LTRB']:
315
+ detRect = Rectangle(*points)
316
+ detPol = rectangle_to_polygon(detRect)
317
+ else:
318
+ detPol = polygon_from_points(points)
319
+ detPols.append(detPol)
320
+ detPolPoints.append(points)
321
+ detTrans.append(transcription)
322
+
323
+ if len(gtDontCarePolsNum)>0 :
324
+ for dontCarePol in gtDontCarePolsNum:
325
+ dontCarePol = gtPols[dontCarePol]
326
+ intersected_area = get_intersection(dontCarePol,detPol)
327
+ pdDimensions = detPol.area()
328
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
329
+ if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ):
330
+ detDontCarePolsNum.append( len(detPols)-1 )
331
+ break
332
+
333
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n")
334
+
335
+ if len(gtPols)>0 and len(detPols)>0:
336
+ #Calculate IoU and precision matrixs
337
+ outputShape=[len(gtPols),len(detPols)]
338
+ iouMat = np.empty(outputShape)
339
+ gtRectMat = np.zeros(len(gtPols),np.int8)
340
+ detRectMat = np.zeros(len(detPols),np.int8)
341
+ for gtNum in range(len(gtPols)):
342
+ for detNum in range(len(detPols)):
343
+ pG = gtPols[gtNum]
344
+ pD = detPols[detNum]
345
+ iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG)
346
+
347
+ for gtNum in range(len(gtPols)):
348
+ for detNum in range(len(detPols)):
349
+ if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum :
350
+ if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']:
351
+ gtRectMat[gtNum] = 1
352
+ detRectMat[detNum] = 1
353
+ #detection matched only if transcription is equal
354
+ if evaluationParams['WORD_SPOTTING']:
355
+ correct = gtTrans[gtNum].upper() == detTrans[detNum].upper()
356
+ else:
357
+ correct = transcription_match(gtTrans[gtNum].upper(),detTrans[detNum].upper(),evaluationParams['SPECIAL_CHARACTERS'],evaluationParams['ONLY_REMOVE_FIRST_LAST_CHARACTER'])==True
358
+ detCorrect += (1 if correct else 0)
359
+ if correct:
360
+ detMatchedNums.append(detNum)
361
+ pairs.append({'gt':gtNum,'det':detNum,'correct':correct})
362
+ evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + " trans. correct: " + str(correct) + "\n"
363
+
364
+ if evaluationParams['CONFIDENCES']:
365
+ for detNum in range(len(detPols)):
366
+ if detNum not in detDontCarePolsNum :
367
+ #we exclude the don't care detections
368
+ match = detNum in detMatchedNums
369
+
370
+ arrSampleConfidences.append(confidencesList[detNum])
371
+ arrSampleMatch.append(match)
372
+
373
+ arrGlobalConfidences.append(confidencesList[detNum]);
374
+ arrGlobalMatches.append(match);
375
+
376
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
377
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
378
+ if numGtCare == 0:
379
+ recall = float(1)
380
+ precision = float(0) if numDetCare >0 else float(1)
381
+ sampleAP = precision
382
+ else:
383
+ recall = float(detCorrect) / numGtCare
384
+ precision = 0 if numDetCare==0 else float(detCorrect) / numDetCare
385
+ if evaluationParams['CONFIDENCES']:
386
+ sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare )
387
+
388
+ hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall)
389
+
390
+ matchedSum += detCorrect
391
+ numGlobalCareGt += numGtCare
392
+ numGlobalCareDet += numDetCare
393
+
394
+ perSampleMetrics[resFile] = {
395
+ 'precision':precision,
396
+ 'recall':recall,
397
+ 'hmean':hmean,
398
+ 'pairs':pairs,
399
+ 'AP':sampleAP,
400
+ 'iouMat':[] if len(detPols)>100 else iouMat.tolist(),
401
+ 'gtPolPoints':gtPolPoints,
402
+ 'detPolPoints':detPolPoints,
403
+ 'gtTrans':gtTrans,
404
+ 'detTrans':detTrans,
405
+ 'gtDontCare':gtDontCarePolsNum,
406
+ 'detDontCare':detDontCarePolsNum,
407
+ 'evaluationParams': evaluationParams,
408
+ 'evaluationLog': evaluationLog
409
+ }
410
+
411
+ # Compute AP
412
+ AP = 0
413
+ if evaluationParams['CONFIDENCES']:
414
+ AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
415
+
416
+ methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt
417
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet
418
+ methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision)
419
+
420
+ methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP }
421
+
422
+ resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics}
423
+
424
+
425
+ return resDict;
426
+
427
+
428
+
429
+ if __name__=='__main__':
430
+ '''
431
+ results_dir: result directory
432
+ score_det: score of detection bounding box
433
+ score_rec: score of the mask recognition branch
434
+ score_rec_seq: score of the sequence recognition branch
435
+ lexicon_type: 1 for generic; 2 for weak; 3 for strong
436
+ '''
437
+ results_dir = '../../../output/mixtrain/inference/total_text_test/model_0250000_1000_results/'
438
+ score_det = 0.05
439
+ score_rec = 0.5
440
+ use_lexicon = False
441
+ score_rec_seq = 0.9
442
+ # use_lexicon = True
443
+ # score_rec_seq = 0.8
444
+ evaluate_result_path = prepare_results_for_evaluation(results_dir,
445
+ use_lexicon=use_lexicon, cache_dir='./cache_files',
446
+ score_det=score_det, score_rec=score_rec, score_rec_seq=score_rec_seq)
447
+ p = {
448
+ 'g': "../gt.zip",
449
+ 'o': "./cache_files",
450
+ 's': evaluate_result_path
451
+ }
452
+ rrc_evaluation_funcs.main_evaluation(p,default_evaluation_params,validate_data,evaluate_method)
evaluation/totaltext/gt.zip ADDED
Binary file (106 kB). View file
 
evaluation/weighted_editdistance.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def weighted_edit_distance(word1, word2, scores):
2
+ m = len(word1)
3
+ n = len(word2)
4
+ dp = [[0 for __ in range(m + 1)] for __ in range(n + 1)]
5
+ for j in range(m + 1):
6
+ dp[0][j] = j
7
+ for i in range(n + 1):
8
+ dp[i][0] = i
9
+ for i in range(1, n + 1): ## word2
10
+ for j in range(1, m + 1): ## word1
11
+ delect_cost = ed_delect_cost(j-1, i-1, word1, word2, scores) ## delect a[i]
12
+ insert_cost = ed_insert_cost(j-1, i-1, word1, word2, scores) ## insert b[j]
13
+ if word1[j - 1] != word2[i - 1]:
14
+ replace_cost = ed_replace_cost(j-1, i-1, word1, word2, scores) ## replace a[i] with b[j]
15
+ else:
16
+ replace_cost = 0
17
+ dp[i][j] = min(dp[i-1][j] + insert_cost, dp[i][j-1] + delect_cost, dp[i-1][j-1] + replace_cost)
18
+
19
+ return dp[n][m]
20
+
21
+ def ed_delect_cost(j, i, word1, word2, scores):
22
+ ## delect a[i]
23
+ c = char2num(word1[j])
24
+ return scores[c][j]
25
+
26
+
27
+ def ed_insert_cost(i, j, word1, word2, scores):
28
+ ## insert b[j]
29
+ if i < len(word1) - 1:
30
+ c1 = char2num(word1[i])
31
+ c2 = char2num(word1[i + 1])
32
+ return (scores[c1][i] + scores[c2][i+1])/2
33
+ else:
34
+ c1 = char2num(word1[i])
35
+ return scores[c1][i]
36
+
37
+
38
+ def ed_replace_cost(i, j, word1, word2, scores):
39
+ ## replace a[i] with b[j]
40
+ c1 = char2num(word1[i])
41
+ c2 = char2num(word2[j])
42
+ # if word1 == "eeatpisaababarait".upper():
43
+ # print(scores[c2][i]/scores[c1][i])
44
+
45
+ return max(1 - scores[c2][i]/scores[c1][i]*5, 0)
46
+
47
+ def char2num(char):
48
+ if char in '0123456789':
49
+ num = ord(char) - ord('0') + 1
50
+ elif char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
51
+ num = ord(char.lower()) - ord('a') + 11
52
+ else:
53
+ print('error symbol', char)
54
+ exit()
55
+ return num - 1
example1.jpg ADDED
example2.jpg ADDED
example3.jpg ADDED
maskrcnn_benchmark/config/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ from .defaults import _C as cfg
maskrcnn_benchmark/config/defaults.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+ import os
4
+
5
+ from yacs.config import CfgNode as CN
6
+
7
+
8
+ # -----------------------------------------------------------------------------
9
+ # Convention about Training / Test specific parameters
10
+ # -----------------------------------------------------------------------------
11
+ # Whenever an argument can be either used for training or for testing, the
12
+ # corresponding name will be post-fixed by a _TRAIN for a training parameter,
13
+ # or _TEST for a test-specific parameter.
14
+ # For example, the number of images during training will be
15
+ # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
16
+ # IMAGES_PER_BATCH_TEST
17
+
18
+ # -----------------------------------------------------------------------------
19
+ # Config definition
20
+ # -----------------------------------------------------------------------------
21
+
22
+ _C = CN()
23
+
24
+ _C.MODEL = CN()
25
+ _C.MODEL.RPN_ONLY = False
26
+ _C.MODEL.MASK_ON = False
27
+ _C.MODEL.SEG_ON = False
28
+ _C.MODEL.CHAR_MASK_ON = False
29
+ _C.MODEL.DEVICE = "cuda"
30
+ _C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
31
+ _C.MODEL.TRAIN_DETECTION_ONLY = False
32
+ _C.MODEL.RESNET34 = False
33
+
34
+ # If the WEIGHT starts with a catalog://, like :R-50, the code will look for
35
+ # the path in paths_catalog. Else, it will use it as the specified absolute
36
+ # path
37
+ _C.MODEL.WEIGHT = ""
38
+
39
+ _C.SEQUENCE = CN()
40
+ _C.SEQUENCE.SEQ_ON = False
41
+ _C.SEQUENCE.NUM_CHAR = 38
42
+ _C.SEQUENCE.BOS_TOKEN = 0
43
+ _C.SEQUENCE.MAX_LENGTH = 32
44
+ _C.SEQUENCE.TEACHER_FORCE_RATIO = 1.0
45
+ _C.SEQUENCE.TWO_CONV = False
46
+ _C.SEQUENCE.MEAN_SCORE = False
47
+ _C.SEQUENCE.RESIZE_HEIGHT = 16
48
+ _C.SEQUENCE.RESIZE_WIDTH = 64
49
+
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # INPUT
53
+ # -----------------------------------------------------------------------------
54
+ _C.INPUT = CN()
55
+ # Size of the smallest side of the image during training
56
+ _C.INPUT.MIN_SIZE_TRAIN = (800,) # (800,)
57
+ # Maximum size of the side of the image during training
58
+ _C.INPUT.MAX_SIZE_TRAIN = 1333
59
+ # Size of the smallest side of the image during testing
60
+ _C.INPUT.MIN_SIZE_TEST = 800
61
+ # Maximum size of the side of the image during testing
62
+ _C.INPUT.MAX_SIZE_TEST = 1333
63
+ # Values to be used for image normalization
64
+ _C.INPUT.PIXEL_MEAN = [102.9801, 115.9465, 122.7717]
65
+ # Values to be used for image normalization
66
+ _C.INPUT.PIXEL_STD = [1.0, 1.0, 1.0]
67
+ # Convert image to BGR format (for Caffe2 models), in range 0-255
68
+ _C.INPUT.TO_BGR255 = True
69
+ _C.INPUT.STRICT_RESIZE = False
70
+
71
+
72
+ # -----------------------------------------------------------------------------
73
+ # Dataset
74
+ # -----------------------------------------------------------------------------
75
+ _C.DATASETS = CN()
76
+ # List of the dataset names for training, as present in paths_catalog.py
77
+ _C.DATASETS.TRAIN = ()
78
+ # List of the dataset names for testing, as present in paths_catalog.py
79
+ _C.DATASETS.TEST = ()
80
+
81
+ _C.DATASETS.RATIOS = []
82
+
83
+ _C.DATASETS.AUG = False
84
+ _C.DATASETS.RANDOM_CROP_PROB = 0.0
85
+ _C.DATASETS.IGNORE_DIFFICULT = False
86
+ _C.DATASETS.FIX_CROP = False
87
+ _C.DATASETS.CROP_SIZE = (512, 512)
88
+ _C.DATASETS.MAX_ROTATE_THETA = 30
89
+ _C.DATASETS.FIX_ROTATE = False
90
+
91
+ # -----------------------------------------------------------------------------
92
+ # DataLoader
93
+ # -----------------------------------------------------------------------------
94
+ _C.DATALOADER = CN()
95
+ # Number of data loading threads
96
+ _C.DATALOADER.NUM_WORKERS = 4
97
+ # If > 0, this enforces that each collated batch should have a size divisible
98
+ # by SIZE_DIVISIBILITY
99
+ _C.DATALOADER.SIZE_DIVISIBILITY = 0
100
+ # If True, each batch should contain only images for which the aspect ratio
101
+ # is compatible. This groups portrait images together, and landscape images
102
+ # are not batched with portrait images.
103
+ _C.DATALOADER.ASPECT_RATIO_GROUPING = True
104
+
105
+ # ---------------------------------------------------------------------------- #
106
+ # Backbone options
107
+ # ---------------------------------------------------------------------------- #
108
+ _C.MODEL.BACKBONE = CN()
109
+
110
+ # The backbone conv body to use
111
+ # The string must match a function that is imported in modeling.model_builder
112
+ # (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN
113
+ # backbone)
114
+ _C.MODEL.BACKBONE.CONV_BODY = "R-50-C4"
115
+
116
+ # Add StopGrad at a specified stage so the bottom layers are frozen
117
+ _C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2
118
+ _C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4
119
+
120
+ # ---------------------------------------------------------------------------- #
121
+ # ResNe[X]t options (ResNets = {ResNet, ResNeXt}
122
+ # Note that parts of a resnet may be used for both the backbone and the head
123
+ # These options apply to both
124
+ # ---------------------------------------------------------------------------- #
125
+ _C.MODEL.RESNETS = CN()
126
+
127
+ # Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
128
+ _C.MODEL.RESNETS.NUM_GROUPS = 1
129
+
130
+ # Baseline width of each group
131
+ _C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
132
+
133
+ # Place the stride 2 conv on the 1x1 filter
134
+ # Use True only for the original MSRA ResNet; use False for C2 and Torch models
135
+ _C.MODEL.RESNETS.STRIDE_IN_1X1 = True
136
+
137
+ # Residual transformation function
138
+ _C.MODEL.RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm"
139
+ # ResNet's stem function (conv1 and pool1)
140
+ _C.MODEL.RESNETS.STEM_FUNC = "StemWithFixedBatchNorm"
141
+
142
+ # Apply dilation in stage "res5"
143
+ _C.MODEL.RESNETS.RES5_DILATION = 1
144
+
145
+ _C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
146
+ _C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
147
+ _C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
148
+
149
+ _C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False)
150
+ _C.MODEL.RESNETS.WITH_MODULATED_DCN = False
151
+ _C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1
152
+ _C.MODEL.RESNETS.LAYERS = (3, 4, 6, 3)
153
+
154
+ # ---------------------------------------------------------------------------- #
155
+ # FPN options
156
+ # ---------------------------------------------------------------------------- #
157
+ _C.MODEL.FPN = CN()
158
+ _C.MODEL.FPN.USE_GN = False
159
+ _C.MODEL.FPN.USE_RELU = False
160
+
161
+ # ---------------------------------------------------------------------------- #
162
+ # RPN options
163
+ # ---------------------------------------------------------------------------- #
164
+ _C.MODEL.RPN = CN()
165
+ _C.MODEL.RPN.USE_FPN = False
166
+ # Base RPN anchor sizes given in absolute pixels w.r.t. the scaled network input
167
+ _C.MODEL.RPN.ANCHOR_SIZES = (32, 64, 128, 256, 512)
168
+ # Stride of the feature map that RPN is attached.
169
+ # For FPN, number of strides should match number of scales
170
+ _C.MODEL.RPN.ANCHOR_STRIDE = (16,)
171
+ # RPN anchor aspect ratios
172
+ _C.MODEL.RPN.ASPECT_RATIOS = (0.5, 1.0, 2.0)
173
+ # Remove RPN anchors that go outside the image by RPN_STRADDLE_THRESH pixels
174
+ # Set to -1 or a large value, e.g. 100000, to disable pruning anchors
175
+ _C.MODEL.RPN.STRADDLE_THRESH = 0
176
+ # Minimum overlap required between an anchor and ground-truth box for the
177
+ # (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
178
+ # ==> positive RPN example)
179
+ _C.MODEL.RPN.FG_IOU_THRESHOLD = 0.7
180
+ # Maximum overlap allowed between an anchor and ground-truth box for the
181
+ # (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
182
+ # ==> negative RPN example)
183
+ _C.MODEL.RPN.BG_IOU_THRESHOLD = 0.3
184
+ # Total number of RPN examples per image
185
+ _C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
186
+ # Target fraction of foreground (positive) examples per RPN minibatch
187
+ _C.MODEL.RPN.POSITIVE_FRACTION = 0.5
188
+ # Number of top scoring RPN proposals to keep before applying NMS
189
+ # When FPN is used, this is *per FPN level* (not total)
190
+ _C.MODEL.RPN.PRE_NMS_TOP_N_TRAIN = 12000
191
+ _C.MODEL.RPN.PRE_NMS_TOP_N_TEST = 6000
192
+ # Number of top scoring RPN proposals to keep after applying NMS
193
+ _C.MODEL.RPN.POST_NMS_TOP_N_TRAIN = 2000
194
+ _C.MODEL.RPN.POST_NMS_TOP_N_TEST = 1000
195
+ # NMS threshold used on RPN proposals
196
+ _C.MODEL.RPN.NMS_THRESH = 0.7
197
+ # Proposal height and width both need to be greater than RPN_MIN_SIZE
198
+ # (a the scale used during training or inference)
199
+ _C.MODEL.RPN.MIN_SIZE = 0
200
+ # Number of top scoring RPN proposals to keep after combining proposals from
201
+ # all FPN levels
202
+ _C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
203
+ _C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
204
+
205
+ _C.MODEL.SEG = CN()
206
+ _C.MODEL.SEG.USE_FPN = False
207
+ _C.MODEL.SEG.USE_FUSE_FEATURE = False
208
+ # Total number of SEG examples per image
209
+ _C.MODEL.SEG.BATCH_SIZE_PER_IMAGE = 256
210
+ # Target fraction of foreground (positive) examples per SEG minibatch
211
+ _C.MODEL.SEG.POSITIVE_FRACTION = 0.5
212
+ # NMS threshold used on SEG proposals
213
+ _C.MODEL.SEG.BINARY_THRESH = 0.5
214
+ _C.MODEL.SEG.USE_MULTIPLE_THRESH = False
215
+ _C.MODEL.SEG.MULTIPLE_THRESH = (0.2, 0.3, 0.5, 0.7)
216
+ _C.MODEL.SEG.BOX_THRESH = 0.7
217
+ # Proposal height and width both need to be greater than RPN_MIN_SIZE
218
+ # (a the scale used during training or inference)
219
+ _C.MODEL.SEG.MIN_SIZE = 0
220
+ _C.MODEL.SEG.SHRINK_RATIO = 0.5
221
+ # Number of top scoring RPN proposals to keep after combining proposals from
222
+ # all FPN levels
223
+ _C.MODEL.SEG.TOP_N_TRAIN = 1000
224
+ _C.MODEL.SEG.TOP_N_TEST = 1000
225
+ _C.MODEL.SEG.AUG_PROPOSALS = False
226
+ _C.MODEL.SEG.IGNORE_DIFFICULT = True
227
+ _C.MODEL.SEG.EXPAND_RATIO = 1.6
228
+ _C.MODEL.SEG.BOX_EXPAND_RATIO = 1.5
229
+ _C.MODEL.SEG.USE_SEG_POLY = False
230
+ _C.MODEL.SEG.USE_PPM = False
231
+
232
+
233
+ # ---------------------------------------------------------------------------- #
234
+ # ROI HEADS options
235
+ # ---------------------------------------------------------------------------- #
236
+ _C.MODEL.ROI_HEADS = CN()
237
+ _C.MODEL.ROI_HEADS.USE_FPN = False
238
+ # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
239
+ _C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5
240
+ # Overlap threshold for an RoI to be considered background
241
+ # (class = 0 if overlap in [0, BG_IOU_THRESHOLD))
242
+ _C.MODEL.ROI_HEADS.BG_IOU_THRESHOLD = 0.5
243
+ # Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
244
+ # These are empirically chosen to approximately lead to unit variance targets
245
+ _C.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
246
+ # RoI minibatch size *per image* (number of regions of interest [ROIs])
247
+ # Total number of RoIs per training minibatch =
248
+ # TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS
249
+ # E.g., a common configuration is: 512 * 2 * 8 = 8192
250
+ _C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
251
+ # Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
252
+ _C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
253
+
254
+ # Only used on test mode
255
+
256
+ # Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
257
+ # balance obtaining high recall with not having too many low precision
258
+ # detections that will slow down inference post processing steps (like NMS)
259
+ # _C.MODEL.ROI_HEADS.SCORE_THRESH = 0.05
260
+ _C.MODEL.ROI_HEADS.SCORE_THRESH = 0.0
261
+ # Overlap threshold used for non-maximum suppression (suppress boxes with
262
+ # IoU >= this threshold)
263
+ _C.MODEL.ROI_HEADS.NMS = 0.5
264
+ # Maximum number of detections to return per image (100 is based on the limit
265
+ # established for the COCO dataset)
266
+ _C.MODEL.ROI_HEADS.DETECTIONS_PER_IMG = 100
267
+
268
+
269
+ _C.MODEL.ROI_BOX_HEAD = CN()
270
+ _C.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor"
271
+ _C.MODEL.ROI_BOX_HEAD.PREDICTOR = "FastRCNNPredictor"
272
+ _C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
273
+ _C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
274
+ _C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,)
275
+ _C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81
276
+ # Hidden layer dimension when using an MLP for the RoI box head
277
+ _C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024
278
+ _C.MODEL.ROI_BOX_HEAD.USE_REGRESSION = True
279
+ _C.MODEL.ROI_BOX_HEAD.INFERENCE_USE_BOX = True
280
+ _C.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE = False
281
+ _C.MODEL.ROI_BOX_HEAD.SOFT_MASKED_FEATURE_RATIO = 0.
282
+ _C.MODEL.ROI_BOX_HEAD.MIX_OPTION = ""
283
+
284
+
285
+ _C.MODEL.ROI_MASK_HEAD = CN()
286
+ _C.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor"
287
+ _C.MODEL.ROI_MASK_HEAD.PREDICTOR = "MaskRCNNC4Predictor"
288
+ _C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
289
+ _C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_H = 32
290
+ _C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION_W = 128
291
+ _C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
292
+ _C.MODEL.ROI_MASK_HEAD.POOLER_SCALES = (1.0 / 16,)
293
+ _C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024
294
+ _C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
295
+ _C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
296
+ _C.MODEL.ROI_MASK_HEAD.RESOLUTION_H = 32
297
+ _C.MODEL.ROI_MASK_HEAD.RESOLUTION_W = 128
298
+ _C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
299
+ _C.MODEL.ROI_MASK_HEAD.CHAR_NUM_CLASSES = 38
300
+ _C.MODEL.ROI_MASK_HEAD.USE_WEIGHTED_CHAR_MASK = False
301
+ _C.MODEL.ROI_MASK_HEAD.MASK_BATCH_SIZE_PER_IM = 64
302
+ _C.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE = False
303
+ _C.MODEL.ROI_MASK_HEAD.SOFT_MASKED_FEATURE_RATIO = 0.
304
+ _C.MODEL.ROI_MASK_HEAD.MIX_OPTION = ""
305
+
306
+ # ---------------------------------------------------------------------------- #
307
+ # Solver
308
+ # ---------------------------------------------------------------------------- #
309
+ _C.SOLVER = CN()
310
+ _C.SOLVER.MAX_ITER = 40000
311
+
312
+ _C.SOLVER.BASE_LR = 0.001
313
+ _C.SOLVER.BIAS_LR_FACTOR = 2
314
+
315
+ _C.SOLVER.MOMENTUM = 0.9
316
+
317
+ _C.SOLVER.WEIGHT_DECAY = 0.0005
318
+ _C.SOLVER.WEIGHT_DECAY_BIAS = 0
319
+
320
+ _C.SOLVER.GAMMA = 0.1
321
+ _C.SOLVER.STEPS = (30000,)
322
+
323
+ _C.SOLVER.WARMUP_FACTOR = 1.0 / 3
324
+ _C.SOLVER.WARMUP_ITERS = 500
325
+ _C.SOLVER.WARMUP_METHOD = "linear"
326
+
327
+ _C.SOLVER.CHECKPOINT_PERIOD = 5000
328
+
329
+ # Number of images per batch
330
+ # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
331
+ # see 2 images per batch
332
+ _C.SOLVER.IMS_PER_BATCH = 16
333
+
334
+ _C.SOLVER.RESUME = True
335
+
336
+ _C.SOLVER.USE_ADAM = False
337
+
338
+ _C.SOLVER.POW_SCHEDULE = False
339
+
340
+ _C.SOLVER.DISPLAY_FREQ = 20
341
+
342
+ # ---------------------------------------------------------------------------- #
343
+ # Specific test options
344
+ # ---------------------------------------------------------------------------- #
345
+ _C.TEST = CN()
346
+ _C.TEST.EXPECTED_RESULTS = []
347
+ _C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4
348
+ # Number of images per batch
349
+ # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
350
+ # see 2 images per batch
351
+ _C.TEST.IMS_PER_BATCH = 8
352
+ _C.TEST.VIS = False
353
+ # from 0 to 255
354
+ _C.TEST.CHAR_THRESH = 128
355
+
356
+
357
+ # ---------------------------------------------------------------------------- #
358
+ # Misc options
359
+ # ---------------------------------------------------------------------------- #
360
+ _C.OUTPUT_DIR = "."
361
+
362
+ _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")
363
+
364
+
365
+ # ---------------------------------------------------------------------------- #
366
+ # Precision options
367
+ # ---------------------------------------------------------------------------- #
368
+
369
+ # Precision of input, allowable: (float32, float16)
370
+ _C.DTYPE = "float32"
371
+
372
+ # Enable verbosity in apex.amp
373
+ _C.AMP_VERBOSE = False
maskrcnn_benchmark/config/paths_catalog.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ """Centralized catalog of paths."""
3
+
4
+ import os
5
+
6
+
7
+ class DatasetCatalog(object):
8
+ DATA_DIR = "datasets"
9
+ # DATA_DIR = "/share/mhliao/MaskTextSpotterV3/datasets/"
10
+
11
+ DATASETS = {
12
+ "coco_2014_train": (
13
+ "coco/train2014",
14
+ "coco/annotations/instances_train2014.json",
15
+ ),
16
+ "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"),
17
+ "coco_2014_minival": (
18
+ "coco/val2014",
19
+ "coco/annotations/instances_minival2014.json",
20
+ ),
21
+ "coco_2014_valminusminival": (
22
+ "coco/val2014",
23
+ "coco/annotations/instances_valminusminival2014.json",
24
+ ),
25
+ "icdar_2013_train": ("icdar2013/train_images", "icdar2013/train_gts"),
26
+ "icdar_2013_test": ("icdar2013/test_images", "icdar2013/test_gts"),
27
+ "rotated_ic13_test_0": ("icdar2013/rotated_test_images_0", "icdar2013/rotated_test_gts_0"),
28
+ "rotated_ic13_test_15": ("icdar2013/rotated_test_images_15", "icdar2013/rotated_test_gts_15"),
29
+ "rotated_ic13_test_30": ("icdar2013/rotated_test_images_30", "icdar2013/rotated_test_gts_30"),
30
+ "rotated_ic13_test_45": ("icdar2013/rotated_test_images_45", "icdar2013/rotated_test_gts_45"),
31
+ "rotated_ic13_test_60": ("icdar2013/rotated_test_images_60", "icdar2013/rotated_test_gts_60"),
32
+ "rotated_ic13_test_75": ("icdar2013/rotated_test_images_75", "icdar2013/rotated_test_gts_75"),
33
+ "rotated_ic13_test_85": ("icdar2013/rotated_test_images_85", "icdar2013/rotated_test_gts_85"),
34
+ "rotated_ic13_test_90": ("icdar2013/rotated_test_images_90", "icdar2013/rotated_test_gts_90"),
35
+ "rotated_ic13_test_-15": ("icdar2013/rotated_test_images_-15", "icdar2013/rotated_test_gts_-15"),
36
+ "rotated_ic13_test_-30": ("icdar2013/rotated_test_images_-30", "icdar2013/rotated_test_gts_-30"),
37
+ "rotated_ic13_test_-45": ("icdar2013/rotated_test_images_-45", "icdar2013/rotated_test_gts_-45"),
38
+ "rotated_ic13_test_-60": ("icdar2013/rotated_test_images_-60", "icdar2013/rotated_test_gts_-60"),
39
+ "rotated_ic13_test_-75": ("icdar2013/rotated_test_images_-75", "icdar2013/rotated_test_gts_-75"),
40
+ "rotated_ic13_test_-90": ("icdar2013/rotated_test_images_-90", "icdar2013/rotated_test_gts_-90"),
41
+ "icdar_2015_train": ("icdar2015/train_images", "icdar2015/train_gts"),
42
+ "icdar_2015_test": (
43
+ "icdar2015/test_images",
44
+ # "icdar2015/test_gts",
45
+ ),
46
+ "synthtext_train": ("synthtext/train_images", "synthtext/train_gts"),
47
+ "synthtext_test": ("synthtext/test_images", "synthtext/test_gts"),
48
+ "total_text_train": ("total_text/train_images", "total_text/train_gts"),
49
+ "td500_train": ("TD_TR/TD500/train_images", "TD500/train_gts"),
50
+ "td500_test": ("TD_TR/TD500/test_images", ),
51
+ "tr400_train": ("TD_TR/TR400/train_images", "TR400/train_gts"),
52
+ "total_text_test": (
53
+ "total_text/test_images",
54
+ # "total_text/test_gts",
55
+ ),
56
+ "scut-eng-char_train": (
57
+ "scut-eng-char/train_images",
58
+ "scut-eng-char/train_gts",
59
+ ),
60
+ }
61
+
62
+ @staticmethod
63
+ def get(name):
64
+ if "coco" in name:
65
+ data_dir = DatasetCatalog.DATA_DIR
66
+ attrs = DatasetCatalog.DATASETS[name]
67
+ args = dict(
68
+ root=os.path.join(data_dir, attrs[0]),
69
+ ann_file=os.path.join(data_dir, attrs[1]),
70
+ )
71
+ return dict(factory="COCODataset", args=args)
72
+ elif "icdar_2013" in name:
73
+ data_dir = DatasetCatalog.DATA_DIR
74
+ attrs = DatasetCatalog.DATASETS[name]
75
+ args = dict(
76
+ use_charann=True,
77
+ imgs_dir=os.path.join(data_dir, attrs[0]),
78
+ gts_dir=os.path.join(data_dir, attrs[1]),
79
+ # imgs_dir='/tmp/icdar2013/icdar2013/train_images',
80
+ # gts_dir='/tmp/icdar2013/icdar2013/train_gts',
81
+ )
82
+ return dict(args=args, factory="IcdarDataset")
83
+ elif "rotated_ic13" in name:
84
+ data_dir = DatasetCatalog.DATA_DIR
85
+ attrs = DatasetCatalog.DATASETS[name]
86
+ args = dict(
87
+ use_charann=True,
88
+ imgs_dir=os.path.join(data_dir, attrs[0]),
89
+ gts_dir=os.path.join(data_dir, attrs[1]),
90
+ )
91
+ return dict(args=args, factory="IcdarDataset")
92
+ elif "icdar_2015" in name:
93
+ data_dir = DatasetCatalog.DATA_DIR
94
+ attrs = DatasetCatalog.DATASETS[name]
95
+ if len(attrs) > 1:
96
+ gts_dir = os.path.join(data_dir, attrs[1])
97
+ else:
98
+ gts_dir = None
99
+
100
+ args = dict(
101
+ use_charann=False,
102
+ imgs_dir=os.path.join(data_dir, attrs[0]),
103
+ gts_dir=gts_dir,
104
+ # imgs_dir='/tmp/icdar2015/icdar2015/train_images/',
105
+ # gts_dir='/tmp/icdar2015/icdar2015/train_gts/',
106
+ )
107
+ return dict(args=args, factory="IcdarDataset")
108
+ elif "synthtext" in name:
109
+ data_dir = DatasetCatalog.DATA_DIR
110
+ attrs = DatasetCatalog.DATASETS[name]
111
+ args = dict(
112
+ use_charann=True,
113
+ list_file_path=os.path.join(data_dir, "synthtext/train_list.txt"),
114
+ imgs_dir=os.path.join(data_dir, attrs[0]),
115
+ gts_dir=os.path.join(data_dir, attrs[1]),
116
+ # imgs_dir='/tmp/synth/SynthText/',
117
+ # gts_dir='/tmp/synth_gt/SynthText_GT_E2E/',
118
+ )
119
+ return dict(args=args, factory="SynthtextDataset")
120
+ elif "total_text" in name:
121
+ data_dir = DatasetCatalog.DATA_DIR
122
+ # data_dir = '/tmp/total_text/'
123
+ attrs = DatasetCatalog.DATASETS[name]
124
+ if len(attrs) > 1:
125
+ gts_dir = os.path.join(data_dir, attrs[1])
126
+ else:
127
+ gts_dir = None
128
+ args = dict(
129
+ use_charann=False,
130
+ imgs_dir=os.path.join(data_dir, attrs[0]),
131
+ gts_dir=gts_dir,
132
+ # imgs_dir='/tmp/total_text/total_text/train_images/',
133
+ # gts_dir='/tmp/total_text/total_text/train_gts/',
134
+ )
135
+ return dict(args=args, factory="TotaltextDataset")
136
+ elif "scut-eng-char" in name:
137
+ data_dir = DatasetCatalog.DATA_DIR
138
+ attrs = DatasetCatalog.DATASETS[name]
139
+ args = dict(
140
+ use_charann=True,
141
+ imgs_dir=os.path.join(data_dir, attrs[0]),
142
+ gts_dir=os.path.join(data_dir, attrs[1]),
143
+ # imgs_dir='/tmp/scut-eng-char/scut-eng-char/train_images/',
144
+ # gts_dir='/tmp/scut-eng-char/scut-eng-char/train_gts/',
145
+ )
146
+ return dict(args=args, factory="ScutDataset")
147
+ elif "td500" in name:
148
+ data_dir = DatasetCatalog.DATA_DIR
149
+ attrs = DatasetCatalog.DATASETS[name]
150
+ if len(attrs) > 1:
151
+ gts_dir = os.path.join(data_dir, attrs[1])
152
+ else:
153
+ gts_dir = None
154
+ args = dict(
155
+ use_charann=False,
156
+ imgs_dir=os.path.join(data_dir, attrs[0]),
157
+ gts_dir=gts_dir,
158
+ )
159
+ return dict(args=args, factory="TotaltextDataset")
160
+ elif "tr400" in name:
161
+ data_dir = DatasetCatalog.DATA_DIR
162
+ attrs = DatasetCatalog.DATASETS[name]
163
+ if len(attrs) > 1:
164
+ gts_dir = os.path.join(data_dir, attrs[1])
165
+ else:
166
+ gts_dir = None
167
+ args = dict(
168
+ use_charann=False,
169
+ imgs_dir=os.path.join(data_dir, attrs[0]),
170
+ gts_dir=gts_dir,
171
+ )
172
+ return dict(args=args, factory="TotaltextDataset")
173
+ raise RuntimeError("Dataset not available: {}".format(name))
174
+
175
+
176
+ class ModelCatalog(object):
177
+ S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron"
178
+ C2_IMAGENET_MODELS = {
179
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
180
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
181
+ "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
182
+ "MSRA/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
183
+ "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
184
+ "MSRA/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
185
+ "FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
186
+ }
187
+
188
+ C2_DETECTRON_SUFFIX = "output/train/{}coco_2014_train%3A{}coco_2014_valminusminival/generalized_rcnn/model_final.pkl"
189
+ C2_DETECTRON_MODELS = {
190
+ "35857197/e2e_faster_rcnn_R-50-C4_1x": "01_33_49.iAX0mXvW",
191
+ "35857345/e2e_faster_rcnn_R-50-FPN_1x": "01_36_30.cUF7QR7I",
192
+ "35857890/e2e_faster_rcnn_R-101-FPN_1x": "01_38_50.sNxI7sX7",
193
+ "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "06_31_39.5MIHi1fZ",
194
+ "35858791/e2e_mask_rcnn_R-50-C4_1x": "01_45_57.ZgkA7hPB",
195
+ "35858933/e2e_mask_rcnn_R-50-FPN_1x": "01_48_14.DzEQe4wC",
196
+ "35861795/e2e_mask_rcnn_R-101-FPN_1x": "02_31_37.KqyEK4tT",
197
+ "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "06_35_59.RZotkLKI",
198
+ "37129812/e2e_mask_rcnn_X-152-32x8d-FPN-IN5k_1.44x": "09_35_36.8pzTQKYK",
199
+ # keypoints
200
+ "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "08_42_54.kdzV35ao"
201
+ }
202
+
203
+ @staticmethod
204
+ def get(name):
205
+ if name.startswith("Caffe2Detectron/COCO"):
206
+ return ModelCatalog.get_c2_detectron_12_2017_baselines(name)
207
+ if name.startswith("ImageNetPretrained"):
208
+ return ModelCatalog.get_c2_imagenet_pretrained(name)
209
+ raise RuntimeError("model not present in the catalog {}".format(name))
210
+
211
+ @staticmethod
212
+ def get_c2_imagenet_pretrained(name):
213
+ prefix = ModelCatalog.S3_C2_DETECTRON_URL
214
+ name = name[len("ImageNetPretrained/") :]
215
+ name = ModelCatalog.C2_IMAGENET_MODELS[name]
216
+ if 'resnet34' in name or 'resnet18' in name:
217
+ return name
218
+ url = "/".join([prefix, name])
219
+ return url
220
+
221
+ @staticmethod
222
+ def get_c2_detectron_12_2017_baselines(name):
223
+ # Detectron C2 models are stored following the structure
224
+ # prefix/<model_id>/2012_2017_baselines/<model_name>.yaml.<signature>/suffix
225
+ # we use as identifiers in the catalog Caffe2Detectron/COCO/<model_id>/<model_name>
226
+ prefix = ModelCatalog.S3_C2_DETECTRON_URL
227
+ suffix = ModelCatalog.C2_DETECTRON_SUFFIX
228
+ # remove identification prefix
229
+ name = name[len("Caffe2Detectron/COCO/") :]
230
+ # split in <model_id> and <model_name>
231
+ model_id, model_name = name.split("/")
232
+ # parsing to make it match the url address from the Caffe2 models
233
+ model_name = "{}.yaml".format(model_name)
234
+ signature = ModelCatalog.C2_DETECTRON_MODELS[name]
235
+ unique_name = ".".join([model_name, signature])
236
+ url = "/".join([prefix, model_id, "12_2017_baselines", unique_name, suffix])
237
+ return url
maskrcnn_benchmark/csrc/ROIAlign.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #pragma once
3
+
4
+ #include "cpu/vision.h"
5
+
6
+ #ifdef WITH_CUDA
7
+ #include "cuda/vision.h"
8
+ #endif
9
+
10
+ // Interface for Python
11
+ at::Tensor ROIAlign_forward(const at::Tensor& input,
12
+ const at::Tensor& rois,
13
+ const float spatial_scale,
14
+ const int pooled_height,
15
+ const int pooled_width,
16
+ const int sampling_ratio) {
17
+ if (input.type().is_cuda()) {
18
+ #ifdef WITH_CUDA
19
+ return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
20
+ #else
21
+ AT_ERROR("Not compiled with GPU support");
22
+ #endif
23
+ }
24
+ return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
25
+ }
26
+
27
+ at::Tensor ROIAlign_backward(const at::Tensor& grad,
28
+ const at::Tensor& rois,
29
+ const float spatial_scale,
30
+ const int pooled_height,
31
+ const int pooled_width,
32
+ const int batch_size,
33
+ const int channels,
34
+ const int height,
35
+ const int width,
36
+ const int sampling_ratio) {
37
+ if (grad.type().is_cuda()) {
38
+ #ifdef WITH_CUDA
39
+ return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
40
+ #else
41
+ AT_ERROR("Not compiled with GPU support");
42
+ #endif
43
+ }
44
+ AT_ERROR("Not implemented on the CPU");
45
+ }
46
+
maskrcnn_benchmark/csrc/ROIPool.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #pragma once
3
+
4
+ #include "cpu/vision.h"
5
+
6
+ #ifdef WITH_CUDA
7
+ #include "cuda/vision.h"
8
+ #endif
9
+
10
+
11
+ std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor& input,
12
+ const at::Tensor& rois,
13
+ const float spatial_scale,
14
+ const int pooled_height,
15
+ const int pooled_width) {
16
+ if (input.type().is_cuda()) {
17
+ #ifdef WITH_CUDA
18
+ return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
19
+ #else
20
+ AT_ERROR("Not compiled with GPU support");
21
+ #endif
22
+ }
23
+ AT_ERROR("Not implemented on the CPU");
24
+ }
25
+
26
+ at::Tensor ROIPool_backward(const at::Tensor& grad,
27
+ const at::Tensor& input,
28
+ const at::Tensor& rois,
29
+ const at::Tensor& argmax,
30
+ const float spatial_scale,
31
+ const int pooled_height,
32
+ const int pooled_width,
33
+ const int batch_size,
34
+ const int channels,
35
+ const int height,
36
+ const int width) {
37
+ if (grad.type().is_cuda()) {
38
+ #ifdef WITH_CUDA
39
+ return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
40
+ #else
41
+ AT_ERROR("Not compiled with GPU support");
42
+ #endif
43
+ }
44
+ AT_ERROR("Not implemented on the CPU");
45
+ }
46
+
47
+
48
+
maskrcnn_benchmark/csrc/SigmoidFocalLoss.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cpu/vision.h"
4
+
5
+ #ifdef WITH_CUDA
6
+ #include "cuda/vision.h"
7
+ #endif
8
+
9
+ // Interface for Python
10
+ at::Tensor SigmoidFocalLoss_forward(
11
+ const at::Tensor& logits,
12
+ const at::Tensor& targets,
13
+ const int num_classes,
14
+ const float gamma,
15
+ const float alpha) {
16
+ if (logits.type().is_cuda()) {
17
+ #ifdef WITH_CUDA
18
+ return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha);
19
+ #else
20
+ AT_ERROR("Not compiled with GPU support");
21
+ #endif
22
+ }
23
+ AT_ERROR("Not implemented on the CPU");
24
+ }
25
+
26
+ at::Tensor SigmoidFocalLoss_backward(
27
+ const at::Tensor& logits,
28
+ const at::Tensor& targets,
29
+ const at::Tensor& d_losses,
30
+ const int num_classes,
31
+ const float gamma,
32
+ const float alpha) {
33
+ if (logits.type().is_cuda()) {
34
+ #ifdef WITH_CUDA
35
+ return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha);
36
+ #else
37
+ AT_ERROR("Not compiled with GPU support");
38
+ #endif
39
+ }
40
+ AT_ERROR("Not implemented on the CPU");
41
+ }
maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #include "cpu/vision.h"
3
+
4
+ // implementation taken from Caffe2
5
+ template <typename T>
6
+ struct PreCalc {
7
+ int pos1;
8
+ int pos2;
9
+ int pos3;
10
+ int pos4;
11
+ T w1;
12
+ T w2;
13
+ T w3;
14
+ T w4;
15
+ };
16
+
17
+ template <typename T>
18
+ void pre_calc_for_bilinear_interpolate(
19
+ const int height,
20
+ const int width,
21
+ const int pooled_height,
22
+ const int pooled_width,
23
+ const int iy_upper,
24
+ const int ix_upper,
25
+ T roi_start_h,
26
+ T roi_start_w,
27
+ T bin_size_h,
28
+ T bin_size_w,
29
+ int roi_bin_grid_h,
30
+ int roi_bin_grid_w,
31
+ std::vector<PreCalc<T>>& pre_calc) {
32
+ int pre_calc_index = 0;
33
+ for (int ph = 0; ph < pooled_height; ph++) {
34
+ for (int pw = 0; pw < pooled_width; pw++) {
35
+ for (int iy = 0; iy < iy_upper; iy++) {
36
+ const T yy = roi_start_h + ph * bin_size_h +
37
+ static_cast<T>(iy + .5f) * bin_size_h /
38
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
39
+ for (int ix = 0; ix < ix_upper; ix++) {
40
+ const T xx = roi_start_w + pw * bin_size_w +
41
+ static_cast<T>(ix + .5f) * bin_size_w /
42
+ static_cast<T>(roi_bin_grid_w);
43
+
44
+ T x = xx;
45
+ T y = yy;
46
+ // deal with: inverse elements are out of feature map boundary
47
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
48
+ // empty
49
+ PreCalc<T> pc;
50
+ pc.pos1 = 0;
51
+ pc.pos2 = 0;
52
+ pc.pos3 = 0;
53
+ pc.pos4 = 0;
54
+ pc.w1 = 0;
55
+ pc.w2 = 0;
56
+ pc.w3 = 0;
57
+ pc.w4 = 0;
58
+ pre_calc[pre_calc_index] = pc;
59
+ pre_calc_index += 1;
60
+ continue;
61
+ }
62
+
63
+ if (y <= 0) {
64
+ y = 0;
65
+ }
66
+ if (x <= 0) {
67
+ x = 0;
68
+ }
69
+
70
+ int y_low = (int)y;
71
+ int x_low = (int)x;
72
+ int y_high;
73
+ int x_high;
74
+
75
+ if (y_low >= height - 1) {
76
+ y_high = y_low = height - 1;
77
+ y = (T)y_low;
78
+ } else {
79
+ y_high = y_low + 1;
80
+ }
81
+
82
+ if (x_low >= width - 1) {
83
+ x_high = x_low = width - 1;
84
+ x = (T)x_low;
85
+ } else {
86
+ x_high = x_low + 1;
87
+ }
88
+
89
+ T ly = y - y_low;
90
+ T lx = x - x_low;
91
+ T hy = 1. - ly, hx = 1. - lx;
92
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
93
+
94
+ // save weights and indices
95
+ PreCalc<T> pc;
96
+ pc.pos1 = y_low * width + x_low;
97
+ pc.pos2 = y_low * width + x_high;
98
+ pc.pos3 = y_high * width + x_low;
99
+ pc.pos4 = y_high * width + x_high;
100
+ pc.w1 = w1;
101
+ pc.w2 = w2;
102
+ pc.w3 = w3;
103
+ pc.w4 = w4;
104
+ pre_calc[pre_calc_index] = pc;
105
+
106
+ pre_calc_index += 1;
107
+ }
108
+ }
109
+ }
110
+ }
111
+ }
112
+
113
+ template <typename T>
114
+ void ROIAlignForward_cpu_kernel(
115
+ const int nthreads,
116
+ const T* bottom_data,
117
+ const T& spatial_scale,
118
+ const int channels,
119
+ const int height,
120
+ const int width,
121
+ const int pooled_height,
122
+ const int pooled_width,
123
+ const int sampling_ratio,
124
+ const T* bottom_rois,
125
+ //int roi_cols,
126
+ T* top_data) {
127
+ //AT_ASSERT(roi_cols == 4 || roi_cols == 5);
128
+ int roi_cols = 5;
129
+
130
+ int n_rois = nthreads / channels / pooled_width / pooled_height;
131
+ // (n, c, ph, pw) is an element in the pooled output
132
+ // can be parallelized using omp
133
+ // #pragma omp parallel for num_threads(32)
134
+ for (int n = 0; n < n_rois; n++) {
135
+ int index_n = n * channels * pooled_width * pooled_height;
136
+
137
+ // roi could have 4 or 5 columns
138
+ const T* offset_bottom_rois = bottom_rois + n * roi_cols;
139
+ int roi_batch_ind = 0;
140
+ if (roi_cols == 5) {
141
+ roi_batch_ind = offset_bottom_rois[0];
142
+ offset_bottom_rois++;
143
+ }
144
+
145
+ // Do not using rounding; this implementation detail is critical
146
+ T roi_start_w = offset_bottom_rois[0] * spatial_scale;
147
+ T roi_start_h = offset_bottom_rois[1] * spatial_scale;
148
+ T roi_end_w = offset_bottom_rois[2] * spatial_scale;
149
+ T roi_end_h = offset_bottom_rois[3] * spatial_scale;
150
+ // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
151
+ // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
152
+ // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
153
+ // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
154
+
155
+ // Force malformed ROIs to be 1x1
156
+ T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
157
+ T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
158
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
159
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
160
+
161
+ // We use roi_bin_grid to sample the grid and mimic integral
162
+ int roi_bin_grid_h = (sampling_ratio > 0)
163
+ ? sampling_ratio
164
+ : ceil(roi_height / pooled_height); // e.g., = 2
165
+ int roi_bin_grid_w =
166
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
167
+
168
+ // We do average (integral) pooling inside a bin
169
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
170
+
171
+ // we want to precalculate indices and weights shared by all channels,
172
+ // this is the key point of optimization
173
+ std::vector<PreCalc<T>> pre_calc(
174
+ roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
175
+ pre_calc_for_bilinear_interpolate(
176
+ height,
177
+ width,
178
+ pooled_height,
179
+ pooled_width,
180
+ roi_bin_grid_h,
181
+ roi_bin_grid_w,
182
+ roi_start_h,
183
+ roi_start_w,
184
+ bin_size_h,
185
+ bin_size_w,
186
+ roi_bin_grid_h,
187
+ roi_bin_grid_w,
188
+ pre_calc);
189
+
190
+ for (int c = 0; c < channels; c++) {
191
+ int index_n_c = index_n + c * pooled_width * pooled_height;
192
+ const T* offset_bottom_data =
193
+ bottom_data + (roi_batch_ind * channels + c) * height * width;
194
+ int pre_calc_index = 0;
195
+
196
+ for (int ph = 0; ph < pooled_height; ph++) {
197
+ for (int pw = 0; pw < pooled_width; pw++) {
198
+ int index = index_n_c + ph * pooled_width + pw;
199
+
200
+ T output_val = 0.;
201
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
202
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
203
+ PreCalc<T> pc = pre_calc[pre_calc_index];
204
+ output_val += pc.w1 * offset_bottom_data[pc.pos1] +
205
+ pc.w2 * offset_bottom_data[pc.pos2] +
206
+ pc.w3 * offset_bottom_data[pc.pos3] +
207
+ pc.w4 * offset_bottom_data[pc.pos4];
208
+
209
+ pre_calc_index += 1;
210
+ }
211
+ }
212
+ output_val /= count;
213
+
214
+ top_data[index] = output_val;
215
+ } // for pw
216
+ } // for ph
217
+ } // for c
218
+ } // for n
219
+ }
220
+
221
+ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
222
+ const at::Tensor& rois,
223
+ const float spatial_scale,
224
+ const int pooled_height,
225
+ const int pooled_width,
226
+ const int sampling_ratio) {
227
+ AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor");
228
+ AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor");
229
+
230
+ auto num_rois = rois.size(0);
231
+ auto channels = input.size(1);
232
+ auto height = input.size(2);
233
+ auto width = input.size(3);
234
+
235
+ auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
236
+ auto output_size = num_rois * pooled_height * pooled_width * channels;
237
+
238
+ if (output.numel() == 0) {
239
+ return output;
240
+ }
241
+
242
+ AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
243
+ ROIAlignForward_cpu_kernel<scalar_t>(
244
+ output_size,
245
+ input.data<scalar_t>(),
246
+ spatial_scale,
247
+ channels,
248
+ height,
249
+ width,
250
+ pooled_height,
251
+ pooled_width,
252
+ sampling_ratio,
253
+ rois.data<scalar_t>(),
254
+ output.data<scalar_t>());
255
+ });
256
+ return output;
257
+ }
maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #include "cpu/vision.h"
3
+
4
+
5
+ template <typename scalar_t>
6
+ at::Tensor nms_cpu_kernel(const at::Tensor& dets,
7
+ const at::Tensor& scores,
8
+ const float threshold) {
9
+ AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
10
+ AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
11
+ AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores");
12
+
13
+ if (dets.numel() == 0) {
14
+ return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
15
+ }
16
+
17
+ auto x1_t = dets.select(1, 0).contiguous();
18
+ auto y1_t = dets.select(1, 1).contiguous();
19
+ auto x2_t = dets.select(1, 2).contiguous();
20
+ auto y2_t = dets.select(1, 3).contiguous();
21
+
22
+ at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
23
+
24
+ auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
25
+
26
+ auto ndets = dets.size(0);
27
+ at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU));
28
+
29
+ auto suppressed = suppressed_t.data<uint8_t>();
30
+ auto order = order_t.data<int64_t>();
31
+ auto x1 = x1_t.data<scalar_t>();
32
+ auto y1 = y1_t.data<scalar_t>();
33
+ auto x2 = x2_t.data<scalar_t>();
34
+ auto y2 = y2_t.data<scalar_t>();
35
+ auto areas = areas_t.data<scalar_t>();
36
+
37
+ for (int64_t _i = 0; _i < ndets; _i++) {
38
+ auto i = order[_i];
39
+ if (suppressed[i] == 1)
40
+ continue;
41
+ auto ix1 = x1[i];
42
+ auto iy1 = y1[i];
43
+ auto ix2 = x2[i];
44
+ auto iy2 = y2[i];
45
+ auto iarea = areas[i];
46
+
47
+ for (int64_t _j = _i + 1; _j < ndets; _j++) {
48
+ auto j = order[_j];
49
+ if (suppressed[j] == 1)
50
+ continue;
51
+ auto xx1 = std::max(ix1, x1[j]);
52
+ auto yy1 = std::max(iy1, y1[j]);
53
+ auto xx2 = std::min(ix2, x2[j]);
54
+ auto yy2 = std::min(iy2, y2[j]);
55
+
56
+ auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
57
+ auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
58
+ auto inter = w * h;
59
+ auto ovr = inter / (iarea + areas[j] - inter);
60
+ if (ovr >= threshold)
61
+ suppressed[j] = 1;
62
+ }
63
+ }
64
+ return at::nonzero(suppressed_t == 0).squeeze(1);
65
+ }
66
+
67
+ at::Tensor nms_cpu(const at::Tensor& dets,
68
+ const at::Tensor& scores,
69
+ const float threshold) {
70
+ at::Tensor result;
71
+ AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
72
+ result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
73
+ });
74
+ return result;
75
+ }
maskrcnn_benchmark/csrc/cpu/vision.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #pragma once
3
+ #include <torch/extension.h>
4
+
5
+
6
+ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
7
+ const at::Tensor& rois,
8
+ const float spatial_scale,
9
+ const int pooled_height,
10
+ const int pooled_width,
11
+ const int sampling_ratio);
12
+
13
+
14
+ at::Tensor nms_cpu(const at::Tensor& dets,
15
+ const at::Tensor& scores,
16
+ const float threshold);
maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+
5
+ #include <THC/THC.h>
6
+ #include <THC/THCAtomics.cuh>
7
+ #include <THC/THCDeviceUtils.cuh>
8
+
9
+ // TODO make it in a common file
10
+ #define CUDA_1D_KERNEL_LOOP(i, n) \
11
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
12
+ i += blockDim.x * gridDim.x)
13
+
14
+
15
+ template <typename T>
16
+ __device__ T bilinear_interpolate(const T* bottom_data,
17
+ const int height, const int width,
18
+ T y, T x,
19
+ const int index /* index for debug only*/) {
20
+
21
+ // deal with cases that inverse elements are out of feature map boundary
22
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
23
+ //empty
24
+ return 0;
25
+ }
26
+
27
+ if (y <= 0) y = 0;
28
+ if (x <= 0) x = 0;
29
+
30
+ int y_low = (int) y;
31
+ int x_low = (int) x;
32
+ int y_high;
33
+ int x_high;
34
+
35
+ if (y_low >= height - 1) {
36
+ y_high = y_low = height - 1;
37
+ y = (T) y_low;
38
+ } else {
39
+ y_high = y_low + 1;
40
+ }
41
+
42
+ if (x_low >= width - 1) {
43
+ x_high = x_low = width - 1;
44
+ x = (T) x_low;
45
+ } else {
46
+ x_high = x_low + 1;
47
+ }
48
+
49
+ T ly = y - y_low;
50
+ T lx = x - x_low;
51
+ T hy = 1. - ly, hx = 1. - lx;
52
+ // do bilinear interpolation
53
+ T v1 = bottom_data[y_low * width + x_low];
54
+ T v2 = bottom_data[y_low * width + x_high];
55
+ T v3 = bottom_data[y_high * width + x_low];
56
+ T v4 = bottom_data[y_high * width + x_high];
57
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
58
+
59
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
60
+
61
+ return val;
62
+ }
63
+
64
+ template <typename T>
65
+ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
66
+ const T spatial_scale, const int channels,
67
+ const int height, const int width,
68
+ const int pooled_height, const int pooled_width,
69
+ const int sampling_ratio,
70
+ const T* bottom_rois, T* top_data) {
71
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
72
+ // (n, c, ph, pw) is an element in the pooled output
73
+ int pw = index % pooled_width;
74
+ int ph = (index / pooled_width) % pooled_height;
75
+ int c = (index / pooled_width / pooled_height) % channels;
76
+ int n = index / pooled_width / pooled_height / channels;
77
+
78
+ const T* offset_bottom_rois = bottom_rois + n * 5;
79
+ int roi_batch_ind = offset_bottom_rois[0];
80
+
81
+ // Do not using rounding; this implementation detail is critical
82
+ T roi_start_w = offset_bottom_rois[1] * spatial_scale;
83
+ T roi_start_h = offset_bottom_rois[2] * spatial_scale;
84
+ T roi_end_w = offset_bottom_rois[3] * spatial_scale;
85
+ T roi_end_h = offset_bottom_rois[4] * spatial_scale;
86
+ // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
87
+ // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
88
+ // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
89
+ // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
90
+
91
+ // Force malformed ROIs to be 1x1
92
+ T roi_width = max(roi_end_w - roi_start_w, (T)1.);
93
+ T roi_height = max(roi_end_h - roi_start_h, (T)1.);
94
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
95
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
96
+
97
+ const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
98
+
99
+ // We use roi_bin_grid to sample the grid and mimic integral
100
+ int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
101
+ int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
102
+
103
+ // We do average (integral) pooling inside a bin
104
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
105
+
106
+ T output_val = 0.;
107
+ for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
108
+ {
109
+ const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
110
+ for (int ix = 0; ix < roi_bin_grid_w; ix ++)
111
+ {
112
+ const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
113
+
114
+ T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
115
+ output_val += val;
116
+ }
117
+ }
118
+ output_val /= count;
119
+
120
+ top_data[index] = output_val;
121
+ }
122
+ }
123
+
124
+
125
+ template <typename T>
126
+ __device__ void bilinear_interpolate_gradient(
127
+ const int height, const int width,
128
+ T y, T x,
129
+ T & w1, T & w2, T & w3, T & w4,
130
+ int & x_low, int & x_high, int & y_low, int & y_high,
131
+ const int index /* index for debug only*/) {
132
+
133
+ // deal with cases that inverse elements are out of feature map boundary
134
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
135
+ //empty
136
+ w1 = w2 = w3 = w4 = 0.;
137
+ x_low = x_high = y_low = y_high = -1;
138
+ return;
139
+ }
140
+
141
+ if (y <= 0) y = 0;
142
+ if (x <= 0) x = 0;
143
+
144
+ y_low = (int) y;
145
+ x_low = (int) x;
146
+
147
+ if (y_low >= height - 1) {
148
+ y_high = y_low = height - 1;
149
+ y = (T) y_low;
150
+ } else {
151
+ y_high = y_low + 1;
152
+ }
153
+
154
+ if (x_low >= width - 1) {
155
+ x_high = x_low = width - 1;
156
+ x = (T) x_low;
157
+ } else {
158
+ x_high = x_low + 1;
159
+ }
160
+
161
+ T ly = y - y_low;
162
+ T lx = x - x_low;
163
+ T hy = 1. - ly, hx = 1. - lx;
164
+
165
+ // reference in forward
166
+ // T v1 = bottom_data[y_low * width + x_low];
167
+ // T v2 = bottom_data[y_low * width + x_high];
168
+ // T v3 = bottom_data[y_high * width + x_low];
169
+ // T v4 = bottom_data[y_high * width + x_high];
170
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
171
+
172
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
173
+
174
+ return;
175
+ }
176
+
177
+ template <typename T>
178
+ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
179
+ const int num_rois, const T spatial_scale,
180
+ const int channels, const int height, const int width,
181
+ const int pooled_height, const int pooled_width,
182
+ const int sampling_ratio,
183
+ T* bottom_diff,
184
+ const T* bottom_rois) {
185
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
186
+ // (n, c, ph, pw) is an element in the pooled output
187
+ int pw = index % pooled_width;
188
+ int ph = (index / pooled_width) % pooled_height;
189
+ int c = (index / pooled_width / pooled_height) % channels;
190
+ int n = index / pooled_width / pooled_height / channels;
191
+
192
+ const T* offset_bottom_rois = bottom_rois + n * 5;
193
+ int roi_batch_ind = offset_bottom_rois[0];
194
+
195
+ // Do not using rounding; this implementation detail is critical
196
+ T roi_start_w = offset_bottom_rois[1] * spatial_scale;
197
+ T roi_start_h = offset_bottom_rois[2] * spatial_scale;
198
+ T roi_end_w = offset_bottom_rois[3] * spatial_scale;
199
+ T roi_end_h = offset_bottom_rois[4] * spatial_scale;
200
+ // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
201
+ // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
202
+ // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
203
+ // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
204
+
205
+ // Force malformed ROIs to be 1x1
206
+ T roi_width = max(roi_end_w - roi_start_w, (T)1.);
207
+ T roi_height = max(roi_end_h - roi_start_h, (T)1.);
208
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
209
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
210
+
211
+ T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
212
+
213
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
214
+ const T* offset_top_diff = top_diff + top_offset;
215
+ const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
216
+
217
+ // We use roi_bin_grid to sample the grid and mimic integral
218
+ int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
219
+ int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
220
+
221
+ // We do average (integral) pooling inside a bin
222
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
223
+
224
+ for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
225
+ {
226
+ const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
227
+ for (int ix = 0; ix < roi_bin_grid_w; ix ++)
228
+ {
229
+ const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
230
+
231
+ T w1, w2, w3, w4;
232
+ int x_low, x_high, y_low, y_high;
233
+
234
+ bilinear_interpolate_gradient(height, width, y, x,
235
+ w1, w2, w3, w4,
236
+ x_low, x_high, y_low, y_high,
237
+ index);
238
+
239
+ T g1 = top_diff_this_bin * w1 / count;
240
+ T g2 = top_diff_this_bin * w2 / count;
241
+ T g3 = top_diff_this_bin * w3 / count;
242
+ T g4 = top_diff_this_bin * w4 / count;
243
+
244
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
245
+ {
246
+ atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
247
+ atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
248
+ atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
249
+ atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
250
+ } // if
251
+ } // ix
252
+ } // iy
253
+ } // CUDA_1D_KERNEL_LOOP
254
+ } // RoIAlignBackward
255
+
256
+
257
+ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
258
+ const at::Tensor& rois,
259
+ const float spatial_scale,
260
+ const int pooled_height,
261
+ const int pooled_width,
262
+ const int sampling_ratio) {
263
+ AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
264
+ AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
265
+
266
+ auto num_rois = rois.size(0);
267
+ auto channels = input.size(1);
268
+ auto height = input.size(2);
269
+ auto width = input.size(3);
270
+
271
+ auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
272
+ auto output_size = num_rois * pooled_height * pooled_width * channels;
273
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
274
+
275
+ dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
276
+ dim3 block(512);
277
+
278
+ if (output.numel() == 0) {
279
+ THCudaCheck(cudaGetLastError());
280
+ return output;
281
+ }
282
+
283
+ AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
284
+ RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
285
+ output_size,
286
+ input.contiguous().data<scalar_t>(),
287
+ spatial_scale,
288
+ channels,
289
+ height,
290
+ width,
291
+ pooled_height,
292
+ pooled_width,
293
+ sampling_ratio,
294
+ rois.contiguous().data<scalar_t>(),
295
+ output.data<scalar_t>());
296
+ });
297
+ THCudaCheck(cudaGetLastError());
298
+ return output;
299
+ }
300
+
301
+ // TODO remove the dependency on input and use instead its sizes -> save memory
302
+ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
303
+ const at::Tensor& rois,
304
+ const float spatial_scale,
305
+ const int pooled_height,
306
+ const int pooled_width,
307
+ const int batch_size,
308
+ const int channels,
309
+ const int height,
310
+ const int width,
311
+ const int sampling_ratio) {
312
+ AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
313
+ AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
314
+
315
+ auto num_rois = rois.size(0);
316
+ auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
317
+
318
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
319
+
320
+ dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
321
+ dim3 block(512);
322
+
323
+ // handle possibly empty gradients
324
+ if (grad.numel() == 0) {
325
+ THCudaCheck(cudaGetLastError());
326
+ return grad_input;
327
+ }
328
+
329
+ AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
330
+ RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
331
+ grad.numel(),
332
+ grad.contiguous().data<scalar_t>(),
333
+ num_rois,
334
+ spatial_scale,
335
+ channels,
336
+ height,
337
+ width,
338
+ pooled_height,
339
+ pooled_width,
340
+ sampling_ratio,
341
+ grad_input.data<scalar_t>(),
342
+ rois.contiguous().data<scalar_t>());
343
+ });
344
+ THCudaCheck(cudaGetLastError());
345
+ return grad_input;
346
+ }
maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+
5
+ #include <THC/THC.h>
6
+ #include <THC/THCAtomics.cuh>
7
+ #include <THC/THCDeviceUtils.cuh>
8
+
9
+
10
+ // TODO make it in a common file
11
+ #define CUDA_1D_KERNEL_LOOP(i, n) \
12
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
13
+ i += blockDim.x * gridDim.x)
14
+
15
+
16
+ template <typename T>
17
+ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
18
+ const T spatial_scale, const int channels, const int height,
19
+ const int width, const int pooled_height, const int pooled_width,
20
+ const T* bottom_rois, T* top_data, int* argmax_data) {
21
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
22
+ // (n, c, ph, pw) is an element in the pooled output
23
+ int pw = index % pooled_width;
24
+ int ph = (index / pooled_width) % pooled_height;
25
+ int c = (index / pooled_width / pooled_height) % channels;
26
+ int n = index / pooled_width / pooled_height / channels;
27
+
28
+ const T* offset_bottom_rois = bottom_rois + n * 5;
29
+ int roi_batch_ind = offset_bottom_rois[0];
30
+ int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
31
+ int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
32
+ int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
33
+ int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
34
+
35
+ // Force malformed ROIs to be 1x1
36
+ int roi_width = max(roi_end_w - roi_start_w + 1, 1);
37
+ int roi_height = max(roi_end_h - roi_start_h + 1, 1);
38
+ T bin_size_h = static_cast<T>(roi_height)
39
+ / static_cast<T>(pooled_height);
40
+ T bin_size_w = static_cast<T>(roi_width)
41
+ / static_cast<T>(pooled_width);
42
+
43
+ int hstart = static_cast<int>(floor(static_cast<T>(ph)
44
+ * bin_size_h));
45
+ int wstart = static_cast<int>(floor(static_cast<T>(pw)
46
+ * bin_size_w));
47
+ int hend = static_cast<int>(ceil(static_cast<T>(ph + 1)
48
+ * bin_size_h));
49
+ int wend = static_cast<int>(ceil(static_cast<T>(pw + 1)
50
+ * bin_size_w));
51
+
52
+ // Add roi offsets and clip to input boundaries
53
+ hstart = min(max(hstart + roi_start_h, 0), height);
54
+ hend = min(max(hend + roi_start_h, 0), height);
55
+ wstart = min(max(wstart + roi_start_w, 0), width);
56
+ wend = min(max(wend + roi_start_w, 0), width);
57
+ bool is_empty = (hend <= hstart) || (wend <= wstart);
58
+
59
+ // Define an empty pooling region to be zero
60
+ T maxval = is_empty ? 0 : -FLT_MAX;
61
+ // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
62
+ int maxidx = -1;
63
+ const T* offset_bottom_data =
64
+ bottom_data + (roi_batch_ind * channels + c) * height * width;
65
+ for (int h = hstart; h < hend; ++h) {
66
+ for (int w = wstart; w < wend; ++w) {
67
+ int bottom_index = h * width + w;
68
+ if (offset_bottom_data[bottom_index] > maxval) {
69
+ maxval = offset_bottom_data[bottom_index];
70
+ maxidx = bottom_index;
71
+ }
72
+ }
73
+ }
74
+ top_data[index] = maxval;
75
+ argmax_data[index] = maxidx;
76
+ }
77
+ }
78
+
79
+ template <typename T>
80
+ __global__ void RoIPoolFBackward(const int nthreads, const T* top_diff,
81
+ const int* argmax_data, const int num_rois, const T spatial_scale,
82
+ const int channels, const int height, const int width,
83
+ const int pooled_height, const int pooled_width, T* bottom_diff,
84
+ const T* bottom_rois) {
85
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
86
+ // (n, c, ph, pw) is an element in the pooled output
87
+ int pw = index % pooled_width;
88
+ int ph = (index / pooled_width) % pooled_height;
89
+ int c = (index / pooled_width / pooled_height) % channels;
90
+ int n = index / pooled_width / pooled_height / channels;
91
+
92
+ const T* offset_bottom_rois = bottom_rois + n * 5;
93
+ int roi_batch_ind = offset_bottom_rois[0];
94
+ int bottom_offset = (roi_batch_ind * channels + c) * height * width;
95
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
96
+ const T* offset_top_diff = top_diff + top_offset;
97
+ T* offset_bottom_diff = bottom_diff + bottom_offset;
98
+ const int* offset_argmax_data = argmax_data + top_offset;
99
+
100
+ int argmax = offset_argmax_data[ph * pooled_width + pw];
101
+ if (argmax != -1) {
102
+ atomicAdd(
103
+ offset_bottom_diff + argmax,
104
+ static_cast<T>(offset_top_diff[ph * pooled_width + pw]));
105
+
106
+ }
107
+ }
108
+ }
109
+
110
+ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
111
+ const at::Tensor& rois,
112
+ const float spatial_scale,
113
+ const int pooled_height,
114
+ const int pooled_width) {
115
+ AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
116
+ AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
117
+
118
+ auto num_rois = rois.size(0);
119
+ auto channels = input.size(1);
120
+ auto height = input.size(2);
121
+ auto width = input.size(3);
122
+
123
+ auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
124
+ auto output_size = num_rois * pooled_height * pooled_width * channels;
125
+ auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt));
126
+
127
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
128
+
129
+ dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
130
+ dim3 block(512);
131
+
132
+ if (output.numel() == 0) {
133
+ THCudaCheck(cudaGetLastError());
134
+ return std::make_tuple(output, argmax);
135
+ }
136
+
137
+ AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] {
138
+ RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>(
139
+ output_size,
140
+ input.contiguous().data<scalar_t>(),
141
+ spatial_scale,
142
+ channels,
143
+ height,
144
+ width,
145
+ pooled_height,
146
+ pooled_width,
147
+ rois.contiguous().data<scalar_t>(),
148
+ output.data<scalar_t>(),
149
+ argmax.data<int>());
150
+ });
151
+ THCudaCheck(cudaGetLastError());
152
+ return std::make_tuple(output, argmax);
153
+ }
154
+
155
+ // TODO remove the dependency on input and use instead its sizes -> save memory
156
+ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
157
+ const at::Tensor& input,
158
+ const at::Tensor& rois,
159
+ const at::Tensor& argmax,
160
+ const float spatial_scale,
161
+ const int pooled_height,
162
+ const int pooled_width,
163
+ const int batch_size,
164
+ const int channels,
165
+ const int height,
166
+ const int width) {
167
+ AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
168
+ AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
169
+ // TODO add more checks
170
+
171
+ auto num_rois = rois.size(0);
172
+ auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
173
+
174
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
175
+
176
+ dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
177
+ dim3 block(512);
178
+
179
+ // handle possibly empty gradients
180
+ if (grad.numel() == 0) {
181
+ THCudaCheck(cudaGetLastError());
182
+ return grad_input;
183
+ }
184
+
185
+ AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] {
186
+ RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>(
187
+ grad.numel(),
188
+ grad.contiguous().data<scalar_t>(),
189
+ argmax.data<int>(),
190
+ num_rois,
191
+ spatial_scale,
192
+ channels,
193
+ height,
194
+ width,
195
+ pooled_height,
196
+ pooled_width,
197
+ grad_input.data<scalar_t>(),
198
+ rois.contiguous().data<scalar_t>());
199
+ });
200
+ THCudaCheck(cudaGetLastError());
201
+ return grad_input;
202
+ }
maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ // This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
3
+ // Cheng-Yang Fu
4
5
+ #include <ATen/ATen.h>
6
+ #include <ATen/cuda/CUDAContext.h>
7
+
8
+ #include <THC/THC.h>
9
+ #include <THC/THCAtomics.cuh>
10
+ #include <THC/THCDeviceUtils.cuh>
11
+
12
+ #include <cfloat>
13
+
14
+ // TODO make it in a common file
15
+ #define CUDA_1D_KERNEL_LOOP(i, n) \
16
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
17
+ i += blockDim.x * gridDim.x)
18
+
19
+
20
+ template <typename T>
21
+ __global__ void SigmoidFocalLossForward(const int nthreads,
22
+ const T* logits,
23
+ const int* targets,
24
+ const int num_classes,
25
+ const float gamma,
26
+ const float alpha,
27
+ const int num,
28
+ T* losses) {
29
+ CUDA_1D_KERNEL_LOOP(i, nthreads) {
30
+
31
+ int n = i / num_classes;
32
+ int d = i % num_classes; // current class[0~79];
33
+ int t = targets[n]; // target class [1~80];
34
+
35
+ // Decide it is positive or negative case.
36
+ T c1 = (t == (d+1));
37
+ T c2 = (t>=0 & t != (d+1));
38
+
39
+ T zn = (1.0 - alpha);
40
+ T zp = (alpha);
41
+
42
+ // p = 1. / 1. + expf(-x); p = sigmoid(x)
43
+ T p = 1. / (1. + expf(-logits[i]));
44
+
45
+ // (1-p)**gamma * log(p) where
46
+ T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
47
+
48
+ // p**gamma * log(1-p)
49
+ T term2 = powf(p, gamma) *
50
+ (-1. * logits[i] * (logits[i] >= 0) -
51
+ logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
52
+
53
+ losses[i] = 0.0;
54
+ losses[i] += -c1 * term1 * zp;
55
+ losses[i] += -c2 * term2 * zn;
56
+
57
+ } // CUDA_1D_KERNEL_LOOP
58
+ } // SigmoidFocalLossForward
59
+
60
+
61
+ template <typename T>
62
+ __global__ void SigmoidFocalLossBackward(const int nthreads,
63
+ const T* logits,
64
+ const int* targets,
65
+ const T* d_losses,
66
+ const int num_classes,
67
+ const float gamma,
68
+ const float alpha,
69
+ const int num,
70
+ T* d_logits) {
71
+ CUDA_1D_KERNEL_LOOP(i, nthreads) {
72
+
73
+ int n = i / num_classes;
74
+ int d = i % num_classes; // current class[0~79];
75
+ int t = targets[n]; // target class [1~80], 0 is background;
76
+
77
+ // Decide it is positive or negative case.
78
+ T c1 = (t == (d+1));
79
+ T c2 = (t>=0 & t != (d+1));
80
+
81
+ T zn = (1.0 - alpha);
82
+ T zp = (alpha);
83
+ // p = 1. / 1. + expf(-x); p = sigmoid(x)
84
+ T p = 1. / (1. + expf(-logits[i]));
85
+
86
+ // (1-p)**g * (1 - p - g*p*log(p)
87
+ T term1 = powf((1. - p), gamma) *
88
+ (1. - p - (p * gamma * logf(max(p, FLT_MIN))));
89
+
90
+ // (p**g) * (g*(1-p)*log(1-p) - p)
91
+ T term2 = powf(p, gamma) *
92
+ ((-1. * logits[i] * (logits[i] >= 0) -
93
+ logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
94
+ (1. - p) * gamma - p);
95
+ d_logits[i] = 0.0;
96
+ d_logits[i] += -c1 * term1 * zp;
97
+ d_logits[i] += -c2 * term2 * zn;
98
+ d_logits[i] = d_logits[i] * d_losses[i];
99
+
100
+ } // CUDA_1D_KERNEL_LOOP
101
+ } // SigmoidFocalLossBackward
102
+
103
+
104
+ at::Tensor SigmoidFocalLoss_forward_cuda(
105
+ const at::Tensor& logits,
106
+ const at::Tensor& targets,
107
+ const int num_classes,
108
+ const float gamma,
109
+ const float alpha) {
110
+ AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
111
+ AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
112
+ AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
113
+
114
+ const int num_samples = logits.size(0);
115
+
116
+ auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
117
+ auto losses_size = num_samples * logits.size(1);
118
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
119
+
120
+ dim3 grid(std::min(THCCeilDiv((long)losses_size, 512L), 4096L));
121
+
122
+ dim3 block(512);
123
+
124
+ if (losses.numel() == 0) {
125
+ THCudaCheck(cudaGetLastError());
126
+ return losses;
127
+ }
128
+
129
+ AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] {
130
+ SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
131
+ losses_size,
132
+ logits.contiguous().data<scalar_t>(),
133
+ targets.contiguous().data<int>(),
134
+ num_classes,
135
+ gamma,
136
+ alpha,
137
+ num_samples,
138
+ losses.data<scalar_t>());
139
+ });
140
+ THCudaCheck(cudaGetLastError());
141
+ return losses;
142
+ }
143
+
144
+
145
+ at::Tensor SigmoidFocalLoss_backward_cuda(
146
+ const at::Tensor& logits,
147
+ const at::Tensor& targets,
148
+ const at::Tensor& d_losses,
149
+ const int num_classes,
150
+ const float gamma,
151
+ const float alpha) {
152
+ AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
153
+ AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
154
+ AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor");
155
+
156
+ AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
157
+
158
+ const int num_samples = logits.size(0);
159
+ AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");
160
+
161
+ auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
162
+ auto d_logits_size = num_samples * logits.size(1);
163
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
164
+
165
+ dim3 grid(std::min(THCCeilDiv((long)d_logits_size, 512L), 4096L));
166
+ dim3 block(512);
167
+
168
+ if (d_logits.numel() == 0) {
169
+ THCudaCheck(cudaGetLastError());
170
+ return d_logits;
171
+ }
172
+
173
+ AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] {
174
+ SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
175
+ d_logits_size,
176
+ logits.contiguous().data<scalar_t>(),
177
+ targets.contiguous().data<int>(),
178
+ d_losses.contiguous().data<scalar_t>(),
179
+ num_classes,
180
+ gamma,
181
+ alpha,
182
+ num_samples,
183
+ d_logits.data<scalar_t>());
184
+ });
185
+
186
+ THCudaCheck(cudaGetLastError());
187
+ return d_logits;
188
+ }
189
+
maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
+
4
+ #include <ATen/ATen.h>
5
+ #include <ATen/cuda/CUDAContext.h>
6
+
7
+ #include <THC/THC.h>
8
+ #include <THC/THCDeviceUtils.cuh>
9
+
10
+ #include <vector>
11
+ #include <iostream>
12
+ #include <cmath>
13
+
14
+
15
+ void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
16
+ const int channels, const int height, const int width,
17
+ const int ksize_h, const int ksize_w, const int pad_h,
18
+ const int pad_w, const int stride_h, const int stride_w,
19
+ const int dilation_h, const int dilation_w,
20
+ const int parallel_imgs, const int deformable_group,
21
+ at::Tensor data_col);
22
+
23
+ void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
24
+ const int channels, const int height, const int width,
25
+ const int ksize_h, const int ksize_w, const int pad_h,
26
+ const int pad_w, const int stride_h, const int stride_w,
27
+ const int dilation_h, const int dilation_w,
28
+ const int parallel_imgs, const int deformable_group,
29
+ at::Tensor grad_im);
30
+
31
+ void deformable_col2im_coord(
32
+ const at::Tensor data_col, const at::Tensor data_im,
33
+ const at::Tensor data_offset, const int channels, const int height,
34
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
35
+ const int pad_w, const int stride_h, const int stride_w,
36
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
37
+ const int deformable_group, at::Tensor grad_offset);
38
+
39
+ void modulated_deformable_im2col_cuda(
40
+ const at::Tensor data_im, const at::Tensor data_offset,
41
+ const at::Tensor data_mask, const int batch_size, const int channels,
42
+ const int height_im, const int width_im, const int height_col,
43
+ const int width_col, const int kernel_h, const int kenerl_w,
44
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
45
+ const int dilation_h, const int dilation_w, const int deformable_group,
46
+ at::Tensor data_col);
47
+
48
+ void modulated_deformable_col2im_cuda(
49
+ const at::Tensor data_col, const at::Tensor data_offset,
50
+ const at::Tensor data_mask, const int batch_size, const int channels,
51
+ const int height_im, const int width_im, const int height_col,
52
+ const int width_col, const int kernel_h, const int kenerl_w,
53
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
54
+ const int dilation_h, const int dilation_w, const int deformable_group,
55
+ at::Tensor grad_im);
56
+
57
+ void modulated_deformable_col2im_coord_cuda(
58
+ const at::Tensor data_col, const at::Tensor data_im,
59
+ const at::Tensor data_offset, const at::Tensor data_mask,
60
+ const int batch_size, const int channels, const int height_im,
61
+ const int width_im, const int height_col, const int width_col,
62
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
63
+ const int stride_h, const int stride_w, const int dilation_h,
64
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
65
+ at::Tensor grad_mask);
66
+
67
+ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
68
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
69
+ int padW, int dilationH, int dilationW, int group,
70
+ int deformable_group)
71
+ {
72
+ AT_CHECK(weight.ndimension() == 4,
73
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
74
+ "but got: %s",
75
+ weight.ndimension());
76
+
77
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
78
+
79
+ AT_CHECK(kW > 0 && kH > 0,
80
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
81
+ kW);
82
+
83
+ AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
84
+ "kernel size should be consistent with weight, ",
85
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
86
+ kW, weight.size(2), weight.size(3));
87
+
88
+ AT_CHECK(dW > 0 && dH > 0,
89
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
90
+
91
+ AT_CHECK(
92
+ dilationW > 0 && dilationH > 0,
93
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
94
+ dilationH, dilationW);
95
+
96
+ int ndim = input.ndimension();
97
+ int dimf = 0;
98
+ int dimh = 1;
99
+ int dimw = 2;
100
+
101
+ if (ndim == 4) {
102
+ dimf++;
103
+ dimh++;
104
+ dimw++;
105
+ }
106
+
107
+ AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
108
+ ndim);
109
+
110
+ long nInputPlane = weight.size(1) * group;
111
+ long inputHeight = input.size(dimh);
112
+ long inputWidth = input.size(dimw);
113
+ long nOutputPlane = weight.size(0);
114
+ long outputHeight =
115
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
116
+ long outputWidth =
117
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
118
+
119
+ AT_CHECK(nInputPlane % deformable_group == 0,
120
+ "input channels must divide deformable group size");
121
+
122
+ if (outputWidth < 1 || outputHeight < 1)
123
+ AT_ERROR(
124
+ "Given input size: (%ld x %ld x %ld). "
125
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
126
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
127
+ outputWidth);
128
+
129
+ AT_CHECK(input.size(1) == nInputPlane,
130
+ "invalid number of input planes, expected: %d, but got: %d",
131
+ nInputPlane, input.size(1));
132
+
133
+ AT_CHECK((inputHeight >= kH && inputWidth >= kW),
134
+ "input image is smaller than kernel");
135
+
136
+ AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
137
+ "invalid spatial size of offset, expected height: %d width: %d, but "
138
+ "got height: %d width: %d",
139
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
140
+
141
+ AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
142
+ "invalid number of channels of offset");
143
+
144
+ if (gradOutput != NULL) {
145
+ AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
146
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
147
+ nOutputPlane, gradOutput->size(dimf));
148
+
149
+ AT_CHECK((gradOutput->size(dimh) == outputHeight &&
150
+ gradOutput->size(dimw) == outputWidth),
151
+ "invalid size of gradOutput, expected height: %d width: %d , but "
152
+ "got height: %d width: %d",
153
+ outputHeight, outputWidth, gradOutput->size(dimh),
154
+ gradOutput->size(dimw));
155
+ }
156
+ }
157
+
158
+ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
159
+ at::Tensor offset, at::Tensor output,
160
+ at::Tensor columns, at::Tensor ones, int kW,
161
+ int kH, int dW, int dH, int padW, int padH,
162
+ int dilationW, int dilationH, int group,
163
+ int deformable_group, int im2col_step)
164
+ {
165
+ // todo: resize columns to include im2col: done
166
+ // todo: add im2col_step as input
167
+ // todo: add new output buffer and transpose it to output (or directly
168
+ // transpose output) todo: possibly change data indexing because of
169
+ // parallel_imgs
170
+
171
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
172
+ dilationH, dilationW, group, deformable_group);
173
+
174
+ input = input.contiguous();
175
+ offset = offset.contiguous();
176
+ weight = weight.contiguous();
177
+
178
+ int batch = 1;
179
+ if (input.ndimension() == 3) {
180
+ // Force batch
181
+ batch = 0;
182
+ input.unsqueeze_(0);
183
+ offset.unsqueeze_(0);
184
+ }
185
+
186
+ // todo: assert batchsize dividable by im2col_step
187
+
188
+ long batchSize = input.size(0);
189
+ long nInputPlane = input.size(1);
190
+ long inputHeight = input.size(2);
191
+ long inputWidth = input.size(3);
192
+
193
+ long nOutputPlane = weight.size(0);
194
+
195
+ long outputWidth =
196
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
197
+ long outputHeight =
198
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
199
+
200
+ AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
201
+
202
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
203
+ outputHeight, outputWidth});
204
+ columns = at::zeros(
205
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
206
+ input.options());
207
+
208
+ if (ones.ndimension() != 2 ||
209
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
210
+ ones = at::ones({outputHeight, outputWidth}, input.options());
211
+ }
212
+
213
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
214
+ inputHeight, inputWidth});
215
+ offset =
216
+ offset.view({batchSize / im2col_step, im2col_step,
217
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
218
+
219
+ at::Tensor output_buffer =
220
+ at::zeros({batchSize / im2col_step, nOutputPlane,
221
+ im2col_step * outputHeight, outputWidth},
222
+ output.options());
223
+
224
+ output_buffer = output_buffer.view(
225
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
226
+ output_buffer.size(2), output_buffer.size(3)});
227
+
228
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
229
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
230
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
231
+ dilationW, im2col_step, deformable_group, columns);
232
+
233
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
234
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
235
+ weight.size(2), weight.size(3)});
236
+
237
+ for (int g = 0; g < group; g++) {
238
+ output_buffer[elt][g] = output_buffer[elt][g]
239
+ .flatten(1)
240
+ .addmm_(weight[g].flatten(1), columns[g])
241
+ .view_as(output_buffer[elt][g]);
242
+ }
243
+ }
244
+
245
+ output_buffer = output_buffer.view(
246
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
247
+ output_buffer.size(3), output_buffer.size(4)});
248
+
249
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
250
+ im2col_step, outputHeight, outputWidth});
251
+ output_buffer.transpose_(1, 2);
252
+ output.copy_(output_buffer);
253
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
254
+
255
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
256
+ offset = offset.view(
257
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
258
+
259
+ if (batch == 0) {
260
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
261
+ input = input.view({nInputPlane, inputHeight, inputWidth});
262
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
263
+ }
264
+
265
+ return 1;
266
+ }
267
+
268
+ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
269
+ at::Tensor gradOutput, at::Tensor gradInput,
270
+ at::Tensor gradOffset, at::Tensor weight,
271
+ at::Tensor columns, int kW, int kH, int dW,
272
+ int dH, int padW, int padH, int dilationW,
273
+ int dilationH, int group,
274
+ int deformable_group, int im2col_step)
275
+ {
276
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
277
+ dilationH, dilationW, group, deformable_group);
278
+
279
+ input = input.contiguous();
280
+ offset = offset.contiguous();
281
+ gradOutput = gradOutput.contiguous();
282
+ weight = weight.contiguous();
283
+
284
+ int batch = 1;
285
+
286
+ if (input.ndimension() == 3) {
287
+ // Force batch
288
+ batch = 0;
289
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
290
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
291
+ gradOutput = gradOutput.view(
292
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
293
+ }
294
+
295
+ long batchSize = input.size(0);
296
+ long nInputPlane = input.size(1);
297
+ long inputHeight = input.size(2);
298
+ long inputWidth = input.size(3);
299
+
300
+ long nOutputPlane = weight.size(0);
301
+
302
+ long outputWidth =
303
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
304
+ long outputHeight =
305
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
306
+
307
+ AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
308
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
309
+ columns = at::zeros(
310
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
311
+ input.options());
312
+
313
+ // change order of grad output
314
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
315
+ nOutputPlane, outputHeight, outputWidth});
316
+ gradOutput.transpose_(1, 2);
317
+
318
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
319
+ inputHeight, inputWidth});
320
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
321
+ inputHeight, inputWidth});
322
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
323
+ deformable_group * 2 * kH * kW, outputHeight,
324
+ outputWidth});
325
+ offset =
326
+ offset.view({batchSize / im2col_step, im2col_step,
327
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
328
+
329
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
330
+ // divide into groups
331
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
332
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
333
+ weight.size(2), weight.size(3)});
334
+ gradOutput = gradOutput.view(
335
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
336
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
337
+
338
+ for (int g = 0; g < group; g++) {
339
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
340
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
341
+ }
342
+
343
+ columns =
344
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
345
+ gradOutput = gradOutput.view(
346
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
347
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
348
+
349
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
350
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
351
+ dilationH, dilationW, im2col_step, deformable_group,
352
+ gradOffset[elt]);
353
+
354
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
355
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
356
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
357
+ }
358
+
359
+ gradOutput.transpose_(1, 2);
360
+ gradOutput =
361
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
362
+
363
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
364
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
365
+ gradOffset = gradOffset.view(
366
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
367
+ offset = offset.view(
368
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
369
+
370
+ if (batch == 0) {
371
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
372
+ input = input.view({nInputPlane, inputHeight, inputWidth});
373
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
374
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
375
+ gradOffset =
376
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
377
+ }
378
+
379
+ return 1;
380
+ }
381
+
382
+ int deform_conv_backward_parameters_cuda(
383
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
384
+ at::Tensor gradWeight, // at::Tensor gradBias,
385
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
386
+ int padW, int padH, int dilationW, int dilationH, int group,
387
+ int deformable_group, float scale, int im2col_step)
388
+ {
389
+ // todo: transpose and reshape outGrad
390
+ // todo: reshape columns
391
+ // todo: add im2col_step as input
392
+
393
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
394
+ padW, dilationH, dilationW, group, deformable_group);
395
+
396
+ input = input.contiguous();
397
+ offset = offset.contiguous();
398
+ gradOutput = gradOutput.contiguous();
399
+
400
+ int batch = 1;
401
+
402
+ if (input.ndimension() == 3) {
403
+ // Force batch
404
+ batch = 0;
405
+ input = input.view(
406
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
407
+ gradOutput = gradOutput.view(
408
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
409
+ }
410
+
411
+ long batchSize = input.size(0);
412
+ long nInputPlane = input.size(1);
413
+ long inputHeight = input.size(2);
414
+ long inputWidth = input.size(3);
415
+
416
+ long nOutputPlane = gradWeight.size(0);
417
+
418
+ long outputWidth =
419
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
420
+ long outputHeight =
421
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
422
+
423
+ AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
424
+
425
+ columns = at::zeros(
426
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
427
+ input.options());
428
+
429
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
430
+ nOutputPlane, outputHeight, outputWidth});
431
+ gradOutput.transpose_(1, 2);
432
+
433
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
434
+ gradOutputBuffer =
435
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
436
+ outputHeight, outputWidth});
437
+ gradOutputBuffer.copy_(gradOutput);
438
+ gradOutputBuffer =
439
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
440
+ im2col_step * outputHeight, outputWidth});
441
+
442
+ gradOutput.transpose_(1, 2);
443
+ gradOutput =
444
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
445
+
446
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
447
+ inputHeight, inputWidth});
448
+ offset =
449
+ offset.view({batchSize / im2col_step, im2col_step,
450
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
451
+
452
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
453
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
454
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
455
+ dilationW, im2col_step, deformable_group, columns);
456
+
457
+ // divide into group
458
+ gradOutputBuffer = gradOutputBuffer.view(
459
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
460
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
461
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
462
+ gradWeight =
463
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
464
+ gradWeight.size(2), gradWeight.size(3)});
465
+
466
+ for (int g = 0; g < group; g++) {
467
+ gradWeight[g] = gradWeight[g]
468
+ .flatten(1)
469
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
470
+ columns[g].transpose(1, 0), 1.0, scale)
471
+ .view_as(gradWeight[g]);
472
+ }
473
+ gradOutputBuffer = gradOutputBuffer.view(
474
+ {gradOutputBuffer.size(0),
475
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
476
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
477
+ columns =
478
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
479
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
480
+ gradWeight.size(2), gradWeight.size(3),
481
+ gradWeight.size(4)});
482
+ }
483
+
484
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
485
+ offset = offset.view(
486
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
487
+
488
+ if (batch == 0) {
489
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
490
+ input = input.view({nInputPlane, inputHeight, inputWidth});
491
+ }
492
+
493
+ return 1;
494
+ }
495
+
496
+ void modulated_deform_conv_cuda_forward(
497
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
498
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
499
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
500
+ const int pad_h, const int pad_w, const int dilation_h,
501
+ const int dilation_w, const int group, const int deformable_group,
502
+ const bool with_bias)
503
+ {
504
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
505
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
506
+
507
+ const int batch = input.size(0);
508
+ const int channels = input.size(1);
509
+ const int height = input.size(2);
510
+ const int width = input.size(3);
511
+
512
+ const int channels_out = weight.size(0);
513
+ const int channels_kernel = weight.size(1);
514
+ const int kernel_h_ = weight.size(2);
515
+ const int kernel_w_ = weight.size(3);
516
+
517
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
518
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
519
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
520
+ if (channels != channels_kernel * group)
521
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
522
+ channels, channels_kernel * group);
523
+
524
+ const int height_out =
525
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
526
+ const int width_out =
527
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
528
+
529
+ if (ones.ndimension() != 2 ||
530
+ ones.size(0) * ones.size(1) < height_out * width_out) {
531
+ // Resize plane and fill with ones...
532
+ ones = at::ones({height_out, width_out}, input.options());
533
+ }
534
+
535
+ // resize output
536
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
537
+ // resize temporary columns
538
+ columns =
539
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
540
+ input.options());
541
+
542
+ output = output.view({output.size(0), group, output.size(1) / group,
543
+ output.size(2), output.size(3)});
544
+
545
+ for (int b = 0; b < batch; b++) {
546
+ modulated_deformable_im2col_cuda(
547
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
548
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
549
+ dilation_h, dilation_w, deformable_group, columns);
550
+
551
+ // divide into group
552
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
553
+ weight.size(2), weight.size(3)});
554
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
555
+
556
+ for (int g = 0; g < group; g++) {
557
+ output[b][g] = output[b][g]
558
+ .flatten(1)
559
+ .addmm_(weight[g].flatten(1), columns[g])
560
+ .view_as(output[b][g]);
561
+ }
562
+
563
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
564
+ weight.size(3), weight.size(4)});
565
+ columns =
566
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
567
+ }
568
+
569
+ output = output.view({output.size(0), output.size(1) * output.size(2),
570
+ output.size(3), output.size(4)});
571
+
572
+ if (with_bias) {
573
+ output += bias.view({1, bias.size(0), 1, 1});
574
+ }
575
+ }
576
+
577
+ void modulated_deform_conv_cuda_backward(
578
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
579
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
580
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
581
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
582
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
583
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
584
+ const bool with_bias)
585
+ {
586
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
587
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
588
+
589
+ const int batch = input.size(0);
590
+ const int channels = input.size(1);
591
+ const int height = input.size(2);
592
+ const int width = input.size(3);
593
+
594
+ const int channels_kernel = weight.size(1);
595
+ const int kernel_h_ = weight.size(2);
596
+ const int kernel_w_ = weight.size(3);
597
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
598
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
599
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
600
+ if (channels != channels_kernel * group)
601
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
602
+ channels, channels_kernel * group);
603
+
604
+ const int height_out =
605
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
606
+ const int width_out =
607
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
608
+
609
+ if (ones.ndimension() != 2 ||
610
+ ones.size(0) * ones.size(1) < height_out * width_out) {
611
+ // Resize plane and fill with ones...
612
+ ones = at::ones({height_out, width_out}, input.options());
613
+ }
614
+
615
+ grad_input = grad_input.view({batch, channels, height, width});
616
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
617
+ input.options());
618
+
619
+ grad_output =
620
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
621
+ grad_output.size(2), grad_output.size(3)});
622
+
623
+ for (int b = 0; b < batch; b++) {
624
+ // divide int group
625
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
626
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
627
+ weight.size(2), weight.size(3)});
628
+
629
+ for (int g = 0; g < group; g++) {
630
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
631
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
632
+ }
633
+
634
+ columns =
635
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
636
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
637
+ weight.size(3), weight.size(4)});
638
+
639
+ // gradient w.r.t. input coordinate data
640
+ modulated_deformable_col2im_coord_cuda(
641
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
642
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
643
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
644
+ grad_mask[b]);
645
+ // gradient w.r.t. input data
646
+ modulated_deformable_col2im_cuda(
647
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
648
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
649
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
650
+
651
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
652
+ // group
653
+ modulated_deformable_im2col_cuda(
654
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
655
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
656
+ dilation_h, dilation_w, deformable_group, columns);
657
+
658
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
659
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
660
+ grad_weight.size(1), grad_weight.size(2),
661
+ grad_weight.size(3)});
662
+ if (with_bias)
663
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
664
+
665
+ for (int g = 0; g < group; g++) {
666
+ grad_weight[g] =
667
+ grad_weight[g]
668
+ .flatten(1)
669
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
670
+ .view_as(grad_weight[g]);
671
+ if (with_bias) {
672
+ grad_bias[g] =
673
+ grad_bias[g]
674
+ .view({-1, 1})
675
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
676
+ .view(-1);
677
+ }
678
+ }
679
+
680
+ columns =
681
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
682
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
683
+ grad_weight.size(2), grad_weight.size(3),
684
+ grad_weight.size(4)});
685
+ if (with_bias)
686
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
687
+ }
688
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
689
+ grad_output.size(2), grad_output.size(3),
690
+ grad_output.size(4)});
691
+ }
maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3
+ *
4
+ * COPYRIGHT
5
+ *
6
+ * All contributions by the University of California:
7
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8
+ * All rights reserved.
9
+ *
10
+ * All other contributions:
11
+ * Copyright (c) 2014-2017, the respective contributors
12
+ * All rights reserved.
13
+ *
14
+ * Caffe uses a shared copyright model: each contributor holds copyright over
15
+ * their contributions to Caffe. The project versioning records all such
16
+ * contribution and copyright details. If a contributor wants to further mark
17
+ * their specific copyright on a particular contribution, they should indicate
18
+ * their copyright solely in the commit message of the change when it is
19
+ * committed.
20
+ *
21
+ * LICENSE
22
+ *
23
+ * Redistribution and use in source and binary forms, with or without
24
+ * modification, are permitted provided that the following conditions are met:
25
+ *
26
+ * 1. Redistributions of source code must retain the above copyright notice, this
27
+ * list of conditions and the following disclaimer.
28
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
29
+ * this list of conditions and the following disclaimer in the documentation
30
+ * and/or other materials provided with the distribution.
31
+ *
32
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42
+ *
43
+ * CONTRIBUTION AGREEMENT
44
+ *
45
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
46
+ * or otherwise, the contributor releases their content to the
47
+ * license and copyright terms herein.
48
+ *
49
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
50
+ *
51
+ * Copyright (c) 2018 Microsoft
52
+ * Licensed under The MIT License [see LICENSE for details]
53
+ * \file modulated_deformable_im2col.cuh
54
+ * \brief Function definitions of converting an image to
55
+ * column matrix based on kernel, padding, dilation, and offset.
56
+ * These functions are mainly used in deformable convolution operators.
57
+ * \ref: https://arxiv.org/abs/1703.06211
58
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
59
+ */
60
+
61
+ // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
62
+
63
+
64
+ #include <ATen/ATen.h>
65
+ #include <THC/THCAtomics.cuh>
66
+ #include <stdio.h>
67
+ #include <math.h>
68
+ #include <float.h>
69
+
70
+ using namespace at;
71
+
72
+ #define CUDA_KERNEL_LOOP(i, n) \
73
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
74
+ i += blockDim.x * gridDim.x)
75
+
76
+ const int CUDA_NUM_THREADS = 1024;
77
+ const int kMaxGridNum = 65535;
78
+ inline int GET_BLOCKS(const int N)
79
+ {
80
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
81
+ }
82
+
83
+ /*
84
+ const int CUDA_NUM_THREADS = 1024;
85
+
86
+ inline int GET_BLOCKS(const int N)
87
+ {
88
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
89
+ }*/
90
+
91
+ template <typename scalar_t>
92
+ __device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
93
+ const int height, const int width, scalar_t h, scalar_t w)
94
+ {
95
+
96
+ int h_low = floor(h);
97
+ int w_low = floor(w);
98
+ int h_high = h_low + 1;
99
+ int w_high = w_low + 1;
100
+
101
+ scalar_t lh = h - h_low;
102
+ scalar_t lw = w - w_low;
103
+ scalar_t hh = 1 - lh, hw = 1 - lw;
104
+
105
+ scalar_t v1 = 0;
106
+ if (h_low >= 0 && w_low >= 0)
107
+ v1 = bottom_data[h_low * data_width + w_low];
108
+ scalar_t v2 = 0;
109
+ if (h_low >= 0 && w_high <= width - 1)
110
+ v2 = bottom_data[h_low * data_width + w_high];
111
+ scalar_t v3 = 0;
112
+ if (h_high <= height - 1 && w_low >= 0)
113
+ v3 = bottom_data[h_high * data_width + w_low];
114
+ scalar_t v4 = 0;
115
+ if (h_high <= height - 1 && w_high <= width - 1)
116
+ v4 = bottom_data[h_high * data_width + w_high];
117
+
118
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
119
+
120
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
121
+ return val;
122
+ }
123
+
124
+ template <typename scalar_t>
125
+ __device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
126
+ const int h, const int w, const int height, const int width)
127
+ {
128
+
129
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
130
+ {
131
+ //empty
132
+ return 0;
133
+ }
134
+
135
+ int argmax_h_low = floor(argmax_h);
136
+ int argmax_w_low = floor(argmax_w);
137
+ int argmax_h_high = argmax_h_low + 1;
138
+ int argmax_w_high = argmax_w_low + 1;
139
+
140
+ scalar_t weight = 0;
141
+ if (h == argmax_h_low && w == argmax_w_low)
142
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
143
+ if (h == argmax_h_low && w == argmax_w_high)
144
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
145
+ if (h == argmax_h_high && w == argmax_w_low)
146
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
147
+ if (h == argmax_h_high && w == argmax_w_high)
148
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
149
+ return weight;
150
+ }
151
+
152
+ template <typename scalar_t>
153
+ __device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
154
+ const int height, const int width, const scalar_t *im_data,
155
+ const int data_width, const int bp_dir)
156
+ {
157
+
158
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
159
+ {
160
+ //empty
161
+ return 0;
162
+ }
163
+
164
+ int argmax_h_low = floor(argmax_h);
165
+ int argmax_w_low = floor(argmax_w);
166
+ int argmax_h_high = argmax_h_low + 1;
167
+ int argmax_w_high = argmax_w_low + 1;
168
+
169
+ scalar_t weight = 0;
170
+
171
+ if (bp_dir == 0)
172
+ {
173
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
174
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
175
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
176
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
177
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
178
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
179
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
180
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
181
+ }
182
+ else if (bp_dir == 1)
183
+ {
184
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
185
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
186
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
187
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
188
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
189
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
190
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
191
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
192
+ }
193
+
194
+ return weight;
195
+ }
196
+
197
+ template <typename scalar_t>
198
+ __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
199
+ const int height, const int width, const int kernel_h, const int kernel_w,
200
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
201
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
202
+ const int batch_size, const int num_channels, const int deformable_group,
203
+ const int height_col, const int width_col,
204
+ scalar_t *data_col)
205
+ {
206
+ CUDA_KERNEL_LOOP(index, n)
207
+ {
208
+ // index index of output matrix
209
+ const int w_col = index % width_col;
210
+ const int h_col = (index / width_col) % height_col;
211
+ const int b_col = (index / width_col / height_col) % batch_size;
212
+ const int c_im = (index / width_col / height_col) / batch_size;
213
+ const int c_col = c_im * kernel_h * kernel_w;
214
+
215
+ // compute deformable group index
216
+ const int deformable_group_index = c_im / channel_per_deformable_group;
217
+
218
+ const int h_in = h_col * stride_h - pad_h;
219
+ const int w_in = w_col * stride_w - pad_w;
220
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
221
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
222
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
223
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
224
+
225
+ for (int i = 0; i < kernel_h; ++i)
226
+ {
227
+ for (int j = 0; j < kernel_w; ++j)
228
+ {
229
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
230
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
231
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
232
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
233
+ scalar_t val = static_cast<scalar_t>(0);
234
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
235
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
236
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
237
+ {
238
+ //const scalar_t map_h = i * dilation_h + offset_h;
239
+ //const scalar_t map_w = j * dilation_w + offset_w;
240
+ //const int cur_height = height - h_in;
241
+ //const int cur_width = width - w_in;
242
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
243
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
244
+ }
245
+ *data_col_ptr = val;
246
+ data_col_ptr += batch_size * height_col * width_col;
247
+ }
248
+ }
249
+ }
250
+ }
251
+
252
+ void deformable_im2col(
253
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
254
+ const int height, const int width, const int ksize_h, const int ksize_w,
255
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
256
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
257
+ const int deformable_group, at::Tensor data_col)
258
+ {
259
+ // num_axes should be smaller than block size
260
+ // todo: check parallel_imgs is correctly passed in
261
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
262
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
263
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
264
+ int channel_per_deformable_group = channels / deformable_group;
265
+
266
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
267
+ data_im.type(), "deformable_im2col_gpu", ([&] {
268
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
269
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
270
+ scalar_t *data_col_ = data_col.data<scalar_t>();
271
+
272
+ deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
273
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
274
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
275
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
276
+ height_col, width_col, data_col_);
277
+ }));
278
+
279
+ cudaError_t err = cudaGetLastError();
280
+ if (err != cudaSuccess)
281
+ {
282
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
283
+ }
284
+ }
285
+
286
+ template <typename scalar_t>
287
+ __global__ void deformable_col2im_gpu_kernel(
288
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
289
+ const int channels, const int height, const int width,
290
+ const int kernel_h, const int kernel_w,
291
+ const int pad_h, const int pad_w,
292
+ const int stride_h, const int stride_w,
293
+ const int dilation_h, const int dilation_w,
294
+ const int channel_per_deformable_group,
295
+ const int batch_size, const int deformable_group,
296
+ const int height_col, const int width_col,
297
+ scalar_t *grad_im)
298
+ {
299
+ CUDA_KERNEL_LOOP(index, n)
300
+ {
301
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
302
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
303
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
304
+ // compute the start and end of the output
305
+
306
+ const int deformable_group_index = c / channel_per_deformable_group;
307
+
308
+ int w_out = index % width_col;
309
+ int h_out = (index / width_col) % height_col;
310
+ int b = (index / width_col / height_col) % batch_size;
311
+ int w_in = w_out * stride_w - pad_w;
312
+ int h_in = h_out * stride_h - pad_h;
313
+
314
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
315
+ 2 * kernel_h * kernel_w * height_col * width_col;
316
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
317
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
318
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
319
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
320
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
321
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
322
+
323
+ const scalar_t cur_top_grad = data_col[index];
324
+ const int cur_h = (int)cur_inv_h_data;
325
+ const int cur_w = (int)cur_inv_w_data;
326
+ for (int dy = -2; dy <= 2; dy++)
327
+ {
328
+ for (int dx = -2; dx <= 2; dx++)
329
+ {
330
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
331
+ cur_w + dx >= 0 && cur_w + dx < width &&
332
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
333
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
334
+ {
335
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
336
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
337
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
338
+ }
339
+ }
340
+ }
341
+ }
342
+ }
343
+
344
+ void deformable_col2im(
345
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
346
+ const int height, const int width, const int ksize_h,
347
+ const int ksize_w, const int pad_h, const int pad_w,
348
+ const int stride_h, const int stride_w,
349
+ const int dilation_h, const int dilation_w,
350
+ const int parallel_imgs, const int deformable_group,
351
+ at::Tensor grad_im)
352
+ {
353
+
354
+ // todo: make sure parallel_imgs is passed in correctly
355
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
356
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
357
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
358
+ int channel_per_deformable_group = channels / deformable_group;
359
+
360
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
361
+ data_col.type(), "deformable_col2im_gpu", ([&] {
362
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
363
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
364
+ scalar_t *grad_im_ = grad_im.data<scalar_t>();
365
+
366
+ deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
367
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
368
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
369
+ dilation_h, dilation_w, channel_per_deformable_group,
370
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
371
+ }));
372
+
373
+ cudaError_t err = cudaGetLastError();
374
+ if (err != cudaSuccess)
375
+ {
376
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
377
+ }
378
+ }
379
+
380
+ template <typename scalar_t>
381
+ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
382
+ const scalar_t *data_im, const scalar_t *data_offset,
383
+ const int channels, const int height, const int width,
384
+ const int kernel_h, const int kernel_w,
385
+ const int pad_h, const int pad_w,
386
+ const int stride_h, const int stride_w,
387
+ const int dilation_h, const int dilation_w,
388
+ const int channel_per_deformable_group,
389
+ const int batch_size, const int offset_channels, const int deformable_group,
390
+ const int height_col, const int width_col, scalar_t *grad_offset)
391
+ {
392
+ CUDA_KERNEL_LOOP(index, n)
393
+ {
394
+ scalar_t val = 0;
395
+ int w = index % width_col;
396
+ int h = (index / width_col) % height_col;
397
+ int c = (index / width_col / height_col) % offset_channels;
398
+ int b = (index / width_col / height_col) / offset_channels;
399
+ // compute the start and end of the output
400
+
401
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
402
+ const int col_step = kernel_h * kernel_w;
403
+ int cnt = 0;
404
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
405
+ batch_size * width_col * height_col;
406
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
407
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
408
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
409
+ kernel_h * kernel_w * height_col * width_col;
410
+
411
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
412
+
413
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
414
+ {
415
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
416
+ const int bp_dir = offset_c % 2;
417
+
418
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
419
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
420
+ int w_out = col_pos % width_col;
421
+ int h_out = (col_pos / width_col) % height_col;
422
+ int w_in = w_out * stride_w - pad_w;
423
+ int h_in = h_out * stride_h - pad_h;
424
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
425
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
426
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
427
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
428
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
429
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
430
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
431
+ {
432
+ inv_h = inv_w = -2;
433
+ }
434
+ const scalar_t weight = get_coordinate_weight(
435
+ inv_h, inv_w,
436
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
437
+ val += weight * data_col_ptr[col_pos];
438
+ cnt += 1;
439
+ }
440
+
441
+ grad_offset[index] = val;
442
+ }
443
+ }
444
+
445
+ void deformable_col2im_coord(
446
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
447
+ const int channels, const int height, const int width, const int ksize_h,
448
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
449
+ const int stride_w, const int dilation_h, const int dilation_w,
450
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
451
+ {
452
+
453
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
454
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
455
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
456
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
457
+
458
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
459
+ data_col.type(), "deformable_col2im_coord_gpu", ([&] {
460
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
461
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
462
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
463
+ scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
464
+
465
+ deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
466
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
467
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
468
+ dilation_h, dilation_w, channel_per_deformable_group,
469
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
470
+ height_col, width_col, grad_offset_);
471
+ }));
472
+ }
473
+
474
+ template <typename scalar_t>
475
+ __device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
476
+ const int height, const int width, scalar_t h, scalar_t w)
477
+ {
478
+ int h_low = floor(h);
479
+ int w_low = floor(w);
480
+ int h_high = h_low + 1;
481
+ int w_high = w_low + 1;
482
+
483
+ scalar_t lh = h - h_low;
484
+ scalar_t lw = w - w_low;
485
+ scalar_t hh = 1 - lh, hw = 1 - lw;
486
+
487
+ scalar_t v1 = 0;
488
+ if (h_low >= 0 && w_low >= 0)
489
+ v1 = bottom_data[h_low * data_width + w_low];
490
+ scalar_t v2 = 0;
491
+ if (h_low >= 0 && w_high <= width - 1)
492
+ v2 = bottom_data[h_low * data_width + w_high];
493
+ scalar_t v3 = 0;
494
+ if (h_high <= height - 1 && w_low >= 0)
495
+ v3 = bottom_data[h_high * data_width + w_low];
496
+ scalar_t v4 = 0;
497
+ if (h_high <= height - 1 && w_high <= width - 1)
498
+ v4 = bottom_data[h_high * data_width + w_high];
499
+
500
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
501
+
502
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
503
+ return val;
504
+ }
505
+
506
+ template <typename scalar_t>
507
+ __device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
508
+ const int h, const int w, const int height, const int width)
509
+ {
510
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
511
+ {
512
+ //empty
513
+ return 0;
514
+ }
515
+
516
+ int argmax_h_low = floor(argmax_h);
517
+ int argmax_w_low = floor(argmax_w);
518
+ int argmax_h_high = argmax_h_low + 1;
519
+ int argmax_w_high = argmax_w_low + 1;
520
+
521
+ scalar_t weight = 0;
522
+ if (h == argmax_h_low && w == argmax_w_low)
523
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
524
+ if (h == argmax_h_low && w == argmax_w_high)
525
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
526
+ if (h == argmax_h_high && w == argmax_w_low)
527
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
528
+ if (h == argmax_h_high && w == argmax_w_high)
529
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
530
+ return weight;
531
+ }
532
+
533
+ template <typename scalar_t>
534
+ __device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
535
+ const int height, const int width, const scalar_t *im_data,
536
+ const int data_width, const int bp_dir)
537
+ {
538
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
539
+ {
540
+ //empty
541
+ return 0;
542
+ }
543
+
544
+ int argmax_h_low = floor(argmax_h);
545
+ int argmax_w_low = floor(argmax_w);
546
+ int argmax_h_high = argmax_h_low + 1;
547
+ int argmax_w_high = argmax_w_low + 1;
548
+
549
+ scalar_t weight = 0;
550
+
551
+ if (bp_dir == 0)
552
+ {
553
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
554
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
555
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
556
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
557
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
558
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
559
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
560
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
561
+ }
562
+ else if (bp_dir == 1)
563
+ {
564
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
565
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
566
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
567
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
568
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
569
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
570
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
571
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
572
+ }
573
+
574
+ return weight;
575
+ }
576
+
577
+ template <typename scalar_t>
578
+ __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
579
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
580
+ const int height, const int width, const int kernel_h, const int kernel_w,
581
+ const int pad_h, const int pad_w,
582
+ const int stride_h, const int stride_w,
583
+ const int dilation_h, const int dilation_w,
584
+ const int channel_per_deformable_group,
585
+ const int batch_size, const int num_channels, const int deformable_group,
586
+ const int height_col, const int width_col,
587
+ scalar_t *data_col)
588
+ {
589
+ CUDA_KERNEL_LOOP(index, n)
590
+ {
591
+ // index index of output matrix
592
+ const int w_col = index % width_col;
593
+ const int h_col = (index / width_col) % height_col;
594
+ const int b_col = (index / width_col / height_col) % batch_size;
595
+ const int c_im = (index / width_col / height_col) / batch_size;
596
+ const int c_col = c_im * kernel_h * kernel_w;
597
+
598
+ // compute deformable group index
599
+ const int deformable_group_index = c_im / channel_per_deformable_group;
600
+
601
+ const int h_in = h_col * stride_h - pad_h;
602
+ const int w_in = w_col * stride_w - pad_w;
603
+
604
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
605
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
606
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
607
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
608
+
609
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
610
+
611
+ for (int i = 0; i < kernel_h; ++i)
612
+ {
613
+ for (int j = 0; j < kernel_w; ++j)
614
+ {
615
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
616
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
617
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
618
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
619
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
620
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
621
+ scalar_t val = static_cast<scalar_t>(0);
622
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
623
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
624
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
625
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
626
+ {
627
+ //const float map_h = i * dilation_h + offset_h;
628
+ //const float map_w = j * dilation_w + offset_w;
629
+ //const int cur_height = height - h_in;
630
+ //const int cur_width = width - w_in;
631
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
632
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
633
+ }
634
+ *data_col_ptr = val * mask;
635
+ data_col_ptr += batch_size * height_col * width_col;
636
+ //data_col_ptr += height_col * width_col;
637
+ }
638
+ }
639
+ }
640
+ }
641
+
642
+ template <typename scalar_t>
643
+ __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
644
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
645
+ const int channels, const int height, const int width,
646
+ const int kernel_h, const int kernel_w,
647
+ const int pad_h, const int pad_w,
648
+ const int stride_h, const int stride_w,
649
+ const int dilation_h, const int dilation_w,
650
+ const int channel_per_deformable_group,
651
+ const int batch_size, const int deformable_group,
652
+ const int height_col, const int width_col,
653
+ scalar_t *grad_im)
654
+ {
655
+ CUDA_KERNEL_LOOP(index, n)
656
+ {
657
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
658
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
659
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
660
+ // compute the start and end of the output
661
+
662
+ const int deformable_group_index = c / channel_per_deformable_group;
663
+
664
+ int w_out = index % width_col;
665
+ int h_out = (index / width_col) % height_col;
666
+ int b = (index / width_col / height_col) % batch_size;
667
+ int w_in = w_out * stride_w - pad_w;
668
+ int h_in = h_out * stride_h - pad_h;
669
+
670
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
671
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
672
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
673
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
674
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
675
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
676
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
677
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
678
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
679
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
680
+
681
+ const scalar_t cur_top_grad = data_col[index] * mask;
682
+ const int cur_h = (int)cur_inv_h_data;
683
+ const int cur_w = (int)cur_inv_w_data;
684
+ for (int dy = -2; dy <= 2; dy++)
685
+ {
686
+ for (int dx = -2; dx <= 2; dx++)
687
+ {
688
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
689
+ cur_w + dx >= 0 && cur_w + dx < width &&
690
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
691
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
692
+ {
693
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
694
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
695
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
696
+ }
697
+ }
698
+ }
699
+ }
700
+ }
701
+
702
+ template <typename scalar_t>
703
+ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
704
+ const scalar_t *data_col, const scalar_t *data_im,
705
+ const scalar_t *data_offset, const scalar_t *data_mask,
706
+ const int channels, const int height, const int width,
707
+ const int kernel_h, const int kernel_w,
708
+ const int pad_h, const int pad_w,
709
+ const int stride_h, const int stride_w,
710
+ const int dilation_h, const int dilation_w,
711
+ const int channel_per_deformable_group,
712
+ const int batch_size, const int offset_channels, const int deformable_group,
713
+ const int height_col, const int width_col,
714
+ scalar_t *grad_offset, scalar_t *grad_mask)
715
+ {
716
+ CUDA_KERNEL_LOOP(index, n)
717
+ {
718
+ scalar_t val = 0, mval = 0;
719
+ int w = index % width_col;
720
+ int h = (index / width_col) % height_col;
721
+ int c = (index / width_col / height_col) % offset_channels;
722
+ int b = (index / width_col / height_col) / offset_channels;
723
+ // compute the start and end of the output
724
+
725
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
726
+ const int col_step = kernel_h * kernel_w;
727
+ int cnt = 0;
728
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
729
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
730
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
731
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
732
+
733
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
734
+
735
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
736
+ {
737
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
738
+ const int bp_dir = offset_c % 2;
739
+
740
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
741
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
742
+ int w_out = col_pos % width_col;
743
+ int h_out = (col_pos / width_col) % height_col;
744
+ int w_in = w_out * stride_w - pad_w;
745
+ int h_in = h_out * stride_h - pad_h;
746
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
747
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
748
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
749
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
750
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
751
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
752
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
753
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
754
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
755
+ {
756
+ inv_h = inv_w = -2;
757
+ }
758
+ else
759
+ {
760
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
761
+ }
762
+ const scalar_t weight = dmcn_get_coordinate_weight(
763
+ inv_h, inv_w,
764
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
765
+ val += weight * data_col_ptr[col_pos] * mask;
766
+ cnt += 1;
767
+ }
768
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
769
+ grad_offset[index] = val;
770
+ if (offset_c % 2 == 0)
771
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
772
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
773
+ }
774
+ }
775
+
776
+ void modulated_deformable_im2col_cuda(
777
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
778
+ const int batch_size, const int channels, const int height_im, const int width_im,
779
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
780
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
781
+ const int dilation_h, const int dilation_w,
782
+ const int deformable_group, at::Tensor data_col)
783
+ {
784
+ // num_axes should be smaller than block size
785
+ const int channel_per_deformable_group = channels / deformable_group;
786
+ const int num_kernels = channels * batch_size * height_col * width_col;
787
+
788
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
789
+ data_im.type(), "modulated_deformable_im2col_gpu", ([&] {
790
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
791
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
792
+ const scalar_t *data_mask_ = data_mask.data<scalar_t>();
793
+ scalar_t *data_col_ = data_col.data<scalar_t>();
794
+
795
+ modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
796
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
797
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
798
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
799
+ }));
800
+
801
+ cudaError_t err = cudaGetLastError();
802
+ if (err != cudaSuccess)
803
+ {
804
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
805
+ }
806
+ }
807
+
808
+ void modulated_deformable_col2im_cuda(
809
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
810
+ const int batch_size, const int channels, const int height_im, const int width_im,
811
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
812
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
813
+ const int dilation_h, const int dilation_w,
814
+ const int deformable_group, at::Tensor grad_im)
815
+ {
816
+
817
+ const int channel_per_deformable_group = channels / deformable_group;
818
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
819
+
820
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
821
+ data_col.type(), "modulated_deformable_col2im_gpu", ([&] {
822
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
823
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
824
+ const scalar_t *data_mask_ = data_mask.data<scalar_t>();
825
+ scalar_t *grad_im_ = grad_im.data<scalar_t>();
826
+
827
+ modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
828
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
829
+ kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
830
+ dilation_h, dilation_w, channel_per_deformable_group,
831
+ batch_size, deformable_group, height_col, width_col, grad_im_);
832
+ }));
833
+
834
+ cudaError_t err = cudaGetLastError();
835
+ if (err != cudaSuccess)
836
+ {
837
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
838
+ }
839
+ }
840
+
841
+ void modulated_deformable_col2im_coord_cuda(
842
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
843
+ const int batch_size, const int channels, const int height_im, const int width_im,
844
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
845
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
846
+ const int dilation_h, const int dilation_w,
847
+ const int deformable_group,
848
+ at::Tensor grad_offset, at::Tensor grad_mask)
849
+ {
850
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
851
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
852
+
853
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
854
+ data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] {
855
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
856
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
857
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
858
+ const scalar_t *data_mask_ = data_mask.data<scalar_t>();
859
+ scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
860
+ scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
861
+
862
+ modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
863
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
864
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
865
+ dilation_h, dilation_w, channel_per_deformable_group,
866
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
867
+ grad_offset_, grad_mask_);
868
+ }));
869
+ cudaError_t err = cudaGetLastError();
870
+ if (err != cudaSuccess)
871
+ {
872
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
873
+ }
874
+ }
maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
3
+
4
+ // based on
5
+ // author: Charles Shang
6
+ // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
7
+
8
+ #include <ATen/ATen.h>
9
+ #include <ATen/cuda/CUDAContext.h>
10
+
11
+ #include <THC/THC.h>
12
+ #include <THC/THCDeviceUtils.cuh>
13
+
14
+ #include <vector>
15
+ #include <iostream>
16
+ #include <cmath>
17
+
18
+
19
+ void DeformablePSROIPoolForward(
20
+ const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
21
+ at::Tensor out, at::Tensor top_count, const int batch, const int channels,
22
+ const int height, const int width, const int num_bbox,
23
+ const int channels_trans, const int no_trans, const float spatial_scale,
24
+ const int output_dim, const int group_size, const int pooled_size,
25
+ const int part_size, const int sample_per_part, const float trans_std);
26
+
27
+ void DeformablePSROIPoolBackwardAcc(
28
+ const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
29
+ const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
30
+ at::Tensor trans_grad, const int batch, const int channels,
31
+ const int height, const int width, const int num_bbox,
32
+ const int channels_trans, const int no_trans, const float spatial_scale,
33
+ const int output_dim, const int group_size, const int pooled_size,
34
+ const int part_size, const int sample_per_part, const float trans_std);
35
+
36
+ void deform_psroi_pooling_cuda_forward(
37
+ at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
38
+ at::Tensor top_count, const int no_trans, const float spatial_scale,
39
+ const int output_dim, const int group_size, const int pooled_size,
40
+ const int part_size, const int sample_per_part, const float trans_std)
41
+ {
42
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
43
+
44
+ const int batch = input.size(0);
45
+ const int channels = input.size(1);
46
+ const int height = input.size(2);
47
+ const int width = input.size(3);
48
+ const int channels_trans = no_trans ? 2 : trans.size(1);
49
+
50
+ const int num_bbox = bbox.size(0);
51
+ if (num_bbox != out.size(0))
52
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
53
+ out.size(0), num_bbox);
54
+
55
+ DeformablePSROIPoolForward(
56
+ input, bbox, trans, out, top_count, batch, channels, height, width,
57
+ num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
58
+ pooled_size, part_size, sample_per_part, trans_std);
59
+ }
60
+
61
+ void deform_psroi_pooling_cuda_backward(
62
+ at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
63
+ at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
64
+ const int no_trans, const float spatial_scale, const int output_dim,
65
+ const int group_size, const int pooled_size, const int part_size,
66
+ const int sample_per_part, const float trans_std)
67
+ {
68
+ AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
69
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
70
+
71
+ const int batch = input.size(0);
72
+ const int channels = input.size(1);
73
+ const int height = input.size(2);
74
+ const int width = input.size(3);
75
+ const int channels_trans = no_trans ? 2 : trans.size(1);
76
+
77
+ const int num_bbox = bbox.size(0);
78
+ if (num_bbox != out_grad.size(0))
79
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
80
+ out_grad.size(0), num_bbox);
81
+
82
+ DeformablePSROIPoolBackwardAcc(
83
+ out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
84
+ channels, height, width, num_bbox, channels_trans, no_trans,
85
+ spatial_scale, output_dim, group_size, pooled_size, part_size,
86
+ sample_per_part, trans_std);
87
+ }