Bastien Dechamps commited on
Commit
fe2f12a
1 Parent(s): ed8157d

[ADD] Kaggle submission

Browse files
configs/kaggle_submission.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ command:
2
+ (): geoguessr_bot.commands.KaggleSubmissionCommand
3
+ guessr:
4
+ (): geoguessr_bot.guessr.DinoV2Embedder
5
+ embedder:
6
+ (): geoguessr_bot.retriever.DinoV2Embedder
7
+ device: "cpu"
8
+ retriever:
9
+ (): geoguessr_bot.retriever.Retriever
10
+ embeddings_path: !path "../resources/embeddings.npy"
11
+ metadata_path: !path "../resources/metadata3.csv"
12
+ image_folder_path: !path "../data/kaggle/images"
13
+ output_path: !path "../data/kaggle/submission.csv"
geoguessr_bot/commands/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .abstract_command import AbstractCommand
2
  from .embed_command import EmbedCommand
 
 
1
  from .abstract_command import AbstractCommand
2
  from .embed_command import EmbedCommand
3
+ from .kaggle_submission_command import KaggleSubmissionCommand
geoguessr_bot/commands/kaggle_submission_command.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import pandas as pd
5
+
6
+ from geoguessr_bot.commands import AbstractCommand
7
+ from geoguessr_bot.guessr import AbstractGuessr
8
+
9
+
10
+ @dataclass
11
+ class KaggleSubmissionCommand(AbstractCommand):
12
+ """Submit a prediction to Kaggle
13
+ """
14
+ image_folder_path: str
15
+ output_path: str
16
+ guessr: AbstractGuessr
17
+
18
+ def run(self) -> None:
19
+ images_ids, latitudes, longitudes = [], [], []
20
+ for image_name in os.listdir(self.image_folder_path):
21
+ image_path = os.path.join(self.image_folder_path, image_name)
22
+ coordinate = self.guessr.guess_from_path(image_path)
23
+ images_ids.append(image_name.split(".")[0])
24
+ latitudes.append(coordinate.latitude)
25
+ longitudes.append(coordinate.longitude)
26
+ pd.DataFrame(dict(
27
+ image_id=images_ids,
28
+ latitude=latitudes,
29
+ longitude=longitudes,
30
+ )).to_csv(self.output_path, index=False)