Benjamin Bossan commited on
Commit
eb10d0a
1 Parent(s): 0598e08

Use joblib to save the model

Browse files
Files changed (1) hide show
  1. train.py +4 -5
train.py CHANGED
@@ -5,14 +5,14 @@ stores the model in a pickle file.
5
 
6
  """
7
 
8
- import pickle
9
-
10
  from sklearn.datasets import make_classification
11
  from sklearn.linear_model import SGDClassifier
12
  from sklearn.model_selection import GridSearchCV
13
 
14
 
15
  SEED = 0
 
16
 
17
 
18
  def get_data():
@@ -48,8 +48,7 @@ def train(model, X, y, hparams):
48
 
49
 
50
  def save_model(model, filename):
51
- with open(filename, 'wb') as f:
52
- pickle.dump(model, f)
53
  print(f"Stored model in '{filename}'")
54
 
55
 
@@ -58,7 +57,7 @@ def main():
58
  model = get_model()
59
  hparams = get_hparams()
60
  model_trained = train(model, X, y, hparams=hparams)
61
- save_model(model_trained, 'model.pickle')
62
 
63
 
64
  if __name__ == '__main__':
 
5
 
6
  """
7
 
8
+ import joblib
 
9
  from sklearn.datasets import make_classification
10
  from sklearn.linear_model import SGDClassifier
11
  from sklearn.model_selection import GridSearchCV
12
 
13
 
14
  SEED = 0
15
+ FILENAME = 'sklearn_model.joblib'
16
 
17
 
18
  def get_data():
 
48
 
49
 
50
  def save_model(model, filename):
51
+ joblib.dump(model, filename)
 
52
  print(f"Stored model in '{filename}'")
53
 
54
 
 
57
  model = get_model()
58
  hparams = get_hparams()
59
  model_trained = train(model, X, y, hparams=hparams)
60
+ save_model(model_trained, FILENAME)
61
 
62
 
63
  if __name__ == '__main__':