transiteration commited on
Commit
f8cd84d
1 Parent(s): 0cce9b6

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +11 -3
evaluate.py CHANGED
@@ -6,7 +6,11 @@ import torch
6
  from omegaconf import open_dict
7
 
8
 
9
- def evaluate_model(model_path: str, test_manifest: str, batch_size: int = 1) -> Dict:
 
 
 
 
10
 
11
  # Determine the device (CPU or GPU)
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -53,8 +57,12 @@ if __name__ == "__main__":
53
  # Parse command line arguments
54
  parser = argparse.ArgumentParser()
55
  parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
56
- parser.add_argument("--test_manifest", help="Path for train manifest JSON file.")
57
  parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
58
  args = parser.parse_args()
59
 
60
- evaluate_model(model_path=args.model_path, test_manifest=args.test_manifest, batch_size=args.batch_size)
 
 
 
 
 
6
  from omegaconf import open_dict
7
 
8
 
9
+ def evaluate_model(
10
+ model_path: str = None,
11
+ test_manifest: str = None,
12
+ batch_size: int = 1,
13
+ ) -> Dict:
14
 
15
  # Determine the device (CPU or GPU)
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
57
  # Parse command line arguments
58
  parser = argparse.ArgumentParser()
59
  parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
60
+ parser.add_argument("--test_manifest", default=None, help="Path for train manifest JSON file.")
61
  parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
62
  args = parser.parse_args()
63
 
64
+ evaluate_model(
65
+ model_path=args.model_path,
66
+ test_manifest=args.test_manifest,
67
+ batch_size=args.batch_size,
68
+ )