DCVAI-Example-1 / script.py
harpreetsahota's picture
Update script.py
ffbdd76 verified
import os
import yaml
import fiftyone as fo
import fiftyone.utils.random as four
import fiftyone.utils.huggingface as fouh
#IMPLEMENT YOUR FUNCTIONS FOR DATA CURATION HERE, BELOW ARE JUST DUMMY FUNCTIONS AS EXAMPLES
def shuffle_data(dataset):
"""Shuffle the dataset"""
return dataset.shuffle(seed=51)
def take_random_sample(dataset):
"""Take a sample from the dataset"""
return dataset.take(size=10,seed=51)
# DEFINE YOUR TRAINING HYPERPARAMETERS IN THIS DICTIONARY
training_config = {
# Dataset split
"train_split": 0.9,
"val_split": 0.1,
# Training parameters
"train_params": {
"epochs": 1,
"batch": 16,
"imgsz": 640,
"lr0": 0.01,
"lrf": 0.01
}
}
# WRAP YOUR DATASET CURATION FUNCTIONS IN THIS FUNCTION
def prepare_dataset():
"""
Prepare the dataset for model training.
NOTE: You there are lines you must not modify in this function. They are marked with "DO NOT MODIFY".
Args:
name (str): The name of the dataset to load. Must be "Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set".
Returns:
fiftyone.core.dataset.Dataset: The curated dataset.
Note:
The following code block MUST NOT be removed from your submission:
This ensures that only the approved dataset is used for the competition.
"""
# DO NOT MODIFY THIS LINE
dataset = fouh.load_from_hub("/tmp/data/train")
# WRAP YOUR DATA CURATION FUNCTIONS HERE
dataset = shuffle_data(dataset)
dataset = take_random_sample(dataset)
# DO NOT MODIFY BELOW THIS LINE
curated_dataset = dataset.clone(name="curated_dataset")
curated_dataset.persistent = True
# DO NOT MODIFY THIS FUNCTION
def export_to_yolo_format(
samples,
classes,
label_field="ground_truth",
export_dir=".",
splits=["train", "val"]
):
"""
Export samples to YOLO format, optionally handling multiple data splits.
NOTE: DO NOT MODIFY THIS FUNCTION.
Args:
samples (fiftyone.core.collections.SampleCollection): The dataset or samples to export.
export_dir (str): The directory where the exported data will be saved.
classes (list): A list of class names for the YOLO format.
label_field (str, optional): The field in the samples that contains the labels.
Defaults to "ground_truth".
splits (str, list, optional): The split(s) to export. Can be a single split name (str)
or a list of split names. If None, all samples are exported as "val" split.
Defaults to None.
Returns:
None
"""
if splits is None:
splits = ["val"]
elif isinstance(splits, str):
splits = [splits]
for split in splits:
split_view = samples if split == "val" and splits == ["val"] else samples.match_tags(split)
split_view.export(
export_dir=export_dir,
dataset_type=fo.types.YOLOv5Dataset,
label_field=label_field,
classes=classes,
split=split
)
# DO NOT MODIFY THIS FUNCTION
def train_model(training_config=training_config):
"""
Train the YOLO model on the given dataset using the provided configuration.
NOTE: DO NOT MODIFY THIS FUNCTION AT ALL OR YOUR SCRIPT WILL FAIL.
"""
training_dataset = prepare_dataset()
print("Splitting the dataset...")
four.random_split(training_dataset, {"train": training_config['train_split'], "val": training_config['val_split']})
print("Dataset split completed.")
print("Exporting dataset to YOLO format...")
export_to_yolo_format(
samples=training_dataset,
classes=training_dataset.default_classes,
)
print("Dataset export completed.")
print("Initializing the YOLO model...")
#DO NOT MODIFY THIS LINE
model = YOLO(
model="/tmp/data/yolo11m.pt",
)
print("Model initialized.")
print("Starting model training...")
results = model.train(
data="dataset.yaml",
**training_config['train_params']
)
print("Model training completed.")
best_model_path = str(results.save_dir / "weights/best.pt")
print(f"Best model saved to: {best_model_path}")
# DO NOT MODIFY THE BELOW
if __name__=="__main__":
train_model()