File size: 3,472 Bytes
64c3915
 
ffe4d51
506d239
64c3915
 
ffe4d51
 
 
 
64c3915
ffe4d51
64c3915
 
 
 
 
 
 
 
ffe4d51
 
 
 
 
 
64c3915
 
ffe4d51
64c3915
 
 
 
 
 
ffe4d51
64c3915
 
 
 
 
ffe4d51
 
 
 
64c3915
 
ffe4d51
 
64c3915
 
 
 
 
 
 
 
 
 
 
ffe4d51
64c3915
 
 
ffe4d51
 
 
 
64c3915
ffe4d51
7dd405e
 
86102e5
66621a9
ffe4d51
 
 
 
 
 
506d239
64c3915
 
ffe4d51
 
 
 
 
 
 
 
 
 
64c3915
ffe4d51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pprint
import re

from huggingface_hub import snapshot_download, delete_inference_endpoint

from src.backend.inference_endpoint import create_endpoint
from src.backend.manage_requests import check_completed_evals, \
    get_eval_requests, set_eval_request, PENDING_STATUS, FINISHED_STATUS, \
    FAILED_STATUS, RUNNING_STATUS
from src.backend.run_toxicity_eval import compute_results
from src.backend.sort_queue import sort_models_by_priority
from src.envs import (REQUESTS_REPO, EVAL_REQUESTS_PATH_BACKEND, RESULTS_REPO,
                      EVAL_RESULTS_PATH_BACKEND, API, TOKEN)
from src.logging import setup_logger

logger = setup_logger(__name__)

pp = pprint.PrettyPrinter(width=80)


snapshot_download(repo_id=RESULTS_REPO, revision="main",
                  local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset",
                  max_workers=60, token=TOKEN)
snapshot_download(repo_id=REQUESTS_REPO, revision="main",
                  local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset",
                  max_workers=60, token=TOKEN)


def run_auto_eval():
    # pull the eval dataset from the hub and parse any eval requests
    # check completed evals and set them to finished
    check_completed_evals(
        api=API,
        completed_status=FINISHED_STATUS,
        failed_status=FAILED_STATUS,
        hf_repo=REQUESTS_REPO,
        local_dir=EVAL_REQUESTS_PATH_BACKEND,
        hf_repo_results=RESULTS_REPO,
        local_dir_results=EVAL_RESULTS_PATH_BACKEND
    )

    # Get all eval requests that are PENDING
    eval_requests = get_eval_requests(hf_repo=REQUESTS_REPO,
                                      local_dir=EVAL_REQUESTS_PATH_BACKEND)
    # Sort the evals by priority (first submitted, first run)
    eval_requests = sort_models_by_priority(api=API, models=eval_requests)

    logger.info(
        f"Found {len(eval_requests)} {PENDING_STATUS} eval requests")

    if len(eval_requests) == 0:
        return

    eval_request = eval_requests[0]
    logger.info(pp.pformat(eval_request))

    set_eval_request(
        api=API,
        eval_request=eval_request,
        set_to_status=RUNNING_STATUS,
        hf_repo=REQUESTS_REPO,
        local_dir=EVAL_REQUESTS_PATH_BACKEND,
    )

    logger.info(
        f'Starting Evaluation of {eval_request.json_filepath} on Inference endpoints')
    endpoint_name = _make_endpoint_name(eval_request)
    endpoint_url = create_endpoint(endpoint_name, eval_request.model)
    logger.info("Created an endpoint url at %s" % endpoint_url)
    results = compute_results(endpoint_url, eval_request)
    logger.info("FINISHED!")
    logger.info(results)
    logger.info(f'Completed Evaluation of {eval_request.json_filepath}')
    set_eval_request(api=API,
                     eval_request=eval_request,
                     set_to_status=FINISHED_STATUS,
                     hf_repo=REQUESTS_REPO,
                     local_dir=EVAL_REQUESTS_PATH_BACKEND,
                     )
    # Delete endpoint when we're done.
    delete_inference_endpoint(endpoint_name)


def _make_endpoint_name(eval_request):
    model_repository = eval_request.model
    # Naming convention for endpoints
    endpoint_name_tmp = re.sub("[/.]", "-",
                               model_repository.lower()) + "-toxicity-eval"
    # Endpoints apparently can't have more than 32 characters.
    endpoint_name = endpoint_name_tmp[:32]
    return endpoint_name


if __name__ == "__main__":
    run_auto_eval()