transiteration
commited on
Commit
•
f8cd84d
1
Parent(s):
0cce9b6
Update evaluate.py
Browse files- evaluate.py +11 -3
evaluate.py
CHANGED
@@ -6,7 +6,11 @@ import torch
|
|
6 |
from omegaconf import open_dict
|
7 |
|
8 |
|
9 |
-
def evaluate_model(
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Determine the device (CPU or GPU)
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -53,8 +57,12 @@ if __name__ == "__main__":
|
|
53 |
# Parse command line arguments
|
54 |
parser = argparse.ArgumentParser()
|
55 |
parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
|
56 |
-
parser.add_argument("--test_manifest", help="Path for train manifest JSON file.")
|
57 |
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
|
58 |
args = parser.parse_args()
|
59 |
|
60 |
-
evaluate_model(
|
|
|
|
|
|
|
|
|
|
6 |
from omegaconf import open_dict
|
7 |
|
8 |
|
9 |
+
def evaluate_model(
|
10 |
+
model_path: str = None,
|
11 |
+
test_manifest: str = None,
|
12 |
+
batch_size: int = 1,
|
13 |
+
) -> Dict:
|
14 |
|
15 |
# Determine the device (CPU or GPU)
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
57 |
# Parse command line arguments
|
58 |
parser = argparse.ArgumentParser()
|
59 |
parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
|
60 |
+
parser.add_argument("--test_manifest", default=None, help="Path for train manifest JSON file.")
|
61 |
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
|
62 |
args = parser.parse_args()
|
63 |
|
64 |
+
evaluate_model(
|
65 |
+
model_path=args.model_path,
|
66 |
+
test_manifest=args.test_manifest,
|
67 |
+
batch_size=args.batch_size,
|
68 |
+
)
|