lvwerra HF staff commited on
Commit
20fed5e
1 Parent(s): 7059b46

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. code_eval.py +21 -5
  2. requirements.txt +1 -1
code_eval.py CHANGED
@@ -20,6 +20,8 @@ import itertools
20
  import os
21
  from collections import Counter, defaultdict
22
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
23
 
24
  import datasets
25
  import numpy as np
@@ -131,14 +133,28 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
131
  THE SOFTWARE."""
132
 
133
 
 
 
 
 
 
 
 
 
 
 
134
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
135
  class CodeEval(evaluate.Metric):
136
- def _info(self):
 
 
 
137
  return evaluate.MetricInfo(
138
  # This is the description that will appear on the metrics page.
139
  description=_DESCRIPTION,
140
  citation=_CITATION,
141
  inputs_description=_KWARGS_DESCRIPTION,
 
142
  # This defines the format of each prediction and reference
143
  features=datasets.Features(
144
  {
@@ -152,7 +168,7 @@ class CodeEval(evaluate.Metric):
152
  license=_LICENSE,
153
  )
154
 
155
- def _compute(self, predictions, references, k=[1, 10, 100], num_workers=4, timeout=3.0):
156
  """Returns the scores"""
157
 
158
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
@@ -161,7 +177,7 @@ class CodeEval(evaluate.Metric):
161
  if os.name == "nt":
162
  raise NotImplementedError("This metric is currently not supported on Windows.")
163
 
164
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
165
  futures = []
166
  completion_id = Counter()
167
  n_samples = 0
@@ -170,7 +186,7 @@ class CodeEval(evaluate.Metric):
170
  for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
171
  for candidate in candidates:
172
  test_program = candidate + "\n" + test_case
173
- args = (test_program, timeout, task_id, completion_id[task_id])
174
  future = executor.submit(check_correctness, *args)
175
  futures.append(future)
176
  completion_id[task_id] += 1
@@ -189,7 +205,7 @@ class CodeEval(evaluate.Metric):
189
  total = np.array(total)
190
  correct = np.array(correct)
191
 
192
- ks = k
193
  pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
194
 
195
  return pass_at_k, results
 
20
  import os
21
  from collections import Counter, defaultdict
22
  from concurrent.futures import ThreadPoolExecutor, as_completed
23
+ from dataclasses import dataclass, field
24
+ from typing import List
25
 
26
  import datasets
27
  import numpy as np
 
133
  THE SOFTWARE."""
134
 
135
 
136
+ @dataclass
137
+ class CodeEvalConfig(evaluate.info.Config):
138
+
139
+ name: str = "default"
140
+
141
+ k: List[int] = field(default_factory=lambda: [1, 10, 100])
142
+ num_workers: int = 4
143
+ timeout: float = 3.0
144
+
145
+
146
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
147
  class CodeEval(evaluate.Metric):
148
+ CONFIG_CLASS = CodeEvalConfig
149
+ ALLOWED_CONFIG_NAMES = ["default"]
150
+
151
+ def _info(self, config):
152
  return evaluate.MetricInfo(
153
  # This is the description that will appear on the metrics page.
154
  description=_DESCRIPTION,
155
  citation=_CITATION,
156
  inputs_description=_KWARGS_DESCRIPTION,
157
+ config=config,
158
  # This defines the format of each prediction and reference
159
  features=datasets.Features(
160
  {
 
168
  license=_LICENSE,
169
  )
170
 
171
+ def _compute(self, predictions, references):
172
  """Returns the scores"""
173
 
174
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
 
177
  if os.name == "nt":
178
  raise NotImplementedError("This metric is currently not supported on Windows.")
179
 
180
+ with ThreadPoolExecutor(max_workers=self.config.num_workers) as executor:
181
  futures = []
182
  completion_id = Counter()
183
  n_samples = 0
 
186
  for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
187
  for candidate in candidates:
188
  test_program = candidate + "\n" + test_case
189
+ args = (test_program, self.config.timeout, task_id, completion_id[task_id])
190
  future = executor.submit(check_correctness, *args)
191
  futures.append(future)
192
  completion_id[task_id] += 1
 
205
  total = np.array(total)
206
  correct = np.array(correct)
207
 
208
+ ks = self.config.k
209
  pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
210
 
211
  return pass_at_k, results
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39