Muennighoff commited on
Commit
ae00acd
1 Parent(s): 9faf6e5

Add cargo_string kwarg

Browse files
Files changed (2) hide show
  1. code_eval.py +16 -2
  2. execute.py +5 -25
code_eval.py CHANGED
@@ -131,6 +131,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
131
  THE SOFTWARE."""
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
135
  class CodeEval(evaluate.Metric):
136
  def _info(self):
@@ -152,7 +166,7 @@ class CodeEval(evaluate.Metric):
152
  license=_LICENSE,
153
  )
154
 
155
- def _compute(self, predictions, references, k=[1, 10, 100], num_workers=4, timeout=3.0, language="python"):
156
  """Returns the scores"""
157
 
158
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
@@ -170,7 +184,7 @@ class CodeEval(evaluate.Metric):
170
  for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
171
  for candidate in candidates:
172
  test_program = candidate + "\n" + test_case
173
- args = (test_program, timeout, task_id, completion_id[task_id], language)
174
  future = executor.submit(check_correctness, *args)
175
  futures.append(future)
176
  completion_id[task_id] += 1
 
131
  THE SOFTWARE."""
132
 
133
 
134
+ # https://github.com/THUDM/CodeGeeX/blob/ebeb850f227a90c79de39f7e26b1302f374f3240/codegeex/benchmark/rust/Cargo.toml
135
+ BASE_CARGO = '''[package]
136
+ name = "rust"
137
+ version = "0.1.0"
138
+ edition = "2021"
139
+
140
+ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
141
+
142
+ [dependencies]
143
+ rand = "0.4"
144
+ regex = "1"
145
+ md5 = "0.7.0"
146
+ '''
147
+
148
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
149
  class CodeEval(evaluate.Metric):
150
  def _info(self):
 
166
  license=_LICENSE,
167
  )
168
 
169
+ def _compute(self, predictions, references, k=[1, 10, 100], num_workers=4, timeout=3.0, language="python", cargo_string=BASE_CARGO):
170
  """Returns the scores"""
171
 
172
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
 
184
  for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
185
  for candidate in candidates:
186
  test_program = candidate + "\n" + test_case
187
+ args = (test_program, timeout, task_id, completion_id[task_id], language, cargo_string)
188
  future = executor.submit(check_correctness, *args)
189
  futures.append(future)
190
  completion_id[task_id] += 1
execute.py CHANGED
@@ -27,22 +27,8 @@ import signal
27
  import subprocess
28
  import tempfile
29
 
30
- # https://github.com/THUDM/CodeGeeX/blob/ebeb850f227a90c79de39f7e26b1302f374f3240/codegeex/benchmark/rust/Cargo.toml
31
- BASE_CARGO = '''[package]
32
- name = "rust"
33
- version = "0.1.0"
34
- edition = "2021"
35
 
36
- # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
37
-
38
- [dependencies]
39
- rand = "0.4"
40
- regex = "1"
41
- md5 = "0.7.0"
42
- '''
43
-
44
-
45
- def check_correctness(check_program, timeout, task_id, completion_id, language):
46
  """
47
  Evaluates the functional correctness of a completion by running the test
48
  suite provided in the problem.
@@ -64,7 +50,7 @@ def check_correctness(check_program, timeout, task_id, completion_id, language):
64
  elif language == "javascript":
65
  p = multiprocessing.Process(target=unsafe_execute_js, args=(check_program, result, timeout))
66
  elif language == "rust":
67
- p = multiprocessing.Process(target=unsafe_execute_rust, args=(check_program, result, timeout))
68
  else:
69
  raise ValueError(f"Language {language} not supported. Feel free to add it :)")
70
 
@@ -237,7 +223,7 @@ def unsafe_execute_js(check_program, result, timeout):
237
  except subprocess.TimeoutExpired as e:
238
  result.append("timed out")
239
 
240
- def unsafe_execute_rust(check_program, result, timeout):
241
 
242
  with create_tempdir():
243
 
@@ -255,14 +241,8 @@ def unsafe_execute_rust(check_program, result, timeout):
255
  os.makedirs(RUST_SRC, exist_ok=True)
256
  os.makedirs(RUST_BIN, exist_ok=True)
257
 
258
- # Check if Cargo exists, if so copy it here
259
- if os.path.exists("/Cargo.toml"):
260
- pass
261
- else:
262
- # Warn that no Cargo was found in the parent directory
263
- logging.warning(f"Cargo.toml not found in root directory ({os.path.abspath('/')}). Creating a new one. Timeout of >300 is recommended.")
264
- # Create Cargo.toml
265
- open(f"{RUST_DIR}/Cargo.toml", 'w').write(BASE_CARGO)
266
 
267
  with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
268
  file_name: str = "test" + RUST_EXT
 
27
  import subprocess
28
  import tempfile
29
 
 
 
 
 
 
30
 
31
+ def check_correctness(check_program, timeout, task_id, completion_id, language, cargo_string=""):
 
 
 
 
 
 
 
 
 
32
  """
33
  Evaluates the functional correctness of a completion by running the test
34
  suite provided in the problem.
 
50
  elif language == "javascript":
51
  p = multiprocessing.Process(target=unsafe_execute_js, args=(check_program, result, timeout))
52
  elif language == "rust":
53
+ p = multiprocessing.Process(target=unsafe_execute_rust, args=(check_program, result, timeout, cargo_string))
54
  else:
55
  raise ValueError(f"Language {language} not supported. Feel free to add it :)")
56
 
 
223
  except subprocess.TimeoutExpired as e:
224
  result.append("timed out")
225
 
226
+ def unsafe_execute_rust(check_program, result, timeout, cargo_string):
227
 
228
  with create_tempdir():
229
 
 
241
  os.makedirs(RUST_SRC, exist_ok=True)
242
  os.makedirs(RUST_BIN, exist_ok=True)
243
 
244
+ # Create Cargo.toml file
245
+ open(f"{RUST_DIR}/Cargo.toml", 'w').write(cargo_string)
 
 
 
 
 
 
246
 
247
  with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
248
  file_name: str = "test" + RUST_EXT