Upload 2 files
Browse files- dofa_base_patch16_224-7cc0f413.pth +3 -0
- extract.py +30 -0
dofa_base_patch16_224-7cc0f413.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cc0f41371c349b958688c99d024d2db1c1a947cc0989329e6a5c104b2f23d40
|
3 |
+
size 445295690
|
extract.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""Extract the model backbone from the checkpoint."""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from torchgeo.models import dofa_base_patch16_224
|
8 |
+
|
9 |
+
# Load the checkpoint
|
10 |
+
in_filename = "ofa_base_checkpoint_e99.pth"
|
11 |
+
checkpoint = torch.load(in_filename, map_location=torch.device("cpu"))
|
12 |
+
|
13 |
+
# Remove extra keys
|
14 |
+
weights = checkpoint["model"]
|
15 |
+
del weights["mask_token"]
|
16 |
+
del weights["norm.weight"], weights["norm.bias"]
|
17 |
+
del weights["projector.weight"], weights["projector.bias"]
|
18 |
+
|
19 |
+
# Load the weights to ensure they are valid
|
20 |
+
# fc_norm and head are generated dynamically
|
21 |
+
allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
|
22 |
+
model = dofa_base_patch16_224()
|
23 |
+
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
|
24 |
+
assert set(missing_keys) <= allowed_missing_keys
|
25 |
+
assert not unexpected_keys
|
26 |
+
|
27 |
+
# Save the cleaned checkpoint
|
28 |
+
# Should be manually renamed later, add first 8 digits of sha256 to suffix
|
29 |
+
out_filename = "dofa_base_patch16_224.pth"
|
30 |
+
torch.save(weights, out_filename)
|