amaye15 commited on
Commit
757ff04
1 Parent(s): debb548

End of training

Browse files
Files changed (7) hide show
  1. README.md +191 -0
  2. config.json +261 -0
  3. config.toml +27 -0
  4. model.safetensors +3 -0
  5. preprocessor_config.json +37 -0
  6. train.ipynb +1298 -0
  7. training_args.bin +3 -0
README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: google/siglip-base-patch16-224
4
+ tags:
5
+ - generated_from_trainer
6
+ datasets:
7
+ - stanford-dogs
8
+ metrics:
9
+ - accuracy
10
+ - f1
11
+ - precision
12
+ - recall
13
+ model-index:
14
+ - name: google-siglip-base-patch16-224-batch64-lr5e-05-standford-dogs
15
+ results:
16
+ - task:
17
+ name: Image Classification
18
+ type: image-classification
19
+ dataset:
20
+ name: stanford-dogs
21
+ type: stanford-dogs
22
+ config: default
23
+ split: full
24
+ args: default
25
+ metrics:
26
+ - name: Accuracy
27
+ type: accuracy
28
+ value: 0.8364917395529641
29
+ - name: F1
30
+ type: f1
31
+ value: 0.8328749982143954
32
+ - name: Precision
33
+ type: precision
34
+ value: 0.8377481660081763
35
+ - name: Recall
36
+ type: recall
37
+ value: 0.8330663170433035
38
+ ---
39
+
40
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
41
+ should probably proofread and complete it, then remove this comment. -->
42
+
43
+ # google-siglip-base-patch16-224-batch64-lr5e-05-standford-dogs
44
+
45
+ This model is a fine-tuned version of [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) on the stanford-dogs dataset.
46
+ It achieves the following results on the evaluation set:
47
+ - Loss: 0.5612
48
+ - Accuracy: 0.8365
49
+ - F1: 0.8329
50
+ - Precision: 0.8377
51
+ - Recall: 0.8331
52
+
53
+ ## Model description
54
+
55
+ More information needed
56
+
57
+ ## Intended uses & limitations
58
+
59
+ More information needed
60
+
61
+ ## Training and evaluation data
62
+
63
+ More information needed
64
+
65
+ ## Training procedure
66
+
67
+ ### Training hyperparameters
68
+
69
+ The following hyperparameters were used during training:
70
+ - learning_rate: 5e-05
71
+ - train_batch_size: 64
72
+ - eval_batch_size: 64
73
+ - seed: 42
74
+ - gradient_accumulation_steps: 4
75
+ - total_train_batch_size: 256
76
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
77
+ - lr_scheduler_type: linear
78
+ - training_steps: 1000
79
+
80
+ ### Training results
81
+
82
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | F1 | Precision | Recall |
83
+ |:-------------:|:-------:|:----:|:---------------:|:--------:|:------:|:---------:|:------:|
84
+ | 4.822 | 0.1550 | 10 | 4.2549 | 0.0782 | 0.0493 | 0.0987 | 0.0726 |
85
+ | 4.236 | 0.3101 | 20 | 3.5279 | 0.1907 | 0.1507 | 0.2201 | 0.1830 |
86
+ | 3.5066 | 0.4651 | 30 | 2.5316 | 0.3319 | 0.2941 | 0.4180 | 0.3205 |
87
+ | 2.8064 | 0.6202 | 40 | 2.1243 | 0.4361 | 0.4090 | 0.5324 | 0.4282 |
88
+ | 2.441 | 0.7752 | 50 | 1.5798 | 0.5510 | 0.5250 | 0.6242 | 0.5438 |
89
+ | 2.0985 | 0.9302 | 60 | 1.4242 | 0.5843 | 0.5577 | 0.6400 | 0.5768 |
90
+ | 1.8689 | 1.0853 | 70 | 1.1481 | 0.6625 | 0.6456 | 0.7143 | 0.6565 |
91
+ | 1.6588 | 1.2403 | 80 | 1.1937 | 0.6465 | 0.6361 | 0.7062 | 0.6439 |
92
+ | 1.5807 | 1.3953 | 90 | 0.9818 | 0.7058 | 0.6890 | 0.7438 | 0.6981 |
93
+ | 1.4851 | 1.5504 | 100 | 1.0181 | 0.7000 | 0.6839 | 0.7373 | 0.6959 |
94
+ | 1.5033 | 1.7054 | 110 | 1.0169 | 0.6914 | 0.6845 | 0.7490 | 0.6883 |
95
+ | 1.3022 | 1.8605 | 120 | 0.9087 | 0.7276 | 0.7170 | 0.7643 | 0.7222 |
96
+ | 1.3106 | 2.0155 | 130 | 0.8385 | 0.7432 | 0.7352 | 0.7667 | 0.7363 |
97
+ | 1.1721 | 2.1705 | 140 | 0.8957 | 0.7128 | 0.7026 | 0.7592 | 0.7075 |
98
+ | 1.131 | 2.3256 | 150 | 0.8730 | 0.7259 | 0.7149 | 0.7687 | 0.7196 |
99
+ | 1.1223 | 2.4806 | 160 | 0.8132 | 0.7546 | 0.7457 | 0.7855 | 0.7482 |
100
+ | 1.0688 | 2.6357 | 170 | 0.7485 | 0.7704 | 0.7601 | 0.7863 | 0.7631 |
101
+ | 1.0686 | 2.7907 | 180 | 0.7559 | 0.7651 | 0.7587 | 0.7920 | 0.7609 |
102
+ | 0.9733 | 2.9457 | 190 | 0.7779 | 0.7553 | 0.7458 | 0.7797 | 0.7521 |
103
+ | 0.9287 | 3.1008 | 200 | 0.7048 | 0.7818 | 0.7756 | 0.7981 | 0.7756 |
104
+ | 0.8746 | 3.2558 | 210 | 0.6848 | 0.7867 | 0.7774 | 0.8034 | 0.7822 |
105
+ | 0.7982 | 3.4109 | 220 | 0.6930 | 0.7884 | 0.7796 | 0.8025 | 0.7846 |
106
+ | 0.823 | 3.5659 | 230 | 0.7041 | 0.7804 | 0.7717 | 0.7975 | 0.7752 |
107
+ | 0.8713 | 3.7209 | 240 | 0.7418 | 0.7755 | 0.7646 | 0.8053 | 0.7711 |
108
+ | 0.8651 | 3.8760 | 250 | 0.6847 | 0.7828 | 0.7773 | 0.8048 | 0.7782 |
109
+ | 0.784 | 4.0310 | 260 | 0.6662 | 0.7923 | 0.7841 | 0.8097 | 0.7860 |
110
+ | 0.6894 | 4.1860 | 270 | 0.6980 | 0.7843 | 0.7781 | 0.8024 | 0.7779 |
111
+ | 0.7727 | 4.3411 | 280 | 0.6629 | 0.7833 | 0.7804 | 0.8030 | 0.7798 |
112
+ | 0.6978 | 4.4961 | 290 | 0.6820 | 0.7845 | 0.7800 | 0.8011 | 0.7820 |
113
+ | 0.7032 | 4.6512 | 300 | 0.6148 | 0.8032 | 0.7969 | 0.8094 | 0.7985 |
114
+ | 0.6978 | 4.8062 | 310 | 0.6457 | 0.7940 | 0.7872 | 0.8085 | 0.7892 |
115
+ | 0.66 | 4.9612 | 320 | 0.6242 | 0.8088 | 0.8033 | 0.8246 | 0.8058 |
116
+ | 0.5706 | 5.1163 | 330 | 0.6404 | 0.7966 | 0.7905 | 0.8097 | 0.7928 |
117
+ | 0.5456 | 5.2713 | 340 | 0.7147 | 0.7872 | 0.7767 | 0.8060 | 0.7819 |
118
+ | 0.5869 | 5.4264 | 350 | 0.6267 | 0.8066 | 0.8016 | 0.8188 | 0.8025 |
119
+ | 0.6022 | 5.5814 | 360 | 0.6197 | 0.8061 | 0.8028 | 0.8209 | 0.8027 |
120
+ | 0.5676 | 5.7364 | 370 | 0.6061 | 0.8059 | 0.8005 | 0.8140 | 0.8024 |
121
+ | 0.5456 | 5.8915 | 380 | 0.6018 | 0.8069 | 0.8006 | 0.8254 | 0.8033 |
122
+ | 0.56 | 6.0465 | 390 | 0.6126 | 0.8090 | 0.8037 | 0.8206 | 0.8045 |
123
+ | 0.4582 | 6.2016 | 400 | 0.6122 | 0.8115 | 0.8062 | 0.8196 | 0.8061 |
124
+ | 0.4594 | 6.3566 | 410 | 0.6058 | 0.8122 | 0.8081 | 0.8235 | 0.8082 |
125
+ | 0.4868 | 6.5116 | 420 | 0.5890 | 0.8195 | 0.8131 | 0.8300 | 0.8141 |
126
+ | 0.4841 | 6.6667 | 430 | 0.5909 | 0.8175 | 0.8119 | 0.8250 | 0.8133 |
127
+ | 0.4537 | 6.8217 | 440 | 0.5889 | 0.8195 | 0.8153 | 0.8261 | 0.8164 |
128
+ | 0.4807 | 6.9767 | 450 | 0.6105 | 0.8144 | 0.8104 | 0.8300 | 0.8106 |
129
+ | 0.4051 | 7.1318 | 460 | 0.5917 | 0.8171 | 0.8103 | 0.8217 | 0.8131 |
130
+ | 0.3727 | 7.2868 | 470 | 0.6037 | 0.8166 | 0.8116 | 0.8262 | 0.8125 |
131
+ | 0.4034 | 7.4419 | 480 | 0.6407 | 0.8032 | 0.8003 | 0.8146 | 0.8015 |
132
+ | 0.3684 | 7.5969 | 490 | 0.6205 | 0.8061 | 0.7997 | 0.8176 | 0.8008 |
133
+ | 0.416 | 7.7519 | 500 | 0.5855 | 0.8258 | 0.8207 | 0.8364 | 0.8211 |
134
+ | 0.3947 | 7.9070 | 510 | 0.5802 | 0.8214 | 0.8179 | 0.8283 | 0.8179 |
135
+ | 0.3731 | 8.0620 | 520 | 0.5870 | 0.8239 | 0.8191 | 0.8324 | 0.8188 |
136
+ | 0.3203 | 8.2171 | 530 | 0.5783 | 0.8265 | 0.8211 | 0.8302 | 0.8216 |
137
+ | 0.337 | 8.3721 | 540 | 0.5836 | 0.8200 | 0.8162 | 0.8247 | 0.8166 |
138
+ | 0.3396 | 8.5271 | 550 | 0.5992 | 0.8156 | 0.8121 | 0.8253 | 0.8115 |
139
+ | 0.3355 | 8.6822 | 560 | 0.5755 | 0.8229 | 0.8182 | 0.8281 | 0.8187 |
140
+ | 0.3273 | 8.8372 | 570 | 0.5819 | 0.8246 | 0.8194 | 0.8268 | 0.8208 |
141
+ | 0.3181 | 8.9922 | 580 | 0.5840 | 0.8205 | 0.8174 | 0.8279 | 0.8168 |
142
+ | 0.2855 | 9.1473 | 590 | 0.5997 | 0.8144 | 0.8098 | 0.8213 | 0.8103 |
143
+ | 0.254 | 9.3023 | 600 | 0.5863 | 0.8183 | 0.8132 | 0.8251 | 0.8133 |
144
+ | 0.2781 | 9.4574 | 610 | 0.5779 | 0.8224 | 0.8169 | 0.8275 | 0.8195 |
145
+ | 0.2691 | 9.6124 | 620 | 0.5816 | 0.8219 | 0.8177 | 0.8257 | 0.8186 |
146
+ | 0.3018 | 9.7674 | 630 | 0.5814 | 0.8297 | 0.8250 | 0.8370 | 0.8253 |
147
+ | 0.2615 | 9.9225 | 640 | 0.5761 | 0.8299 | 0.8261 | 0.8377 | 0.8262 |
148
+ | 0.2707 | 10.0775 | 650 | 0.5640 | 0.8326 | 0.8283 | 0.8385 | 0.8284 |
149
+ | 0.2482 | 10.2326 | 660 | 0.5685 | 0.8246 | 0.8206 | 0.8284 | 0.8218 |
150
+ | 0.2493 | 10.3876 | 670 | 0.5717 | 0.8241 | 0.8208 | 0.8311 | 0.8199 |
151
+ | 0.2167 | 10.5426 | 680 | 0.5741 | 0.8246 | 0.8204 | 0.8273 | 0.8204 |
152
+ | 0.2628 | 10.6977 | 690 | 0.5791 | 0.8248 | 0.8205 | 0.8281 | 0.8216 |
153
+ | 0.2316 | 10.8527 | 700 | 0.5770 | 0.8321 | 0.8272 | 0.8348 | 0.8284 |
154
+ | 0.2326 | 11.0078 | 710 | 0.5755 | 0.8280 | 0.8249 | 0.8348 | 0.8249 |
155
+ | 0.2001 | 11.1628 | 720 | 0.5783 | 0.8336 | 0.8299 | 0.8354 | 0.8310 |
156
+ | 0.1759 | 11.3178 | 730 | 0.5804 | 0.8345 | 0.8302 | 0.8367 | 0.8311 |
157
+ | 0.202 | 11.4729 | 740 | 0.5820 | 0.8316 | 0.8278 | 0.8353 | 0.8280 |
158
+ | 0.2191 | 11.6279 | 750 | 0.5724 | 0.8324 | 0.8279 | 0.8341 | 0.8287 |
159
+ | 0.1955 | 11.7829 | 760 | 0.5957 | 0.8226 | 0.8181 | 0.8268 | 0.8198 |
160
+ | 0.1972 | 11.9380 | 770 | 0.5722 | 0.8294 | 0.8254 | 0.8318 | 0.8263 |
161
+ | 0.1848 | 12.0930 | 780 | 0.5731 | 0.8311 | 0.8269 | 0.8339 | 0.8281 |
162
+ | 0.1613 | 12.2481 | 790 | 0.5682 | 0.8382 | 0.8344 | 0.8397 | 0.8356 |
163
+ | 0.1665 | 12.4031 | 800 | 0.5565 | 0.8350 | 0.8325 | 0.8365 | 0.8325 |
164
+ | 0.1739 | 12.5581 | 810 | 0.5738 | 0.8360 | 0.8328 | 0.8395 | 0.8326 |
165
+ | 0.1744 | 12.7132 | 820 | 0.5628 | 0.8360 | 0.8327 | 0.8387 | 0.8328 |
166
+ | 0.1737 | 12.8682 | 830 | 0.5712 | 0.8355 | 0.8320 | 0.8395 | 0.8324 |
167
+ | 0.1635 | 13.0233 | 840 | 0.5745 | 0.8309 | 0.8256 | 0.8328 | 0.8269 |
168
+ | 0.1689 | 13.1783 | 850 | 0.5781 | 0.8326 | 0.8288 | 0.8358 | 0.8294 |
169
+ | 0.1611 | 13.3333 | 860 | 0.5740 | 0.8328 | 0.8280 | 0.8349 | 0.8289 |
170
+ | 0.1624 | 13.4884 | 870 | 0.5656 | 0.8324 | 0.8279 | 0.8328 | 0.8287 |
171
+ | 0.1635 | 13.6434 | 880 | 0.5618 | 0.8319 | 0.8276 | 0.8328 | 0.8280 |
172
+ | 0.1395 | 13.7984 | 890 | 0.5648 | 0.8350 | 0.8311 | 0.8368 | 0.8312 |
173
+ | 0.1489 | 13.9535 | 900 | 0.5666 | 0.8341 | 0.8304 | 0.8370 | 0.8304 |
174
+ | 0.1174 | 14.1085 | 910 | 0.5700 | 0.8358 | 0.8321 | 0.8400 | 0.8320 |
175
+ | 0.1274 | 14.2636 | 920 | 0.5720 | 0.8331 | 0.8295 | 0.8366 | 0.8295 |
176
+ | 0.134 | 14.4186 | 930 | 0.5657 | 0.8353 | 0.8311 | 0.8369 | 0.8317 |
177
+ | 0.1327 | 14.5736 | 940 | 0.5662 | 0.8343 | 0.8308 | 0.8367 | 0.8307 |
178
+ | 0.1165 | 14.7287 | 950 | 0.5654 | 0.8341 | 0.8301 | 0.8355 | 0.8303 |
179
+ | 0.1277 | 14.8837 | 960 | 0.5661 | 0.8345 | 0.8308 | 0.8360 | 0.8310 |
180
+ | 0.1221 | 15.0388 | 970 | 0.5615 | 0.8370 | 0.8335 | 0.8388 | 0.8335 |
181
+ | 0.1194 | 15.1938 | 980 | 0.5632 | 0.8353 | 0.8318 | 0.8369 | 0.8319 |
182
+ | 0.1126 | 15.3488 | 990 | 0.5616 | 0.8362 | 0.8326 | 0.8376 | 0.8327 |
183
+ | 0.1256 | 15.5039 | 1000 | 0.5612 | 0.8365 | 0.8329 | 0.8377 | 0.8331 |
184
+
185
+
186
+ ### Framework versions
187
+
188
+ - Transformers 4.40.2
189
+ - Pytorch 2.3.0
190
+ - Datasets 2.19.1
191
+ - Tokenizers 0.19.1
config.json ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/siglip-base-patch16-224",
3
+ "architectures": [
4
+ "SiglipForImageClassification"
5
+ ],
6
+ "id2label": {
7
+ "0": "Affenpinscher",
8
+ "1": "Afghan Hound",
9
+ "2": "African Hunting Dog",
10
+ "3": "Airedale",
11
+ "4": "American Staffordshire Terrier",
12
+ "5": "Appenzeller",
13
+ "6": "Australian Terrier",
14
+ "7": "Basenji",
15
+ "8": "Basset",
16
+ "9": "Beagle",
17
+ "10": "Bedlington Terrier",
18
+ "11": "Bernese Mountain Dog",
19
+ "12": "Black And Tan Coonhound",
20
+ "13": "Blenheim Spaniel",
21
+ "14": "Bloodhound",
22
+ "15": "Bluetick",
23
+ "16": "Border Collie",
24
+ "17": "Border Terrier",
25
+ "18": "Borzoi",
26
+ "19": "Boston Bull",
27
+ "20": "Bouvier Des Flandres",
28
+ "21": "Boxer",
29
+ "22": "Brabancon Griffon",
30
+ "23": "Briard",
31
+ "24": "Brittany Spaniel",
32
+ "25": "Bull Mastiff",
33
+ "26": "Cairn",
34
+ "27": "Cardigan",
35
+ "28": "Chesapeake Bay Retriever",
36
+ "29": "Chihuahua",
37
+ "30": "Chow",
38
+ "31": "Clumber",
39
+ "32": "Cocker Spaniel",
40
+ "33": "Collie",
41
+ "34": "Curly Coated Retriever",
42
+ "35": "Dandie Dinmont",
43
+ "36": "Dhole",
44
+ "37": "Dingo",
45
+ "38": "Doberman",
46
+ "39": "English Foxhound",
47
+ "40": "English Setter",
48
+ "41": "English Springer",
49
+ "42": "Entlebucher",
50
+ "43": "Eskimo Dog",
51
+ "44": "Flat Coated Retriever",
52
+ "45": "French Bulldog",
53
+ "46": "German Shepherd",
54
+ "47": "German Short Haired Pointer",
55
+ "48": "Giant Schnauzer",
56
+ "49": "Golden Retriever",
57
+ "50": "Gordon Setter",
58
+ "51": "Great Dane",
59
+ "52": "Great Pyrenees",
60
+ "53": "Greater Swiss Mountain Dog",
61
+ "54": "Groenendael",
62
+ "55": "Ibizan Hound",
63
+ "56": "Irish Setter",
64
+ "57": "Irish Terrier",
65
+ "58": "Irish Water Spaniel",
66
+ "59": "Irish Wolfhound",
67
+ "60": "Italian Greyhound",
68
+ "61": "Japanese Spaniel",
69
+ "62": "Keeshond",
70
+ "63": "Kelpie",
71
+ "64": "Kerry Blue Terrier",
72
+ "65": "Komondor",
73
+ "66": "Kuvasz",
74
+ "67": "Labrador Retriever",
75
+ "68": "Lakeland Terrier",
76
+ "69": "Leonberg",
77
+ "70": "Lhasa",
78
+ "71": "Malamute",
79
+ "72": "Malinois",
80
+ "73": "Maltese Dog",
81
+ "74": "Mexican Hairless",
82
+ "75": "Miniature Pinscher",
83
+ "76": "Miniature Poodle",
84
+ "77": "Miniature Schnauzer",
85
+ "78": "Newfoundland",
86
+ "79": "Norfolk Terrier",
87
+ "80": "Norwegian Elkhound",
88
+ "81": "Norwich Terrier",
89
+ "82": "Old English Sheepdog",
90
+ "83": "Otterhound",
91
+ "84": "Papillon",
92
+ "85": "Pekinese",
93
+ "86": "Pembroke",
94
+ "87": "Pomeranian",
95
+ "88": "Pug",
96
+ "89": "Redbone",
97
+ "90": "Rhodesian Ridgeback",
98
+ "91": "Rottweiler",
99
+ "92": "Saint Bernard",
100
+ "93": "Saluki",
101
+ "94": "Samoyed",
102
+ "95": "Schipperke",
103
+ "96": "Scotch Terrier",
104
+ "97": "Scottish Deerhound",
105
+ "98": "Sealyham Terrier",
106
+ "99": "Shetland Sheepdog",
107
+ "100": "Shih Tzu",
108
+ "101": "Siberian Husky",
109
+ "102": "Silky Terrier",
110
+ "103": "Soft Coated Wheaten Terrier",
111
+ "104": "Staffordshire Bullterrier",
112
+ "105": "Standard Poodle",
113
+ "106": "Standard Schnauzer",
114
+ "107": "Sussex Spaniel",
115
+ "108": "Tibetan Mastiff",
116
+ "109": "Tibetan Terrier",
117
+ "110": "Toy Poodle",
118
+ "111": "Toy Terrier",
119
+ "112": "Vizsla",
120
+ "113": "Walker Hound",
121
+ "114": "Weimaraner",
122
+ "115": "Welsh Springer Spaniel",
123
+ "116": "West Highland White Terrier",
124
+ "117": "Whippet",
125
+ "118": "Wire Haired Fox Terrier",
126
+ "119": "Yorkshire Terrier"
127
+ },
128
+ "initializer_factor": 1.0,
129
+ "label2id": {
130
+ "Affenpinscher": 0,
131
+ "Afghan Hound": 1,
132
+ "African Hunting Dog": 2,
133
+ "Airedale": 3,
134
+ "American Staffordshire Terrier": 4,
135
+ "Appenzeller": 5,
136
+ "Australian Terrier": 6,
137
+ "Basenji": 7,
138
+ "Basset": 8,
139
+ "Beagle": 9,
140
+ "Bedlington Terrier": 10,
141
+ "Bernese Mountain Dog": 11,
142
+ "Black And Tan Coonhound": 12,
143
+ "Blenheim Spaniel": 13,
144
+ "Bloodhound": 14,
145
+ "Bluetick": 15,
146
+ "Border Collie": 16,
147
+ "Border Terrier": 17,
148
+ "Borzoi": 18,
149
+ "Boston Bull": 19,
150
+ "Bouvier Des Flandres": 20,
151
+ "Boxer": 21,
152
+ "Brabancon Griffon": 22,
153
+ "Briard": 23,
154
+ "Brittany Spaniel": 24,
155
+ "Bull Mastiff": 25,
156
+ "Cairn": 26,
157
+ "Cardigan": 27,
158
+ "Chesapeake Bay Retriever": 28,
159
+ "Chihuahua": 29,
160
+ "Chow": 30,
161
+ "Clumber": 31,
162
+ "Cocker Spaniel": 32,
163
+ "Collie": 33,
164
+ "Curly Coated Retriever": 34,
165
+ "Dandie Dinmont": 35,
166
+ "Dhole": 36,
167
+ "Dingo": 37,
168
+ "Doberman": 38,
169
+ "English Foxhound": 39,
170
+ "English Setter": 40,
171
+ "English Springer": 41,
172
+ "Entlebucher": 42,
173
+ "Eskimo Dog": 43,
174
+ "Flat Coated Retriever": 44,
175
+ "French Bulldog": 45,
176
+ "German Shepherd": 46,
177
+ "German Short Haired Pointer": 47,
178
+ "Giant Schnauzer": 48,
179
+ "Golden Retriever": 49,
180
+ "Gordon Setter": 50,
181
+ "Great Dane": 51,
182
+ "Great Pyrenees": 52,
183
+ "Greater Swiss Mountain Dog": 53,
184
+ "Groenendael": 54,
185
+ "Ibizan Hound": 55,
186
+ "Irish Setter": 56,
187
+ "Irish Terrier": 57,
188
+ "Irish Water Spaniel": 58,
189
+ "Irish Wolfhound": 59,
190
+ "Italian Greyhound": 60,
191
+ "Japanese Spaniel": 61,
192
+ "Keeshond": 62,
193
+ "Kelpie": 63,
194
+ "Kerry Blue Terrier": 64,
195
+ "Komondor": 65,
196
+ "Kuvasz": 66,
197
+ "Labrador Retriever": 67,
198
+ "Lakeland Terrier": 68,
199
+ "Leonberg": 69,
200
+ "Lhasa": 70,
201
+ "Malamute": 71,
202
+ "Malinois": 72,
203
+ "Maltese Dog": 73,
204
+ "Mexican Hairless": 74,
205
+ "Miniature Pinscher": 75,
206
+ "Miniature Poodle": 76,
207
+ "Miniature Schnauzer": 77,
208
+ "Newfoundland": 78,
209
+ "Norfolk Terrier": 79,
210
+ "Norwegian Elkhound": 80,
211
+ "Norwich Terrier": 81,
212
+ "Old English Sheepdog": 82,
213
+ "Otterhound": 83,
214
+ "Papillon": 84,
215
+ "Pekinese": 85,
216
+ "Pembroke": 86,
217
+ "Pomeranian": 87,
218
+ "Pug": 88,
219
+ "Redbone": 89,
220
+ "Rhodesian Ridgeback": 90,
221
+ "Rottweiler": 91,
222
+ "Saint Bernard": 92,
223
+ "Saluki": 93,
224
+ "Samoyed": 94,
225
+ "Schipperke": 95,
226
+ "Scotch Terrier": 96,
227
+ "Scottish Deerhound": 97,
228
+ "Sealyham Terrier": 98,
229
+ "Shetland Sheepdog": 99,
230
+ "Shih Tzu": 100,
231
+ "Siberian Husky": 101,
232
+ "Silky Terrier": 102,
233
+ "Soft Coated Wheaten Terrier": 103,
234
+ "Staffordshire Bullterrier": 104,
235
+ "Standard Poodle": 105,
236
+ "Standard Schnauzer": 106,
237
+ "Sussex Spaniel": 107,
238
+ "Tibetan Mastiff": 108,
239
+ "Tibetan Terrier": 109,
240
+ "Toy Poodle": 110,
241
+ "Toy Terrier": 111,
242
+ "Vizsla": 112,
243
+ "Walker Hound": 113,
244
+ "Weimaraner": 114,
245
+ "Welsh Springer Spaniel": 115,
246
+ "West Highland White Terrier": 116,
247
+ "Whippet": 117,
248
+ "Wire Haired Fox Terrier": 118,
249
+ "Yorkshire Terrier": 119
250
+ },
251
+ "model_type": "siglip",
252
+ "problem_type": "single_label_classification",
253
+ "text_config": {
254
+ "model_type": "siglip_text_model"
255
+ },
256
+ "torch_dtype": "float32",
257
+ "transformers_version": "4.40.2",
258
+ "vision_config": {
259
+ "model_type": "siglip_vision_model"
260
+ }
261
+ }
config.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [training_args]
2
+ output_dir="/Users/andrewmayes/Openclassroom/CanineNet/code/"
3
+ evaluation_strategy="steps"
4
+ save_strategy="steps"
5
+ learning_rate=5e-5
6
+ #per_device_train_batch_size=32 # 512
7
+ #per_device_eval_batch_size=32 # 512
8
+ # num_train_epochs=5,
9
+ eval_delay=0 # 50
10
+ eval_steps=0.01
11
+ #eval_accumulation_steps
12
+ gradient_accumulation_steps=4
13
+ gradient_checkpointing=true
14
+ optim="adafactor"
15
+ max_steps=1000 # 100
16
+ #logging_dir=""
17
+ #log_level="error"
18
+ load_best_model_at_end=true
19
+ metric_for_best_model="f1"
20
+ greater_is_better=true
21
+ #use_mps_device=true
22
+ logging_steps=0.01
23
+ save_steps=0.01
24
+ #auto_find_batch_size=true
25
+ report_to="mlflow"
26
+ save_total_limit=2
27
+ #hub_model_id="amaye15/SwinV2-Base-Document-Classifier"
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7e157ca5975cd1e5ead2a93dfb08575ef6c00e75f75cc082dcd1c0f6b18d51
3
+ size 371930976
preprocessor_config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_processor_keys": [
3
+ "images",
4
+ "do_resize",
5
+ "size",
6
+ "resample",
7
+ "do_rescale",
8
+ "rescale_factor",
9
+ "do_normalize",
10
+ "image_mean",
11
+ "image_std",
12
+ "return_tensors",
13
+ "data_format",
14
+ "input_data_format"
15
+ ],
16
+ "do_normalize": true,
17
+ "do_rescale": true,
18
+ "do_resize": true,
19
+ "image_mean": [
20
+ 0.5,
21
+ 0.5,
22
+ 0.5
23
+ ],
24
+ "image_processor_type": "SiglipImageProcessor",
25
+ "image_std": [
26
+ 0.5,
27
+ 0.5,
28
+ 0.5
29
+ ],
30
+ "processor_class": "SiglipProcessor",
31
+ "resample": 3,
32
+ "rescale_factor": 0.00392156862745098,
33
+ "size": {
34
+ "height": 224,
35
+ "width": 224
36
+ }
37
+ }
train.ipynb ADDED
@@ -0,0 +1,1298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Install"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 2,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "Requirement already satisfied: uv in /Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages (0.1.42)\n",
20
+ "Note: you may need to restart the kernel to use updated packages.\n"
21
+ ]
22
+ }
23
+ ],
24
+ "source": [
25
+ "%pip install uv"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "!uv pip install dagshub setuptools accelerate toml torch torchvision transformers mlflow datasets ipywidgets python-dotenv evaluate"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "# Setup"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 1,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "data": {
51
+ "text/html": [
52
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Initialized MLflow to track repo <span style=\"color: #008000; text-decoration-color: #008000\">\"amaye15/CanineNet\"</span>\n",
53
+ "</pre>\n"
54
+ ],
55
+ "text/plain": [
56
+ "Initialized MLflow to track repo \u001b[32m\"amaye15/CanineNet\"\u001b[0m\n"
57
+ ]
58
+ },
59
+ "metadata": {},
60
+ "output_type": "display_data"
61
+ },
62
+ {
63
+ "data": {
64
+ "text/html": [
65
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Repository amaye15/CanineNet initialized!\n",
66
+ "</pre>\n"
67
+ ],
68
+ "text/plain": [
69
+ "Repository amaye15/CanineNet initialized!\n"
70
+ ]
71
+ },
72
+ "metadata": {},
73
+ "output_type": "display_data"
74
+ }
75
+ ],
76
+ "source": [
77
+ "import os\n",
78
+ "import toml\n",
79
+ "import torch\n",
80
+ "import mlflow\n",
81
+ "import dagshub\n",
82
+ "import datasets\n",
83
+ "import evaluate\n",
84
+ "from dotenv import load_dotenv\n",
85
+ "from torchvision.transforms import v2\n",
86
+ "from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer\n",
87
+ "\n",
88
+ "ENV_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/.env\"\n",
89
+ "CONFIG_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/code/config.toml\"\n",
90
+ "CONFIG = toml.load(CONFIG_PATH)\n",
91
+ "\n",
92
+ "load_dotenv(ENV_PATH)\n",
93
+ "\n",
94
+ "dagshub.init(repo_name=os.environ['MLFLOW_TRACKING_PROJECTNAME'], repo_owner=os.environ['MLFLOW_TRACKING_USERNAME'], mlflow=True, dvc=True)\n",
95
+ "\n",
96
+ "os.environ['MLFLOW_TRACKING_USERNAME'] = \"amaye15\"\n",
97
+ "\n",
98
+ "mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME']\n",
99
+ " + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')\n",
100
+ "\n",
101
+ "CREATE_DATASET = True\n",
102
+ "ORIGINAL_DATASET = \"Alanox/stanford-dogs\"\n",
103
+ "MODIFIED_DATASET = \"amaye15/stanford-dogs\"\n",
104
+ "REMOVE_COLUMNS = [\"name\", \"annotations\"]\n",
105
+ "RENAME_COLUMNS = {\"image\":\"pixel_values\", \"target\":\"label\"}\n",
106
+ "SPLIT = 0.2\n",
107
+ "\n",
108
+ "METRICS = [\"accuracy\", \"f1\", \"precision\", \"recall\"]\n",
109
+ "# MODELS = 'google/vit-base-patch16-224'\n",
110
+ "# MODELS = \"google/siglip-base-patch16-224\"\n",
111
+ "\n"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "metadata": {},
117
+ "source": [
118
+ "# Dataset"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 2,
124
+ "metadata": {},
125
+ "outputs": [
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "Affenpinscher: 0\n",
131
+ "Afghan Hound: 1\n",
132
+ "African Hunting Dog: 2\n",
133
+ "Airedale: 3\n",
134
+ "American Staffordshire Terrier: 4\n",
135
+ "Appenzeller: 5\n",
136
+ "Australian Terrier: 6\n",
137
+ "Basenji: 7\n",
138
+ "Basset: 8\n",
139
+ "Beagle: 9\n",
140
+ "Bedlington Terrier: 10\n",
141
+ "Bernese Mountain Dog: 11\n",
142
+ "Black And Tan Coonhound: 12\n",
143
+ "Blenheim Spaniel: 13\n",
144
+ "Bloodhound: 14\n",
145
+ "Bluetick: 15\n",
146
+ "Border Collie: 16\n",
147
+ "Border Terrier: 17\n",
148
+ "Borzoi: 18\n",
149
+ "Boston Bull: 19\n",
150
+ "Bouvier Des Flandres: 20\n",
151
+ "Boxer: 21\n",
152
+ "Brabancon Griffon: 22\n",
153
+ "Briard: 23\n",
154
+ "Brittany Spaniel: 24\n",
155
+ "Bull Mastiff: 25\n",
156
+ "Cairn: 26\n",
157
+ "Cardigan: 27\n",
158
+ "Chesapeake Bay Retriever: 28\n",
159
+ "Chihuahua: 29\n",
160
+ "Chow: 30\n",
161
+ "Clumber: 31\n",
162
+ "Cocker Spaniel: 32\n",
163
+ "Collie: 33\n",
164
+ "Curly Coated Retriever: 34\n",
165
+ "Dandie Dinmont: 35\n",
166
+ "Dhole: 36\n",
167
+ "Dingo: 37\n",
168
+ "Doberman: 38\n",
169
+ "English Foxhound: 39\n",
170
+ "English Setter: 40\n",
171
+ "English Springer: 41\n",
172
+ "Entlebucher: 42\n",
173
+ "Eskimo Dog: 43\n",
174
+ "Flat Coated Retriever: 44\n",
175
+ "French Bulldog: 45\n",
176
+ "German Shepherd: 46\n",
177
+ "German Short Haired Pointer: 47\n",
178
+ "Giant Schnauzer: 48\n",
179
+ "Golden Retriever: 49\n",
180
+ "Gordon Setter: 50\n",
181
+ "Great Dane: 51\n",
182
+ "Great Pyrenees: 52\n",
183
+ "Greater Swiss Mountain Dog: 53\n",
184
+ "Groenendael: 54\n",
185
+ "Ibizan Hound: 55\n",
186
+ "Irish Setter: 56\n",
187
+ "Irish Terrier: 57\n",
188
+ "Irish Water Spaniel: 58\n",
189
+ "Irish Wolfhound: 59\n",
190
+ "Italian Greyhound: 60\n",
191
+ "Japanese Spaniel: 61\n",
192
+ "Keeshond: 62\n",
193
+ "Kelpie: 63\n",
194
+ "Kerry Blue Terrier: 64\n",
195
+ "Komondor: 65\n",
196
+ "Kuvasz: 66\n",
197
+ "Labrador Retriever: 67\n",
198
+ "Lakeland Terrier: 68\n",
199
+ "Leonberg: 69\n",
200
+ "Lhasa: 70\n",
201
+ "Malamute: 71\n",
202
+ "Malinois: 72\n",
203
+ "Maltese Dog: 73\n",
204
+ "Mexican Hairless: 74\n",
205
+ "Miniature Pinscher: 75\n",
206
+ "Miniature Poodle: 76\n",
207
+ "Miniature Schnauzer: 77\n",
208
+ "Newfoundland: 78\n",
209
+ "Norfolk Terrier: 79\n",
210
+ "Norwegian Elkhound: 80\n",
211
+ "Norwich Terrier: 81\n",
212
+ "Old English Sheepdog: 82\n",
213
+ "Otterhound: 83\n",
214
+ "Papillon: 84\n",
215
+ "Pekinese: 85\n",
216
+ "Pembroke: 86\n",
217
+ "Pomeranian: 87\n",
218
+ "Pug: 88\n",
219
+ "Redbone: 89\n",
220
+ "Rhodesian Ridgeback: 90\n",
221
+ "Rottweiler: 91\n",
222
+ "Saint Bernard: 92\n",
223
+ "Saluki: 93\n",
224
+ "Samoyed: 94\n",
225
+ "Schipperke: 95\n",
226
+ "Scotch Terrier: 96\n",
227
+ "Scottish Deerhound: 97\n",
228
+ "Sealyham Terrier: 98\n",
229
+ "Shetland Sheepdog: 99\n",
230
+ "Shih Tzu: 100\n",
231
+ "Siberian Husky: 101\n",
232
+ "Silky Terrier: 102\n",
233
+ "Soft Coated Wheaten Terrier: 103\n",
234
+ "Staffordshire Bullterrier: 104\n",
235
+ "Standard Poodle: 105\n",
236
+ "Standard Schnauzer: 106\n",
237
+ "Sussex Spaniel: 107\n",
238
+ "Tibetan Mastiff: 108\n",
239
+ "Tibetan Terrier: 109\n",
240
+ "Toy Poodle: 110\n",
241
+ "Toy Terrier: 111\n",
242
+ "Vizsla: 112\n",
243
+ "Walker Hound: 113\n",
244
+ "Weimaraner: 114\n",
245
+ "Welsh Springer Spaniel: 115\n",
246
+ "West Highland White Terrier: 116\n",
247
+ "Whippet: 117\n",
248
+ "Wire Haired Fox Terrier: 118\n",
249
+ "Yorkshire Terrier: 119\n"
250
+ ]
251
+ }
252
+ ],
253
+ "source": [
254
+ "if CREATE_DATASET:\n",
255
+ " ds = datasets.load_dataset(ORIGINAL_DATASET, token=os.getenv(\"HF_TOKEN\"), split=\"full\", trust_remote_code=True)\n",
256
+ " ds = ds.remove_columns(REMOVE_COLUMNS).rename_columns(RENAME_COLUMNS)\n",
257
+ "\n",
258
+ " labels = ds.select_columns(\"label\").to_pandas().sort_values(\"label\").get(\"label\").unique().tolist()\n",
259
+ " numbers = range(len(labels))\n",
260
+ " label2int = dict(zip(labels, numbers))\n",
261
+ " int2label = dict(zip(numbers, labels))\n",
262
+ "\n",
263
+ " for key, val in label2int.items():\n",
264
+ " print(f\"{key}: {val}\")\n",
265
+ "\n",
266
+ " ds = ds.class_encode_column(\"label\")\n",
267
+ " ds = ds.align_labels_with_mapping(label2int, \"label\")\n",
268
+ "\n",
269
+ " ds = ds.train_test_split(test_size=SPLIT, stratify_by_column = \"label\")\n",
270
+ " #ds.push_to_hub(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"))\n",
271
+ "\n",
272
+ " CONFIG[\"label2int\"] = str(label2int)\n",
273
+ " CONFIG[\"int2label\"] = str(int2label)\n",
274
+ "\n",
275
+ " # with open(\"output.toml\", \"w\") as toml_file:\n",
276
+ " # toml.dump(toml.dumps(CONFIG), toml_file)\n",
277
+ "\n",
278
+ " #ds = datasets.load_dataset(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"), trust_remote_code=True, streaming=True)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 3,
284
+ "metadata": {},
285
+ "outputs": [
286
+ {
287
+ "name": "stderr",
288
+ "output_type": "stream",
289
+ "text": [
290
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
291
+ " warnings.warn(\n",
292
+ "Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip-base-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
293
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
294
+ "max_steps is given, it will override any value given in num_train_epochs\n"
295
+ ]
296
+ },
297
+ {
298
+ "data": {
299
+ "application/vnd.jupyter.widget-view+json": {
300
+ "model_id": "343b3d32fc774b0f9a2b0dee471ec262",
301
+ "version_major": 2,
302
+ "version_minor": 0
303
+ },
304
+ "text/plain": [
305
+ " 0%| | 0/1000 [00:00<?, ?it/s]"
306
+ ]
307
+ },
308
+ "metadata": {},
309
+ "output_type": "display_data"
310
+ },
311
+ {
312
+ "name": "stderr",
313
+ "output_type": "stream",
314
+ "text": [
315
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
316
+ " warnings.warn(\n"
317
+ ]
318
+ },
319
+ {
320
+ "name": "stdout",
321
+ "output_type": "stream",
322
+ "text": [
323
+ "{'loss': 4.822, 'grad_norm': 11.180054664611816, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.16}\n"
324
+ ]
325
+ },
326
+ {
327
+ "data": {
328
+ "application/vnd.jupyter.widget-view+json": {
329
+ "model_id": "48d9a7623928469c9bd9e61ad4a5573b",
330
+ "version_major": 2,
331
+ "version_minor": 0
332
+ },
333
+ "text/plain": [
334
+ " 0%| | 0/65 [00:00<?, ?it/s]"
335
+ ]
336
+ },
337
+ "metadata": {},
338
+ "output_type": "display_data"
339
+ },
340
+ {
341
+ "name": "stderr",
342
+ "output_type": "stream",
343
+ "text": [
344
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
345
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
346
+ ]
347
+ },
348
+ {
349
+ "name": "stdout",
350
+ "output_type": "stream",
351
+ "text": [
352
+ "{'eval_loss': 4.254875183105469, 'eval_accuracy': 0.0782312925170068, 'eval_f1': 0.04927852996179247, 'eval_precision': 0.09874043278607707, 'eval_recall': 0.07264375052644872, 'eval_runtime': 55.3923, 'eval_samples_per_second': 74.306, 'eval_steps_per_second': 1.173, 'epoch': 0.16}\n"
353
+ ]
354
+ },
355
+ {
356
+ "name": "stderr",
357
+ "output_type": "stream",
358
+ "text": [
359
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
360
+ " warnings.warn(\n"
361
+ ]
362
+ },
363
+ {
364
+ "name": "stdout",
365
+ "output_type": "stream",
366
+ "text": [
367
+ "{'loss': 4.236, 'grad_norm': 17.628389358520508, 'learning_rate': 4.9e-05, 'epoch': 0.31}\n"
368
+ ]
369
+ },
370
+ {
371
+ "data": {
372
+ "application/vnd.jupyter.widget-view+json": {
373
+ "model_id": "b7de237f61e54c158691316da90ae5a5",
374
+ "version_major": 2,
375
+ "version_minor": 0
376
+ },
377
+ "text/plain": [
378
+ " 0%| | 0/65 [00:00<?, ?it/s]"
379
+ ]
380
+ },
381
+ "metadata": {},
382
+ "output_type": "display_data"
383
+ },
384
+ {
385
+ "name": "stderr",
386
+ "output_type": "stream",
387
+ "text": [
388
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
389
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
390
+ ]
391
+ },
392
+ {
393
+ "name": "stdout",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "{'eval_loss': 3.5278525352478027, 'eval_accuracy': 0.19071914480077745, 'eval_f1': 0.15072037668109098, 'eval_precision': 0.22011886017337598, 'eval_recall': 0.18300531751649823, 'eval_runtime': 55.5125, 'eval_samples_per_second': 74.145, 'eval_steps_per_second': 1.171, 'epoch': 0.31}\n"
397
+ ]
398
+ },
399
+ {
400
+ "name": "stderr",
401
+ "output_type": "stream",
402
+ "text": [
403
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
404
+ " warnings.warn(\n"
405
+ ]
406
+ },
407
+ {
408
+ "name": "stdout",
409
+ "output_type": "stream",
410
+ "text": [
411
+ "{'loss': 3.5066, 'grad_norm': 19.224912643432617, 'learning_rate': 4.85e-05, 'epoch': 0.47}\n"
412
+ ]
413
+ },
414
+ {
415
+ "data": {
416
+ "application/vnd.jupyter.widget-view+json": {
417
+ "model_id": "02737fd7a3554302b14c8fd8be6aad37",
418
+ "version_major": 2,
419
+ "version_minor": 0
420
+ },
421
+ "text/plain": [
422
+ " 0%| | 0/65 [00:00<?, ?it/s]"
423
+ ]
424
+ },
425
+ "metadata": {},
426
+ "output_type": "display_data"
427
+ },
428
+ {
429
+ "name": "stderr",
430
+ "output_type": "stream",
431
+ "text": [
432
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
433
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
434
+ ]
435
+ },
436
+ {
437
+ "name": "stdout",
438
+ "output_type": "stream",
439
+ "text": [
440
+ "{'eval_loss': 2.531590223312378, 'eval_accuracy': 0.33187560738581146, 'eval_f1': 0.2941424042839124, 'eval_precision': 0.4180298856360352, 'eval_recall': 0.320509389455932, 'eval_runtime': 56.5067, 'eval_samples_per_second': 72.841, 'eval_steps_per_second': 1.15, 'epoch': 0.47}\n"
441
+ ]
442
+ },
443
+ {
444
+ "name": "stderr",
445
+ "output_type": "stream",
446
+ "text": [
447
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
448
+ " warnings.warn(\n"
449
+ ]
450
+ },
451
+ {
452
+ "name": "stdout",
453
+ "output_type": "stream",
454
+ "text": [
455
+ "{'loss': 2.8064, 'grad_norm': 22.580602645874023, 'learning_rate': 4.8e-05, 'epoch': 0.62}\n"
456
+ ]
457
+ },
458
+ {
459
+ "data": {
460
+ "application/vnd.jupyter.widget-view+json": {
461
+ "model_id": "5a4a4e46c1ec4352a0d248e7f8d53a9e",
462
+ "version_major": 2,
463
+ "version_minor": 0
464
+ },
465
+ "text/plain": [
466
+ " 0%| | 0/65 [00:00<?, ?it/s]"
467
+ ]
468
+ },
469
+ "metadata": {},
470
+ "output_type": "display_data"
471
+ },
472
+ {
473
+ "name": "stderr",
474
+ "output_type": "stream",
475
+ "text": [
476
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
477
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
478
+ ]
479
+ },
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "{'eval_loss': 2.1243278980255127, 'eval_accuracy': 0.4361030126336249, 'eval_f1': 0.409040351489377, 'eval_precision': 0.5324247354698377, 'eval_recall': 0.4282087854976091, 'eval_runtime': 56.8186, 'eval_samples_per_second': 72.441, 'eval_steps_per_second': 1.144, 'epoch': 0.62}\n"
485
+ ]
486
+ },
487
+ {
488
+ "name": "stderr",
489
+ "output_type": "stream",
490
+ "text": [
491
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
492
+ " warnings.warn(\n"
493
+ ]
494
+ },
495
+ {
496
+ "name": "stdout",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "{'loss': 2.441, 'grad_norm': 17.738447189331055, 'learning_rate': 4.75e-05, 'epoch': 0.78}\n"
500
+ ]
501
+ },
502
+ {
503
+ "data": {
504
+ "application/vnd.jupyter.widget-view+json": {
505
+ "model_id": "45d70e5072f1485199b309a10b746045",
506
+ "version_major": 2,
507
+ "version_minor": 0
508
+ },
509
+ "text/plain": [
510
+ " 0%| | 0/65 [00:00<?, ?it/s]"
511
+ ]
512
+ },
513
+ "metadata": {},
514
+ "output_type": "display_data"
515
+ },
516
+ {
517
+ "name": "stderr",
518
+ "output_type": "stream",
519
+ "text": [
520
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
521
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
522
+ ]
523
+ },
524
+ {
525
+ "name": "stdout",
526
+ "output_type": "stream",
527
+ "text": [
528
+ "{'eval_loss': 1.5798275470733643, 'eval_accuracy': 0.5510204081632653, 'eval_f1': 0.5250154943185481, 'eval_precision': 0.6242284529813324, 'eval_recall': 0.5437767896591994, 'eval_runtime': 57.0171, 'eval_samples_per_second': 72.189, 'eval_steps_per_second': 1.14, 'epoch': 0.78}\n"
529
+ ]
530
+ },
531
+ {
532
+ "name": "stderr",
533
+ "output_type": "stream",
534
+ "text": [
535
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
536
+ " warnings.warn(\n"
537
+ ]
538
+ },
539
+ {
540
+ "name": "stdout",
541
+ "output_type": "stream",
542
+ "text": [
543
+ "{'loss': 2.0985, 'grad_norm': 18.94181251525879, 'learning_rate': 4.7e-05, 'epoch': 0.93}\n"
544
+ ]
545
+ },
546
+ {
547
+ "data": {
548
+ "application/vnd.jupyter.widget-view+json": {
549
+ "model_id": "83e2ce0b615d4939819bf0db71dd745a",
550
+ "version_major": 2,
551
+ "version_minor": 0
552
+ },
553
+ "text/plain": [
554
+ " 0%| | 0/65 [00:00<?, ?it/s]"
555
+ ]
556
+ },
557
+ "metadata": {},
558
+ "output_type": "display_data"
559
+ },
560
+ {
561
+ "name": "stderr",
562
+ "output_type": "stream",
563
+ "text": [
564
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
565
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
566
+ ]
567
+ },
568
+ {
569
+ "name": "stdout",
570
+ "output_type": "stream",
571
+ "text": [
572
+ "{'eval_loss': 1.42424476146698, 'eval_accuracy': 0.5843051506316812, 'eval_f1': 0.557705493250987, 'eval_precision': 0.6400162236362443, 'eval_recall': 0.5768419977409593, 'eval_runtime': 57.2333, 'eval_samples_per_second': 71.916, 'eval_steps_per_second': 1.136, 'epoch': 0.93}\n"
573
+ ]
574
+ },
575
+ {
576
+ "name": "stderr",
577
+ "output_type": "stream",
578
+ "text": [
579
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
580
+ " warnings.warn(\n"
581
+ ]
582
+ },
583
+ {
584
+ "name": "stdout",
585
+ "output_type": "stream",
586
+ "text": [
587
+ "{'loss': 1.8689, 'grad_norm': 15.593049049377441, 'learning_rate': 4.6500000000000005e-05, 'epoch': 1.09}\n"
588
+ ]
589
+ },
590
+ {
591
+ "data": {
592
+ "application/vnd.jupyter.widget-view+json": {
593
+ "model_id": "e0e5069d05484298906c53c0409ae3b4",
594
+ "version_major": 2,
595
+ "version_minor": 0
596
+ },
597
+ "text/plain": [
598
+ " 0%| | 0/65 [00:00<?, ?it/s]"
599
+ ]
600
+ },
601
+ "metadata": {},
602
+ "output_type": "display_data"
603
+ },
604
+ {
605
+ "name": "stderr",
606
+ "output_type": "stream",
607
+ "text": [
608
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
609
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
610
+ ]
611
+ },
612
+ {
613
+ "name": "stdout",
614
+ "output_type": "stream",
615
+ "text": [
616
+ "{'eval_loss': 1.1481006145477295, 'eval_accuracy': 0.6625364431486881, 'eval_f1': 0.6455514206859728, 'eval_precision': 0.7142859368225736, 'eval_recall': 0.6564757487305617, 'eval_runtime': 54.7373, 'eval_samples_per_second': 75.196, 'eval_steps_per_second': 1.187, 'epoch': 1.09}\n"
617
+ ]
618
+ },
619
+ {
620
+ "name": "stderr",
621
+ "output_type": "stream",
622
+ "text": [
623
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
624
+ " warnings.warn(\n"
625
+ ]
626
+ },
627
+ {
628
+ "name": "stdout",
629
+ "output_type": "stream",
630
+ "text": [
631
+ "{'loss': 1.6588, 'grad_norm': 18.39203453063965, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.24}\n"
632
+ ]
633
+ },
634
+ {
635
+ "data": {
636
+ "application/vnd.jupyter.widget-view+json": {
637
+ "model_id": "f3c04a3fbe7149e89e25a26ce4d543a7",
638
+ "version_major": 2,
639
+ "version_minor": 0
640
+ },
641
+ "text/plain": [
642
+ " 0%| | 0/65 [00:00<?, ?it/s]"
643
+ ]
644
+ },
645
+ "metadata": {},
646
+ "output_type": "display_data"
647
+ },
648
+ {
649
+ "name": "stderr",
650
+ "output_type": "stream",
651
+ "text": [
652
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
653
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
654
+ ]
655
+ },
656
+ {
657
+ "name": "stdout",
658
+ "output_type": "stream",
659
+ "text": [
660
+ "{'eval_loss': 1.1937264204025269, 'eval_accuracy': 0.6465014577259475, 'eval_f1': 0.6361000380324133, 'eval_precision': 0.7061715448588218, 'eval_recall': 0.6438641849166267, 'eval_runtime': 55.5191, 'eval_samples_per_second': 74.137, 'eval_steps_per_second': 1.171, 'epoch': 1.24}\n"
661
+ ]
662
+ },
663
+ {
664
+ "name": "stderr",
665
+ "output_type": "stream",
666
+ "text": [
667
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
668
+ " warnings.warn(\n"
669
+ ]
670
+ },
671
+ {
672
+ "name": "stdout",
673
+ "output_type": "stream",
674
+ "text": [
675
+ "{'loss': 1.5807, 'grad_norm': 15.319233894348145, 'learning_rate': 4.55e-05, 'epoch': 1.4}\n"
676
+ ]
677
+ },
678
+ {
679
+ "data": {
680
+ "application/vnd.jupyter.widget-view+json": {
681
+ "model_id": "d8e25a97554844cb8d12799591b55105",
682
+ "version_major": 2,
683
+ "version_minor": 0
684
+ },
685
+ "text/plain": [
686
+ " 0%| | 0/65 [00:00<?, ?it/s]"
687
+ ]
688
+ },
689
+ "metadata": {},
690
+ "output_type": "display_data"
691
+ },
692
+ {
693
+ "name": "stdout",
694
+ "output_type": "stream",
695
+ "text": [
696
+ "{'eval_loss': 0.9817520976066589, 'eval_accuracy': 0.70578231292517, 'eval_f1': 0.6890227220667341, 'eval_precision': 0.7438497507413404, 'eval_recall': 0.6980582780473442, 'eval_runtime': 54.4988, 'eval_samples_per_second': 75.525, 'eval_steps_per_second': 1.193, 'epoch': 1.4}\n"
697
+ ]
698
+ },
699
+ {
700
+ "name": "stderr",
701
+ "output_type": "stream",
702
+ "text": [
703
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
704
+ " warnings.warn(\n"
705
+ ]
706
+ },
707
+ {
708
+ "name": "stdout",
709
+ "output_type": "stream",
710
+ "text": [
711
+ "{'loss': 1.4851, 'grad_norm': 15.890103340148926, 'learning_rate': 4.5e-05, 'epoch': 1.55}\n"
712
+ ]
713
+ },
714
+ {
715
+ "data": {
716
+ "application/vnd.jupyter.widget-view+json": {
717
+ "model_id": "9ccacdd73fc6461b9a3a11bbf3c91179",
718
+ "version_major": 2,
719
+ "version_minor": 0
720
+ },
721
+ "text/plain": [
722
+ " 0%| | 0/65 [00:00<?, ?it/s]"
723
+ ]
724
+ },
725
+ "metadata": {},
726
+ "output_type": "display_data"
727
+ },
728
+ {
729
+ "name": "stderr",
730
+ "output_type": "stream",
731
+ "text": [
732
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
733
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
734
+ ]
735
+ },
736
+ {
737
+ "name": "stdout",
738
+ "output_type": "stream",
739
+ "text": [
740
+ "{'eval_loss': 1.0180633068084717, 'eval_accuracy': 0.6999514091350826, 'eval_f1': 0.6838587523312375, 'eval_precision': 0.7373019639077568, 'eval_recall': 0.6959074888023662, 'eval_runtime': 55.2973, 'eval_samples_per_second': 74.434, 'eval_steps_per_second': 1.175, 'epoch': 1.55}\n"
741
+ ]
742
+ },
743
+ {
744
+ "name": "stderr",
745
+ "output_type": "stream",
746
+ "text": [
747
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
748
+ " warnings.warn(\n"
749
+ ]
750
+ },
751
+ {
752
+ "name": "stdout",
753
+ "output_type": "stream",
754
+ "text": [
755
+ "{'loss': 1.5033, 'grad_norm': 17.170801162719727, 'learning_rate': 4.4500000000000004e-05, 'epoch': 1.71}\n"
756
+ ]
757
+ },
758
+ {
759
+ "data": {
760
+ "application/vnd.jupyter.widget-view+json": {
761
+ "model_id": "5fbaecaf6e7e4a89b14a917a95a32c48",
762
+ "version_major": 2,
763
+ "version_minor": 0
764
+ },
765
+ "text/plain": [
766
+ " 0%| | 0/65 [00:00<?, ?it/s]"
767
+ ]
768
+ },
769
+ "metadata": {},
770
+ "output_type": "display_data"
771
+ },
772
+ {
773
+ "name": "stdout",
774
+ "output_type": "stream",
775
+ "text": [
776
+ "{'eval_loss': 1.0169070959091187, 'eval_accuracy': 0.6914480077745384, 'eval_f1': 0.6845415929736886, 'eval_precision': 0.7489788726612852, 'eval_recall': 0.6883375806361393, 'eval_runtime': 54.525, 'eval_samples_per_second': 75.488, 'eval_steps_per_second': 1.192, 'epoch': 1.71}\n"
777
+ ]
778
+ },
779
+ {
780
+ "name": "stderr",
781
+ "output_type": "stream",
782
+ "text": [
783
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
784
+ " warnings.warn(\n"
785
+ ]
786
+ },
787
+ {
788
+ "name": "stdout",
789
+ "output_type": "stream",
790
+ "text": [
791
+ "{'loss': 1.3022, 'grad_norm': 15.557647705078125, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.86}\n"
792
+ ]
793
+ },
794
+ {
795
+ "data": {
796
+ "application/vnd.jupyter.widget-view+json": {
797
+ "model_id": "f0ec4f69dabe4268a130a513f8556e8f",
798
+ "version_major": 2,
799
+ "version_minor": 0
800
+ },
801
+ "text/plain": [
802
+ " 0%| | 0/65 [00:00<?, ?it/s]"
803
+ ]
804
+ },
805
+ "metadata": {},
806
+ "output_type": "display_data"
807
+ },
808
+ {
809
+ "name": "stdout",
810
+ "output_type": "stream",
811
+ "text": [
812
+ "{'eval_loss': 0.9087187051773071, 'eval_accuracy': 0.7276482021379981, 'eval_f1': 0.7169827898093813, 'eval_precision': 0.7642639115410531, 'eval_recall': 0.722171618202087, 'eval_runtime': 54.7556, 'eval_samples_per_second': 75.17, 'eval_steps_per_second': 1.187, 'epoch': 1.86}\n"
813
+ ]
814
+ },
815
+ {
816
+ "name": "stderr",
817
+ "output_type": "stream",
818
+ "text": [
819
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
820
+ " warnings.warn(\n"
821
+ ]
822
+ },
823
+ {
824
+ "name": "stdout",
825
+ "output_type": "stream",
826
+ "text": [
827
+ "{'loss': 1.3106, 'grad_norm': 15.203620910644531, 'learning_rate': 4.35e-05, 'epoch': 2.02}\n"
828
+ ]
829
+ },
830
+ {
831
+ "data": {
832
+ "application/vnd.jupyter.widget-view+json": {
833
+ "model_id": "d4fdc66d719e477c91950d7c477e44a4",
834
+ "version_major": 2,
835
+ "version_minor": 0
836
+ },
837
+ "text/plain": [
838
+ " 0%| | 0/65 [00:00<?, ?it/s]"
839
+ ]
840
+ },
841
+ "metadata": {},
842
+ "output_type": "display_data"
843
+ },
844
+ {
845
+ "name": "stdout",
846
+ "output_type": "stream",
847
+ "text": [
848
+ "{'eval_loss': 0.8385488986968994, 'eval_accuracy': 0.7431972789115646, 'eval_f1': 0.7352483871752059, 'eval_precision': 0.7666810806987456, 'eval_recall': 0.7363282855094594, 'eval_runtime': 57.5486, 'eval_samples_per_second': 71.522, 'eval_steps_per_second': 1.129, 'epoch': 2.02}\n"
849
+ ]
850
+ },
851
+ {
852
+ "name": "stderr",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
856
+ " warnings.warn(\n"
857
+ ]
858
+ },
859
+ {
860
+ "name": "stdout",
861
+ "output_type": "stream",
862
+ "text": [
863
+ "{'loss': 1.1721, 'grad_norm': 18.051284790039062, 'learning_rate': 4.3e-05, 'epoch': 2.17}\n"
864
+ ]
865
+ },
866
+ {
867
+ "data": {
868
+ "application/vnd.jupyter.widget-view+json": {
869
+ "model_id": "dc6976d80b7e40bbb3cc0031eef848b8",
870
+ "version_major": 2,
871
+ "version_minor": 0
872
+ },
873
+ "text/plain": [
874
+ " 0%| | 0/65 [00:00<?, ?it/s]"
875
+ ]
876
+ },
877
+ "metadata": {},
878
+ "output_type": "display_data"
879
+ },
880
+ {
881
+ "name": "stderr",
882
+ "output_type": "stream",
883
+ "text": [
884
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
885
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
886
+ ]
887
+ },
888
+ {
889
+ "name": "stdout",
890
+ "output_type": "stream",
891
+ "text": [
892
+ "{'eval_loss': 0.8956524133682251, 'eval_accuracy': 0.7128279883381924, 'eval_f1': 0.7025737877793609, 'eval_precision': 0.7591947211938203, 'eval_recall': 0.7074780847115492, 'eval_runtime': 58.6317, 'eval_samples_per_second': 70.201, 'eval_steps_per_second': 1.109, 'epoch': 2.17}\n"
893
+ ]
894
+ },
895
+ {
896
+ "name": "stderr",
897
+ "output_type": "stream",
898
+ "text": [
899
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
900
+ " warnings.warn(\n"
901
+ ]
902
+ },
903
+ {
904
+ "name": "stdout",
905
+ "output_type": "stream",
906
+ "text": [
907
+ "{'loss': 1.131, 'grad_norm': 16.522109985351562, 'learning_rate': 4.25e-05, 'epoch': 2.33}\n"
908
+ ]
909
+ },
910
+ {
911
+ "data": {
912
+ "application/vnd.jupyter.widget-view+json": {
913
+ "model_id": "f602bc78c97842f5b7a7221559cb52f7",
914
+ "version_major": 2,
915
+ "version_minor": 0
916
+ },
917
+ "text/plain": [
918
+ " 0%| | 0/65 [00:00<?, ?it/s]"
919
+ ]
920
+ },
921
+ "metadata": {},
922
+ "output_type": "display_data"
923
+ },
924
+ {
925
+ "name": "stdout",
926
+ "output_type": "stream",
927
+ "text": [
928
+ "{'eval_loss': 0.8729854226112366, 'eval_accuracy': 0.7259475218658892, 'eval_f1': 0.7148538617252097, 'eval_precision': 0.7687155689784482, 'eval_recall': 0.719605645045331, 'eval_runtime': 54.6147, 'eval_samples_per_second': 75.364, 'eval_steps_per_second': 1.19, 'epoch': 2.33}\n"
929
+ ]
930
+ },
931
+ {
932
+ "name": "stderr",
933
+ "output_type": "stream",
934
+ "text": [
935
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
936
+ " warnings.warn(\n"
937
+ ]
938
+ },
939
+ {
940
+ "name": "stdout",
941
+ "output_type": "stream",
942
+ "text": [
943
+ "{'loss': 1.1223, 'grad_norm': 16.727994918823242, 'learning_rate': 4.2e-05, 'epoch': 2.48}\n"
944
+ ]
945
+ },
946
+ {
947
+ "data": {
948
+ "application/vnd.jupyter.widget-view+json": {
949
+ "model_id": "e7b0d1b380a145b8bf58631628e8830e",
950
+ "version_major": 2,
951
+ "version_minor": 0
952
+ },
953
+ "text/plain": [
954
+ " 0%| | 0/65 [00:00<?, ?it/s]"
955
+ ]
956
+ },
957
+ "metadata": {},
958
+ "output_type": "display_data"
959
+ },
960
+ {
961
+ "name": "stdout",
962
+ "output_type": "stream",
963
+ "text": [
964
+ "{'eval_loss': 0.8132386803627014, 'eval_accuracy': 0.7546161321671526, 'eval_f1': 0.7457409451548778, 'eval_precision': 0.7855075954723149, 'eval_recall': 0.7482023394172116, 'eval_runtime': 54.8217, 'eval_samples_per_second': 75.08, 'eval_steps_per_second': 1.186, 'epoch': 2.48}\n"
965
+ ]
966
+ },
967
+ {
968
+ "name": "stderr",
969
+ "output_type": "stream",
970
+ "text": [
971
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
972
+ " warnings.warn(\n"
973
+ ]
974
+ },
975
+ {
976
+ "name": "stdout",
977
+ "output_type": "stream",
978
+ "text": [
979
+ "{'loss': 1.0688, 'grad_norm': 14.611897468566895, 'learning_rate': 4.15e-05, 'epoch': 2.64}\n"
980
+ ]
981
+ },
982
+ {
983
+ "data": {
984
+ "application/vnd.jupyter.widget-view+json": {
985
+ "model_id": "2f071a57e8b14c4a9b85178f57885e03",
986
+ "version_major": 2,
987
+ "version_minor": 0
988
+ },
989
+ "text/plain": [
990
+ " 0%| | 0/65 [00:00<?, ?it/s]"
991
+ ]
992
+ },
993
+ "metadata": {},
994
+ "output_type": "display_data"
995
+ },
996
+ {
997
+ "name": "stdout",
998
+ "output_type": "stream",
999
+ "text": [
1000
+ "{'eval_loss': 0.7485197186470032, 'eval_accuracy': 0.7704081632653061, 'eval_f1': 0.7600821180493263, 'eval_precision': 0.7863249503261968, 'eval_recall': 0.7631296317667979, 'eval_runtime': 56.4165, 'eval_samples_per_second': 72.957, 'eval_steps_per_second': 1.152, 'epoch': 2.64}\n"
1001
+ ]
1002
+ },
1003
+ {
1004
+ "name": "stderr",
1005
+ "output_type": "stream",
1006
+ "text": [
1007
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1008
+ " warnings.warn(\n"
1009
+ ]
1010
+ },
1011
+ {
1012
+ "name": "stdout",
1013
+ "output_type": "stream",
1014
+ "text": [
1015
+ "{'loss': 1.0686, 'grad_norm': 17.756242752075195, 'learning_rate': 4.1e-05, 'epoch': 2.79}\n"
1016
+ ]
1017
+ },
1018
+ {
1019
+ "data": {
1020
+ "application/vnd.jupyter.widget-view+json": {
1021
+ "model_id": "f4d3dc2d893047508b51c00948d8057c",
1022
+ "version_major": 2,
1023
+ "version_minor": 0
1024
+ },
1025
+ "text/plain": [
1026
+ " 0%| | 0/65 [00:00<?, ?it/s]"
1027
+ ]
1028
+ },
1029
+ "metadata": {},
1030
+ "output_type": "display_data"
1031
+ },
1032
+ {
1033
+ "name": "stdout",
1034
+ "output_type": "stream",
1035
+ "text": [
1036
+ "{'eval_loss': 0.7559003233909607, 'eval_accuracy': 0.7650631681243926, 'eval_f1': 0.7586751052497263, 'eval_precision': 0.7920018070685718, 'eval_recall': 0.7609226412898984, 'eval_runtime': 53.1768, 'eval_samples_per_second': 77.402, 'eval_steps_per_second': 1.222, 'epoch': 2.79}\n"
1037
+ ]
1038
+ },
1039
+ {
1040
+ "name": "stderr",
1041
+ "output_type": "stream",
1042
+ "text": [
1043
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1044
+ " warnings.warn(\n"
1045
+ ]
1046
+ },
1047
+ {
1048
+ "name": "stdout",
1049
+ "output_type": "stream",
1050
+ "text": [
1051
+ "{'loss': 0.9733, 'grad_norm': 14.432697296142578, 'learning_rate': 4.05e-05, 'epoch': 2.95}\n"
1052
+ ]
1053
+ },
1054
+ {
1055
+ "data": {
1056
+ "application/vnd.jupyter.widget-view+json": {
1057
+ "model_id": "35ef518721bb4ff29f1d1d6286fc5f75",
1058
+ "version_major": 2,
1059
+ "version_minor": 0
1060
+ },
1061
+ "text/plain": [
1062
+ " 0%| | 0/65 [00:00<?, ?it/s]"
1063
+ ]
1064
+ },
1065
+ "metadata": {},
1066
+ "output_type": "display_data"
1067
+ },
1068
+ {
1069
+ "name": "stderr",
1070
+ "output_type": "stream",
1071
+ "text": [
1072
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
1073
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
1074
+ ]
1075
+ },
1076
+ {
1077
+ "name": "stdout",
1078
+ "output_type": "stream",
1079
+ "text": [
1080
+ "{'eval_loss': 0.7778576612472534, 'eval_accuracy': 0.7553449951409135, 'eval_f1': 0.7458481776644565, 'eval_precision': 0.7797152587168623, 'eval_recall': 0.7521461869682271, 'eval_runtime': 54.2043, 'eval_samples_per_second': 75.935, 'eval_steps_per_second': 1.199, 'epoch': 2.95}\n"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "name": "stderr",
1085
+ "output_type": "stream",
1086
+ "text": [
1087
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1088
+ " warnings.warn(\n"
1089
+ ]
1090
+ }
1091
+ ],
1092
+ "source": [
1093
+ "metrics = {metric: evaluate.load(metric) for metric in METRICS}\n",
1094
+ "\n",
1095
+ "\n",
1096
+ "# for lr in [5e-3, 5e-4, 5e-5]: # 5e-5\n",
1097
+ "# for batch in [64]: # 32\n",
1098
+ "# for model_name in [\"google/vit-base-patch16-224\", \"microsoft/swinv2-base-patch4-window16-256\", \"google/siglip-base-patch16-224\"]: # \"facebook/dinov2-base\"\n",
1099
+ "\n",
1100
+ "lr = 5e-5\n",
1101
+ "batch = 64\n",
1102
+ "model_name = \"google/siglip-base-patch16-224\"\n",
1103
+ "\n",
1104
+ "image_processor = AutoImageProcessor.from_pretrained(model_name)\n",
1105
+ "model = AutoModelForImageClassification.from_pretrained(\n",
1106
+ "model_name,\n",
1107
+ "num_labels=len(label2int),\n",
1108
+ "id2label=int2label,\n",
1109
+ "label2id=label2int,\n",
1110
+ "ignore_mismatched_sizes=True,\n",
1111
+ ")\n",
1112
+ "\n",
1113
+ "# Then, in your transformations:\n",
1114
+ "def train_transform(examples, num_ops=10, magnitude=9, num_magnitude_bins=31):\n",
1115
+ "\n",
1116
+ " transformation = v2.Compose(\n",
1117
+ " [\n",
1118
+ " v2.RandAugment(\n",
1119
+ " num_ops=num_ops,\n",
1120
+ " magnitude=magnitude,\n",
1121
+ " num_magnitude_bins=num_magnitude_bins,\n",
1122
+ " )\n",
1123
+ " ]\n",
1124
+ " )\n",
1125
+ " # Ensure each image has three dimensions (in this case, ensure it's RGB)\n",
1126
+ " examples[\"pixel_values\"] = [\n",
1127
+ " image.convert(\"RGB\") for image in examples[\"pixel_values\"]\n",
1128
+ " ]\n",
1129
+ " # Apply transformations\n",
1130
+ " examples[\"pixel_values\"] = [\n",
1131
+ " image_processor(transformation(image), return_tensors=\"pt\")[\n",
1132
+ " \"pixel_values\"\n",
1133
+ " ].squeeze()\n",
1134
+ " for image in examples[\"pixel_values\"]\n",
1135
+ " ]\n",
1136
+ " return examples\n",
1137
+ "\n",
1138
+ "\n",
1139
+ "def test_transform(examples):\n",
1140
+ " # Ensure each image is RGB\n",
1141
+ " examples[\"pixel_values\"] = [\n",
1142
+ " image.convert(\"RGB\") for image in examples[\"pixel_values\"]\n",
1143
+ " ]\n",
1144
+ " # Apply processing\n",
1145
+ " examples[\"pixel_values\"] = [\n",
1146
+ " image_processor(image, return_tensors=\"pt\")[\"pixel_values\"].squeeze()\n",
1147
+ " for image in examples[\"pixel_values\"]\n",
1148
+ " ]\n",
1149
+ " return examples\n",
1150
+ "\n",
1151
+ "\n",
1152
+ "def compute_metrics(eval_pred):\n",
1153
+ " predictions, labels = eval_pred\n",
1154
+ " # predictions = np.argmax(logits, axis=-1)\n",
1155
+ " results = {}\n",
1156
+ " for key, val in metrics.items():\n",
1157
+ " if \"accuracy\" == key:\n",
1158
+ " result = next(\n",
1159
+ " iter(val.compute(predictions=predictions, references=labels).items())\n",
1160
+ " )\n",
1161
+ " if \"accuracy\" != key:\n",
1162
+ " result = next(\n",
1163
+ " iter(\n",
1164
+ " val.compute(\n",
1165
+ " predictions=predictions, references=labels, average=\"macro\"\n",
1166
+ " ).items()\n",
1167
+ " )\n",
1168
+ " )\n",
1169
+ " results[result[0]] = result[1]\n",
1170
+ " return results\n",
1171
+ "\n",
1172
+ "\n",
1173
+ "def collate_fn(examples):\n",
1174
+ " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
1175
+ " labels = torch.tensor([example[\"label\"] for example in examples])\n",
1176
+ " return {\"pixel_values\": pixel_values, \"labels\": labels}\n",
1177
+ "\n",
1178
+ "\n",
1179
+ "def preprocess_logits_for_metrics(logits, labels):\n",
1180
+ " \"\"\"\n",
1181
+ " Original Trainer may have a memory leak.\n",
1182
+ " This is a workaround to avoid storing too many tensors that are not needed.\n",
1183
+ " \"\"\"\n",
1184
+ " pred_ids = torch.argmax(logits, dim=-1)\n",
1185
+ " return pred_ids\n",
1186
+ "\n",
1187
+ "ds[\"train\"].set_transform(train_transform)\n",
1188
+ "ds[\"test\"].set_transform(test_transform)\n",
1189
+ "\n",
1190
+ "training_args = TrainingArguments(**CONFIG[\"training_args\"])\n",
1191
+ "training_args.per_device_train_batch_size = batch\n",
1192
+ "training_args.per_device_eval_batch_size = batch\n",
1193
+ "training_args.hub_model_id = f\"amaye15/{model_name.replace('/','-')}-batch{batch}-lr{lr}-standford-dogs\"\n",
1194
+ "\n",
1195
+ "mlflow.start_run(run_name=f\"{model_name.replace('/','-')}-batch{batch}-lr{lr}\")\n",
1196
+ "\n",
1197
+ "trainer = Trainer(\n",
1198
+ " model=model,\n",
1199
+ " args=training_args,\n",
1200
+ " train_dataset=ds[\"train\"],\n",
1201
+ " eval_dataset=ds[\"test\"],\n",
1202
+ " tokenizer=image_processor,\n",
1203
+ " data_collator=collate_fn,\n",
1204
+ " compute_metrics=compute_metrics,\n",
1205
+ " # callbacks=[early_stopping_callback],\n",
1206
+ " preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
1207
+ ")\n",
1208
+ "\n",
1209
+ "# Train the model\n",
1210
+ "trainer.train()\n",
1211
+ "\n",
1212
+ "trainer.push_to_hub()\n",
1213
+ "\n",
1214
+ "mlflow.end_run()"
1215
+ ]
1216
+ },
1217
+ {
1218
+ "cell_type": "code",
1219
+ "execution_count": null,
1220
+ "metadata": {},
1221
+ "outputs": [
1222
+ {
1223
+ "ename": "NameError",
1224
+ "evalue": "name 'mlflow' is not defined",
1225
+ "output_type": "error",
1226
+ "traceback": [
1227
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1228
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
1229
+ "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmlflow\u001b[49m\u001b[38;5;241m.\u001b[39mend_run()\n",
1230
+ "\u001b[0;31mNameError\u001b[0m: name 'mlflow' is not defined"
1231
+ ]
1232
+ }
1233
+ ],
1234
+ "source": [
1235
+ "mlflow.end_run()"
1236
+ ]
1237
+ },
1238
+ {
1239
+ "cell_type": "code",
1240
+ "execution_count": null,
1241
+ "metadata": {},
1242
+ "outputs": [],
1243
+ "source": [
1244
+ "# training_args = TrainingArguments(**CONFIG[\"training_args\"])\n",
1245
+ "\n",
1246
+ "# image_processor = AutoImageProcessor.from_pretrained(MODELS)\n",
1247
+ "# model = AutoModelForImageClassification.from_pretrained(\n",
1248
+ "# MODELS,\n",
1249
+ "# num_labels=len(CONFIG[\"label2int\"]),\n",
1250
+ "# id2label=CONFIG[\"label2int\"],\n",
1251
+ "# label2id=CONFIG[\"int2label\"],\n",
1252
+ "# ignore_mismatched_sizes=True,\n",
1253
+ "# )\n",
1254
+ "\n",
1255
+ "\n",
1256
+ "# training_args = TrainingArguments(**CONFIG[\"training_args\"])\n",
1257
+ "\n",
1258
+ "# trainer = Trainer(\n",
1259
+ "# model=model,\n",
1260
+ "# args=training_args,\n",
1261
+ "# train_dataset=ds[\"train\"],\n",
1262
+ "# eval_dataset=ds[\"test\"],\n",
1263
+ "# tokenizer=image_processor,\n",
1264
+ "# data_collator=collate_fn,\n",
1265
+ "# compute_metrics=compute_metrics,\n",
1266
+ "# # callbacks=[early_stopping_callback],\n",
1267
+ "# preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
1268
+ "# )\n",
1269
+ "\n",
1270
+ "# # Train the model\n",
1271
+ "# trainer.train()\n",
1272
+ "\n",
1273
+ "# mlflow.end_run()"
1274
+ ]
1275
+ }
1276
+ ],
1277
+ "metadata": {
1278
+ "kernelspec": {
1279
+ "display_name": "env",
1280
+ "language": "python",
1281
+ "name": "python3"
1282
+ },
1283
+ "language_info": {
1284
+ "codemirror_mode": {
1285
+ "name": "ipython",
1286
+ "version": 3
1287
+ },
1288
+ "file_extension": ".py",
1289
+ "mimetype": "text/x-python",
1290
+ "name": "python",
1291
+ "nbconvert_exporter": "python",
1292
+ "pygments_lexer": "ipython3",
1293
+ "version": "3.12.3"
1294
+ }
1295
+ },
1296
+ "nbformat": 4,
1297
+ "nbformat_minor": 2
1298
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1e2adea2ad13837e172adc6cc3a084bf72cd54a57bc6b1cc0cd252cea258e77
3
+ size 5112