Spaces:
Paused
Paused
Dean
commited on
Commit
•
79fd7d0
1
Parent(s):
0b86a0a
Fixed a bug in the training stage where the model was not saved, commiting before training on colab
Browse files- dvc.yaml +1 -1
- src/code/training.py +5 -12
dvc.yaml
CHANGED
@@ -9,7 +9,7 @@ stages:
|
|
9 |
outs:
|
10 |
- src/data/processed/
|
11 |
train:
|
12 |
-
cmd: python3 src/code/training.py src/data/processed
|
13 |
deps:
|
14 |
- src/code/training.py
|
15 |
- src/data/processed/
|
|
|
9 |
outs:
|
10 |
- src/data/processed/
|
11 |
train:
|
12 |
+
cmd: python3 src/code/training.py src/data/processed
|
13 |
deps:
|
14 |
- src/code/training.py
|
15 |
- src/data/processed/
|
src/code/training.py
CHANGED
@@ -17,20 +17,13 @@ def create_data(data_path):
|
|
17 |
return data
|
18 |
|
19 |
|
20 |
-
def train(data):
|
21 |
-
learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=1, loss_func=MSELossFlat())
|
22 |
-
learner.fine_tune(1)
|
23 |
-
|
24 |
-
|
25 |
if __name__ == "__main__":
|
26 |
-
if len(sys.argv) <
|
27 |
-
print("usage: %s <data_path>
|
28 |
sys.exit(0)
|
29 |
|
30 |
data = create_data(Path(sys.argv[1]))
|
31 |
-
data
|
32 |
-
|
33 |
-
learner = train(data)
|
34 |
|
35 |
-
learner.save(
|
36 |
-
learner.show_results()
|
|
|
17 |
return data
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
if __name__ == "__main__":
|
21 |
+
if len(sys.argv) < 2:
|
22 |
+
print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
|
23 |
sys.exit(0)
|
24 |
|
25 |
data = create_data(Path(sys.argv[1]))
|
26 |
+
learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=1, loss_func=MSELossFlat(), path='src/')
|
27 |
+
learner.fine_tune(1)
|
|
|
28 |
|
29 |
+
learner.save('model')
|
|