|
import asyncio |
|
import re |
|
from typing import Dict, List |
|
|
|
import gradio as gr |
|
import httpx |
|
from cashews import cache |
|
from huggingface_hub import ModelCard |
|
|
|
from ragatouille_search import create_ragatouille_interface |
|
|
|
cache.setup("mem://") |
|
API_URL = "https://davanstrien-huggingface-datasets-search-v2.hf.space" |
|
HF_API_URL = "https://huggingface.co/api/datasets" |
|
README_URL_TEMPLATE = "https://huggingface.co/datasets/{}/raw/main/README.md" |
|
|
|
|
|
async def fetch_similar_datasets(dataset_id: str, limit: int = 10) -> List[Dict]: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.get( |
|
f"{API_URL}/similar?dataset_id={dataset_id}&n={limit + 1}" |
|
) |
|
if response.status_code == 200: |
|
results = response.json()["results"] |
|
|
|
return [r for r in results if r["dataset_id"] != dataset_id][:limit] |
|
return [] |
|
|
|
|
|
async def fetch_similar_datasets_by_text(query: str, limit: int = 10) -> List[Dict]: |
|
async with httpx.AsyncClient(timeout=30) as client: |
|
response = await client.get( |
|
f"{API_URL}/similar-text", params={"query": query, "n": limit + 1} |
|
) |
|
if response.status_code == 200: |
|
results = response.json()["results"] |
|
return results[:limit] |
|
return [] |
|
|
|
|
|
async def search_similar_datasets_by_text(query: str, limit: int = 10): |
|
results = await fetch_similar_datasets_by_text(query, limit) |
|
|
|
if not results: |
|
return "No similar datasets found." |
|
|
|
|
|
dataset_cards = await asyncio.gather( |
|
*[fetch_dataset_card(result["dataset_id"]) for result in results] |
|
) |
|
dataset_infos = await asyncio.gather( |
|
*[fetch_dataset_info(result["dataset_id"]) for result in results] |
|
) |
|
|
|
return format_results(results, dataset_cards, dataset_infos) |
|
|
|
|
|
async def fetch_dataset_card(dataset_id: str) -> str: |
|
url = README_URL_TEMPLATE.format(dataset_id) |
|
async with httpx.AsyncClient() as client: |
|
response = await client.get(url) |
|
return ModelCard(response.text).text if response.status_code == 200 else "" |
|
|
|
|
|
async def fetch_dataset_info(dataset_id: str) -> Dict: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.get(f"{HF_API_URL}/{dataset_id}") |
|
return response.json() if response.status_code == 200 else {} |
|
|
|
|
|
def format_results( |
|
results: List[Dict], dataset_cards: List[str], dataset_infos: List[Dict] |
|
) -> str: |
|
markdown = ( |
|
"<h1 style='text-align: center;'>✨ Similar Datasets ✨</h1>\n\n" |
|
) |
|
for result, card, info in zip(results, dataset_cards, dataset_infos): |
|
hub_id = result["dataset_id"] |
|
similarity = result["similarity"] |
|
url = f"https://huggingface.co/datasets/{hub_id}" |
|
|
|
|
|
header = f"## [{hub_id}]({url})" |
|
markdown += header + "\n" |
|
markdown += f"**Similarity Score:** {similarity:.4f}\n\n" |
|
|
|
if info: |
|
downloads = info.get("downloads", 0) |
|
likes = info.get("likes", 0) |
|
last_modified = info.get("lastModified", "N/A") |
|
markdown += f"**Downloads:** {downloads} | **Likes:** {likes} | **Last Modified:** {last_modified}\n\n" |
|
|
|
if card: |
|
|
|
card_without_title = re.sub( |
|
r"^#.*\n", "", card, count=1, flags=re.MULTILINE |
|
) |
|
|
|
|
|
paragraphs = card_without_title.split("\n\n") |
|
|
|
|
|
preview = next( |
|
( |
|
p |
|
for p in paragraphs |
|
if p.strip() |
|
and not p.strip().startswith("![") |
|
and not p.strip().startswith("<img") |
|
), |
|
"No preview available.", |
|
) |
|
|
|
|
|
preview = f"{preview[:300]}..." if len(preview) > 300 else preview |
|
|
|
|
|
markdown += f"{preview}\n\n" |
|
|
|
|
|
full_card = re.sub( |
|
r'<img src="([^"]+)"', |
|
r'<img src="\1" style="max-width: 300px; max-height: 300px;"', |
|
card_without_title, |
|
) |
|
full_card = re.sub( |
|
r"!\[([^\]]*)\]\(([^\)]+)\)", |
|
r'<img src="\2" alt="\1" style="max-width: 300px; max-height: 300px;">', |
|
full_card, |
|
) |
|
markdown += f"<details><summary>Full Dataset Card</summary>\n\n{full_card}\n\n</details>\n\n" |
|
|
|
markdown += "---\n\n" |
|
|
|
return markdown |
|
|
|
|
|
async def search_similar_datasets(dataset_id: str, limit: int = 10): |
|
results = await fetch_similar_datasets(dataset_id, limit) |
|
|
|
if not results: |
|
return "No similar datasets found." |
|
|
|
|
|
dataset_cards = await asyncio.gather( |
|
*[fetch_dataset_card(result["dataset_id"]) for result in results] |
|
) |
|
dataset_infos = await asyncio.gather( |
|
*[fetch_dataset_info(result["dataset_id"]) for result in results] |
|
) |
|
|
|
return format_results(results, dataset_cards, dataset_infos) |
|
|
|
|
|
async def search_viewer(query: str, limit: int = 10): |
|
async with httpx.AsyncClient(timeout=30) as client: |
|
response = await client.get( |
|
f"{API_URL}/search-viewer", params={"query": query, "n": limit} |
|
) |
|
if response.status_code == 200: |
|
results = response.json()["results"] |
|
return format_viewer_results(results) |
|
return "No results found." |
|
|
|
|
|
def format_viewer_results(results: List[Dict]) -> str: |
|
html = "<div style='height: 600px; overflow-y: auto;'>" |
|
for result in results: |
|
dataset_id = result["dataset_id"] |
|
html += f""" |
|
<div style='margin-bottom: 20px; border: 1px solid #ddd; padding: 10px;'> |
|
<h3>{dataset_id}</h3> |
|
<p><strong>Similarity Score:</strong> {result['similarity']:.4f}</p> |
|
<iframe |
|
src="https://huggingface.co/datasets/{dataset_id}/embed/viewer/default/train" |
|
frameborder="0" |
|
width="100%" |
|
height="560px" |
|
></iframe> |
|
</div> |
|
""" |
|
html += "</div>" |
|
return html |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🤗 Dataset Search and Similarity") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Similar Datasets"): |
|
gr.Markdown("## 🤗 Dataset Similarity Search") |
|
with gr.Row(): |
|
gr.Markdown( |
|
"This Gradio app allows you to find similar datasets based on a given dataset ID or a text query. " |
|
"Choose the search type and enter either a dataset ID or a text query to find similar datasets with previews of their dataset cards.\n\n" |
|
"For a seamless experience on the Hugging Face website, check out the " |
|
"[Hugging Face Similar Chrome extension](https://chromewebstore.google.com/detail/hugging-face-similar/aijelnjllajooinkcpkpbhckbghghpnl?authuser=0&hl=en). " |
|
"This extension adds a 'Similar Datasets' section directly to Hugging Face dataset pages, " |
|
"making it even easier to discover related datasets for your projects." |
|
) |
|
|
|
with gr.Row(): |
|
search_type = gr.Radio( |
|
["Dataset ID", "Text Query"], |
|
label="Search Type", |
|
value="Dataset ID", |
|
) |
|
|
|
with gr.Row(): |
|
dataset_id = gr.Textbox( |
|
value="airtrain-ai/fineweb-edu-fortified", |
|
label="Dataset ID (e.g., airtrain-ai/fineweb-edu-fortified)", |
|
) |
|
text_query = gr.Textbox( |
|
label="Text Query (e.g., 'natural language processing dataset')", |
|
visible=False, |
|
) |
|
|
|
with gr.Row(): |
|
search_btn = gr.Button("Search Similar Datasets") |
|
max_results = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=10, |
|
label="Maximum number of results", |
|
) |
|
|
|
results = gr.Markdown() |
|
|
|
def toggle_input_visibility(choice): |
|
return gr.update(visible=choice == "Dataset ID"), gr.update( |
|
visible=choice == "Text Query" |
|
) |
|
|
|
search_type.change( |
|
toggle_input_visibility, |
|
inputs=[search_type], |
|
outputs=[dataset_id, text_query], |
|
) |
|
|
|
search_btn.click( |
|
lambda search_type, dataset_id, text_query, limit: asyncio.run( |
|
search_similar_datasets(dataset_id, limit) |
|
if search_type == "Dataset ID" |
|
else search_similar_datasets_by_text(text_query, limit) |
|
), |
|
inputs=[search_type, dataset_id, text_query, max_results], |
|
outputs=results, |
|
) |
|
|
|
with gr.TabItem("RAGatouille Search"): |
|
ragatouille_interface = create_ragatouille_interface() |
|
|
|
with gr.TabItem("Search Viewer"): |
|
gr.Markdown("## 🔍 Search Viewer") |
|
with gr.Row(): |
|
gr.Markdown( |
|
"This tab allows you to search for datasets using their dataset viewer preview! " |
|
"Unlike the other search methods, this search utilizes the dataset viewer embedded in most datasets to match your query. " |
|
"This means it doesn't rely on the dataset card for matching!\n\n" |
|
"Enter a query to find relevant datasets and preview them directly using the dataset viewer.\n\n" |
|
"Currently, this search is using a subset of datasets and a very early version of an embedding model to match natural language queries to datasets." |
|
"**Help us improve!** Contribute to query quality improvement by participating in our " |
|
"[Argilla annotation task](https://huggingface.co/spaces/davanstrien/my-argilla). Your feedback helps refine search results for everyone." |
|
) |
|
|
|
with gr.Row(): |
|
viewer_query = gr.Textbox( |
|
label="Search Query", placeholder="Enter your search query here" |
|
) |
|
|
|
with gr.Row(): |
|
viewer_search_btn = gr.Button("Search") |
|
viewer_max_results = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=10, |
|
label="Maximum number of results", |
|
) |
|
|
|
viewer_results = gr.HTML() |
|
|
|
viewer_search_btn.click( |
|
lambda query, limit: asyncio.run(search_viewer(query, limit)), |
|
inputs=[viewer_query, viewer_max_results], |
|
outputs=viewer_results, |
|
) |
|
|
|
demo.launch() |
|
|