Spaces:
Build error
Build error
import json | |
import os | |
import gradio as gr | |
from distilabel.llms import InferenceEndpointsLLM, LlamaCppLLM | |
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller | |
file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf") | |
download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q5_K_S.gguf?download=true" | |
if not os.path.exists(file_path): | |
import requests | |
import tqdm | |
response = requests.get(download_url, stream=True) | |
total_length = int(response.headers.get("content-length")) | |
with open(file_path, "wb") as f: | |
for chunk in tqdm.tqdm( | |
response.iter_content(chunk_size=1024 * 1024), | |
total=total_length / (1024 * 1024), | |
unit="KB", | |
unit_scale=True, | |
): | |
f.write(chunk) | |
llm_cpp = LlamaCppLLM( | |
model_path=file_path, | |
n_gpu_layers=-1, | |
n_ctx=1000 * 114, | |
generation_kwargs={"max_new_tokens": 1000 * 14}, | |
) | |
task_cpp = ArgillaLabeller(llm=llm_cpp) | |
task_cpp.load() | |
llm_ep = InferenceEndpointsLLM( | |
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
generation_kwargs={"max_new_tokens": 1000}, | |
) | |
task_ep = ArgillaLabeller(llm=llm_ep) | |
task_ep.load() | |
def load_examples(): | |
with open("examples.json", "r") as f: | |
return json.load(f) | |
# Create Gradio examples | |
examples = load_examples()[:1] | |
def process_fields(fields): | |
if isinstance(fields, str): | |
fields = json.loads(fields) | |
if isinstance(fields, dict): | |
fields = [fields] | |
return [field if isinstance(field, dict) else json.loads(field) for field in fields] | |
def process_records_gradio(records, fields, question, example_records=None): | |
try: | |
# Convert string inputs to dictionaries | |
if isinstance(records, str) and records: | |
records = json.loads(records) | |
if isinstance(example_records, str) and example_records: | |
example_records = json.loads(example_records) | |
if isinstance(fields, str) and fields: | |
fields = json.loads(fields) | |
if isinstance(question, str) and question: | |
question = json.loads(question) | |
if not fields and not question: | |
raise Exception("Error: Either fields or question must be provided") | |
runtime_parameters = {"fields": fields, "question": question} | |
if example_records: | |
runtime_parameters["example_records"] = example_records | |
task_ep.set_runtime_parameters(runtime_parameters) | |
task_cpp.set_runtime_parameters(runtime_parameters) | |
results = [] | |
try: | |
output = next( | |
task_ep.process(inputs=[{"record": record} for record in records]) | |
) | |
except Exception: | |
output = next( | |
task_cpp.process(inputs=[{"record": record} for record in records]) | |
) | |
for idx in range(len(records)): | |
entry = output[idx] | |
if entry["suggestions"]: | |
results.append(entry["suggestions"]) | |
return json.dumps({"results": results}, indent=2) | |
except Exception as e: | |
raise gr.Error(f"Error: {str(e)}") | |
description = """ | |
An example workflow for JSON payload. | |
```python | |
import json | |
import os | |
from gradio_client import Client | |
import argilla as rg | |
# Initialize Argilla client | |
gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
argilla_client = rg.Argilla( | |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"] | |
) | |
# Load the dataset | |
dataset = argilla_client.datasets(name="my_dataset", workspace="my_workspace") | |
# Get the field and question | |
field = dataset.settings.fields["text"] | |
question = dataset.settings.questions["sentiment"] | |
# Get completed and pending records | |
completed_records_filter = rg.Filter(("status", "==", "completed")) | |
pending_records_filter = rg.Filter(("status", "==", "pending")) | |
example_records = list( | |
dataset.records( | |
query=rg.Query(filter=completed_records_filter), | |
limit=5, | |
) | |
) | |
some_pending_records = list( | |
dataset.records( | |
query=rg.Query(filter=pending_records_filter), | |
limit=5, | |
) | |
) | |
# Process the records | |
payload = { | |
"records": [record.to_dict() for record in some_pending_records], | |
"fields": [field.serialize()], | |
"question": question.serialize(), | |
"example_records": [record.to_dict() for record in example_records], | |
"api_name": "/predict", | |
} | |
response = gradio_client.predict(**payload) | |
``` | |
""" | |
interface = gr.Interface( | |
fn=process_records_gradio, | |
inputs=[ | |
gr.Code(label="Records (JSON)", language="json", lines=5), | |
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5), | |
gr.Code(label="Fields (JSON, optional)", language="json"), | |
gr.Code(label="Question (JSON, optional)", language="json"), | |
], | |
examples=examples, | |
cache_examples=False, | |
outputs=gr.Code(label="Suggestions", language="json", lines=10), | |
title="Distilabel - ArgillaLabeller - Record Processing Interface", | |
description=description, | |
) | |
if __name__ == "__main__": | |
interface.launch() | |