Spaces:
Runtime error
Runtime error
import argparse | |
import glob | |
import json | |
import logging | |
import multiprocessing as mp | |
import os | |
import time | |
import uuid | |
from datetime import timedelta | |
from functools import lru_cache | |
from typing import List, Union | |
import aegis | |
import gradio as gr | |
import requests | |
from huggingface_hub import HfApi | |
from optimum.onnxruntime import ORTModelForSequenceClassification | |
from rebuff import Rebuff | |
from transformers import AutoTokenizer, pipeline | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
hf_api = HfApi() | |
num_processes = 2 # mp.cpu_count() | |
lakera_api_key = os.getenv("LAKERA_API_KEY") | |
automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY") | |
rebuff_api_key = os.getenv("REBUFF_API_KEY") | |
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline: | |
hf_model = ORTModelForSequenceClassification.from_pretrained( | |
prompt_injection_ort_model, | |
export=False, | |
subfolder=subfolder, | |
) | |
hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder) | |
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] | |
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU") | |
return pipeline( | |
"text-classification", | |
model=hf_model, | |
tokenizer=hf_tokenizer, | |
device="cpu", | |
batch_size=1, | |
truncation=True, | |
max_length=512, | |
) | |
def convert_elapsed_time(diff_time) -> float: | |
return round(timedelta(seconds=diff_time).total_seconds(), 2) | |
deepset_classifier = init_prompt_injection_model( | |
"laiyer/deberta-v3-base-injection-onnx" | |
) # ONNX version of deepset/deberta-v3-base-injection | |
laiyer_classifier = init_prompt_injection_model("laiyer/deberta-v3-base-prompt-injection", "onnx") | |
def detect_hf(prompt: str, threshold: float = 0.5, classifier=laiyer_classifier) -> (bool, bool): | |
try: | |
pi_result = classifier(prompt) | |
injection_score = round( | |
pi_result[0]["score"] | |
if pi_result[0]["label"] == "INJECTION" | |
else 1 - pi_result[0]["score"], | |
2, | |
) | |
logger.info(f"Prompt injection result from the HF model: {pi_result}") | |
return True, injection_score > threshold | |
except Exception as err: | |
logger.error(f"Failed to call HF model: {err}") | |
return False, False | |
def detect_hf_laiyer(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=laiyer_classifier) | |
def detect_hf_deepset(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=deepset_classifier) | |
def detect_lakera(prompt: str) -> (bool, bool): | |
try: | |
response = requests.post( | |
"https://api.lakera.ai/v1/prompt_injection", | |
json={"input": prompt}, | |
headers={"Authorization": f"Bearer {lakera_api_key}"}, | |
) | |
response_json = response.json() | |
logger.info(f"Prompt injection result from Lakera: {response.json()}") | |
return True, response_json["results"][0]["flagged"] | |
except requests.RequestException as err: | |
logger.error(f"Failed to call Lakera API: {err}") | |
return False, False | |
def detect_automorphic(prompt: str) -> (bool, bool): | |
ag = aegis.Aegis(automorphic_api_key) | |
try: | |
ingress_attack_detected = ag.ingress(prompt, "") | |
logger.info(f"Prompt injection result from Automorphic: {ingress_attack_detected}") | |
return True, ingress_attack_detected["detected"] | |
except Exception as err: | |
logger.error(f"Failed to call Automorphic API: {err}") | |
return False, False # Assume it's not attack | |
def detect_rebuff(prompt: str) -> (bool, bool): | |
try: | |
rb = Rebuff(api_token=rebuff_api_key, api_url="https://www.rebuff.ai") | |
result = rb.detect_injection(prompt) | |
logger.info(f"Prompt injection result from Rebuff: {result}") | |
return True, result.injectionDetected | |
except Exception as err: | |
logger.error(f"Failed to call Rebuff API: {err}") | |
return False, False | |
detection_providers = { | |
"Laiyer (HF model)": detect_hf_laiyer, | |
"Deepset (HF model)": detect_hf_deepset, | |
"Lakera Guard": detect_lakera, | |
"Automorphic Aegis": detect_automorphic, | |
"Rebuff": detect_rebuff, | |
} | |
def is_detected(provider: str, prompt: str) -> (str, bool, bool, float): | |
if provider not in detection_providers: | |
logger.warning(f"Provider {provider} is not supported") | |
return False, 0.0 | |
start_time = time.monotonic() | |
request_result, is_injection = detection_providers[provider](prompt) | |
end_time = time.monotonic() | |
return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time) | |
def execute(prompt: str, store_to_dataset: bool = True) -> List[Union[str, bool, float]]: | |
results = [] | |
with mp.Pool(processes=num_processes) as pool: | |
for result in pool.starmap( | |
is_detected, [(provider, prompt) for provider in detection_providers.keys()] | |
): | |
results.append(result) | |
# Save image and result | |
if store_to_dataset: | |
fileobj = json.dumps({"prompt": prompt, "results": results}, indent=2).encode("utf-8") | |
result_path = f"/prompts/train/{str(uuid.uuid4())}.json" | |
hf_api.upload_file( | |
path_or_fileobj=fileobj, | |
path_in_repo=result_path, | |
repo_id="laiyer/prompt-injection-benchmark", | |
repo_type="dataset", | |
) | |
logger.info(f"Stored prompt: {prompt}") | |
return results | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--url", type=str, default="0.0.0.0") | |
args, left_argv = parser.parse_known_args() | |
example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt")) | |
examples = [open(file).read() for file in example_files] | |
gr.Interface( | |
fn=execute, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
gr.Checkbox( | |
label="Store prompt and results to the public dataset `laiyer/prompt-injection-benchmark`", | |
value=True, | |
), | |
], | |
outputs=[ | |
gr.Dataframe( | |
headers=[ | |
"Provider", | |
"Is request successful?", | |
"Is prompt injection?", | |
"Latency (seconds)", | |
], | |
datatype=["str", "bool", "bool", "number"], | |
label="Results", | |
), | |
], | |
title="Prompt Injection Benchmark", | |
description="This interface aims to benchmark the prompt injection detection providers. The results are stored in the public dataset for fairness of all sides.", | |
examples=[ | |
[ | |
example, | |
False, | |
] | |
for example in examples | |
], | |
cache_examples=True, | |
allow_flagging="never", | |
).queue(1).launch(server_name=args.url, server_port=args.port) | |