Muennighoff commited on
Commit
116e267
1 Parent(s): f877f8e
Files changed (1) hide show
  1. execute.py +44 -12
execute.py CHANGED
@@ -18,6 +18,7 @@
18
  import contextlib
19
  import faulthandler
20
  import io
 
21
  import multiprocessing
22
  import os
23
  import platform
@@ -25,6 +26,19 @@ import signal
25
  import subprocess
26
  import tempfile
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  def check_correctness(check_program, timeout, task_id, completion_id, language):
@@ -209,7 +223,7 @@ def unsafe_execute_js(check_program, result, timeout):
209
 
210
  # Run program.
211
  try:
212
- exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
213
  if exec_result.stderr.decode():
214
  err = exec_result.stderr.decode()
215
  result.append(f"failed: {err}")
@@ -225,23 +239,41 @@ def unsafe_execute_js(check_program, result, timeout):
225
  def unsafe_execute_rust(check_program, result, timeout):
226
 
227
  with create_tempdir():
228
- open(f"test.rs", 'w').write(check_program)
229
 
230
- log_path = "test.jsonl"
231
- cargo_check: str = "cargo check --bin test --message-format json >> " + log_path
232
- returned_val_compilation = os.system(cargo_check)
233
- if returned_val_compilation == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- cargo_test: str = "cargo test --bin test --message-format json >> " + log_path
236
- returned_val_execution = os.system(cargo_test)
237
 
238
- if returned_val_execution == 0:
 
 
 
239
  result.append("passed")
240
  else:
241
- result.append("failed: execution error")
242
  else:
243
- result.append("failed: compilation error")
244
-
245
 
246
 
247
  @contextlib.contextmanager
 
18
  import contextlib
19
  import faulthandler
20
  import io
21
+ import logging
22
  import multiprocessing
23
  import os
24
  import platform
 
26
  import subprocess
27
  import tempfile
28
 
29
+ # https://github.com/THUDM/CodeGeeX/blob/ebeb850f227a90c79de39f7e26b1302f374f3240/codegeex/benchmark/rust/Cargo.toml
30
+ BASE_CARGO = """[package]
31
+ name = "rust"
32
+ version = "0.1.0"
33
+ edition = "2021"
34
+
35
+ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
36
+
37
+ [dependencies]
38
+ rand = "0.4"
39
+ regex = "1"
40
+ md5 = "0.7.0
41
+ """
42
 
43
 
44
  def check_correctness(check_program, timeout, task_id, completion_id, language):
 
223
 
224
  # Run program.
225
  try:
226
+ exec_result = ["node", "test.js"], timeout=timeout, capture_output=True)
227
  if exec_result.stderr.decode():
228
  err = exec_result.stderr.decode()
229
  result.append(f"failed: {err}")
 
239
  def unsafe_execute_rust(check_program, result, timeout):
240
 
241
  with create_tempdir():
 
242
 
243
+ WD: str = os.getcwd()
244
+ RUST_DIR: str = os.path.join(WD, "rust")
245
+ RUST_SRC: str = os.path.join(RUST_DIR, "src")
246
+ RUST_BIN: str = os.path.join(RUST_SRC, "bin")
247
+ RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
248
+ RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
249
+ RUST_EXT: str = ".rs"
250
+
251
+ # Create mandatory tmp directories
252
+ os.makedirs(RUST_TMP_DIR, exist_ok=True)
253
+ os.makedirs(RUST_LOGS, exist_ok=True)
254
+ os.makedirs(RUST_SRC, exist_ok=True)
255
+ os.makedirs(RUST_BIN, exist_ok=True)
256
+
257
+ # Check if Cargo exists, if so copy it here
258
+ if os.path.exists("../Cargo.toml"):
259
+ shutil.copy("../Cargo.toml", RUST_DIR)
260
+ else:
261
+ # Warn that no Cargo was found in the parent directory
262
+ logging.warning(f"Cargo.toml not found in {os.path.abspath('../')}. Creating a new one. Timeout of >300 is recommended.")
263
+ # Create Cargo.toml
264
+ open(f"{RUST_DIR}/Cargo.toml", 'w').write(BASE_CARGO)
265
 
266
+ open(f"test.rs", 'w').write(check_program)
 
267
 
268
+ compilation_result = subprocess.run(["cargo", "check", "--bin", "test", "--message-format", "json"], capture_output=True)
269
+ if compilation_result.returncode == 0:
270
+ exec_result = subprocess.run(["cargo", "test", "--bin", "test", "--message-format", "json"], capture_output=True)
271
+ if exec_result.returncode == 0:
272
  result.append("passed")
273
  else:
274
+ result.append("failed: execution error: " + exec_result.stderr.decode())
275
  else:
276
+ result.append("failed: compilation error: " + compilation_result.stderr.decode())
 
277
 
278
 
279
  @contextlib.contextmanager