dean commited on
Commit
9c03436
1 Parent(s): 068408a

train model on colab after fixing normalization bug

Browse files
Files changed (2) hide show
  1. dvc.lock +8 -3
  2. src/code/training.py +3 -3
dvc.lock CHANGED
@@ -20,9 +20,14 @@ train:
20
  cmd: python3 src/code/training.py src/data/processed
21
  deps:
22
  - path: src/code/training.py
23
- md5: 1d5f2b07b208bf062526e5ebfddca043
 
24
  - path: src/data/processed/
25
- md5: 77adb8603dbf31f3b272e0f51b6c2c29.dir
 
 
26
  outs:
27
  - path: src/models/
28
- md5: e6f3667c5e3ff28faaf9172adab28107.dir
 
 
 
20
  cmd: python3 src/code/training.py src/data/processed
21
  deps:
22
  - path: src/code/training.py
23
+ md5: 9634e85cffa3cf72d3d3d7739e40969e
24
+ size: 1645
25
  - path: src/data/processed/
26
+ md5: d98a9647a37ab431bfa35815eb4afda0.dir
27
+ size: 232903470
28
+ nfiles: 2898
29
  outs:
30
  - path: src/models/
31
+ md5: 18d26ed378b1b5ac61425afe153fc076.dir
32
+ size: 494926829
33
+ nfiles: 1
src/code/training.py CHANGED
@@ -3,7 +3,6 @@ import sys
3
  from fastai.vision.all import *
4
  from torchvision.utils import save_image
5
 
6
-
7
  class ImageImageDataLoaders(DataLoaders):
8
  "Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
9
  @classmethod
@@ -38,7 +37,8 @@ if __name__ == "__main__":
38
  sys.exit(0)
39
 
40
  data = create_data(Path(sys.argv[1]))
41
- learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/')
 
42
  learner.fine_tune(1)
43
-
44
  learner.save('model')
 
3
  from fastai.vision.all import *
4
  from torchvision.utils import save_image
5
 
 
6
  class ImageImageDataLoaders(DataLoaders):
7
  "Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
8
  @classmethod
 
37
  sys.exit(0)
38
 
39
  data = create_data(Path(sys.argv[1]))
40
+ learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/test/')
41
+ print("Training model...")
42
  learner.fine_tune(1)
43
+ print("Saving model...")
44
  learner.save('model')