Spaces:
Sleeping
Sleeping
File size: 9,501 Bytes
25d2eb7 2827b8a 7a1cd7a 2827b8a 892ceeb 2827b8a 7a1cd7a 2827b8a 892ceeb 2827b8a 892ceeb 2827b8a 7a1cd7a 2827b8a 892ceeb 2827b8a 25d2eb7 7a1cd7a 892ceeb 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 892ceeb 2827b8a 7a1cd7a 892ceeb 7a1cd7a 892ceeb 7a1cd7a 892ceeb 7a1cd7a 25d2eb7 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a 7a1cd7a 2827b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import gradio as gr
from datasets import load_dataset
import numpy as np
from model2vec import StaticModel
from reach import Reach
from difflib import ndiff
def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[np.ndarray, dict[int, int]]:
"""
Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
"""
reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
# Use a set for deduplicated indices and keep track of duplicates
deduplicated_indices = set(range(len(embedding_matrix))) # Start with all indices as deduplicated
duplicate_to_original_mapping = {}
results = reach.nearest_neighbor_threshold(
embedding_matrix,
threshold=threshold,
batch_size=batch_size,
show_progressbar=True
)
# Process duplicates
for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
if i not in deduplicated_indices:
continue # Skip already marked duplicates
# Similar items are returned as (index, score), we are only interested in the index
similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
# Mark similar documents as duplicates and map them to the original
for sim_idx in similar_indices:
if sim_idx in deduplicated_indices:
deduplicated_indices.remove(sim_idx)
duplicate_to_original_mapping[sim_idx] = i # Map duplicate to original
return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[list[int], dict[int, int]]:
"""
Deduplicate embeddings across two datasets and return the indices of duplicates between them.
"""
reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
# Keep track of duplicates in the second dataset
duplicate_indices_in_test = []
duplicate_to_original_mapping = {}
# Find nearest neighbors from the test set in the train set
results = reach.nearest_neighbor_threshold(
embedding_matrix_2,
threshold=threshold,
batch_size=batch_size,
show_progressbar=True
)
# Process duplicates
for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
# Similar items are returned as (index, score), we are only interested in the index
similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
# If we find a similar item in the train set, mark it as a duplicate
if similar_indices:
duplicate_indices_in_test.append(i)
duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train
return duplicate_indices_in_test, duplicate_to_original_mapping
def display_word_differences(x: str, y: str) -> str:
diff = ndiff(x.split(), y.split())
return " ".join([word for word in diff if word.startswith(('+', '-'))])
def perform_deduplication(
deduplication_type,
dataset1_name,
dataset1_split,
dataset1_text_column,
dataset2_name,
dataset2_split,
dataset2_text_column,
threshold,
progress=gr.Progress(track_tqdm=True)
):
# Convert threshold to float
threshold = float(threshold)
if deduplication_type == "Single dataset":
# Load the dataset
ds = load_dataset(dataset1_name, split=dataset1_split)
# Extract texts
texts = [example[dataset1_text_column] for example in ds]
# Compute embeddings
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
embedding_matrix = model.encode(texts, show_progressbar=True)
# Deduplicate
deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
# Prepare the results
num_duplicates = len(duplicate_to_original_mapping)
num_total = len(texts)
num_deduplicated = len(deduplicated_indices)
result_text = f"**Total documents:** {num_total}\n"
result_text += f"**Number of duplicates found:** {num_duplicates}\n"
result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
# Show deduplicated examples
result_text += "**Examples of duplicates found:**\n\n"
num_examples = min(5, num_duplicates)
examples_shown = 0
for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
original_text = texts[original_idx]
duplicate_text = texts[duplicate_idx]
differences = display_word_differences(original_text, duplicate_text)
result_text += f"**Original text:**\n{original_text}\n\n"
result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
result_text += f"**Differences:**\n{differences}\n"
result_text += "-" * 50 + "\n\n"
examples_shown += 1
return result_text
elif deduplication_type == "Cross-dataset":
# Load datasets
ds1 = load_dataset(dataset1_name, split=dataset1_split)
ds2 = load_dataset(dataset2_name, split=dataset2_split)
# Extract texts
texts1 = [example[dataset1_text_column] for example in ds1]
texts2 = [example[dataset2_text_column] for example in ds2]
# Compute embeddings
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
embedding_matrix1 = model.encode(texts1, show_progressbar=True)
embedding_matrix2 = model.encode(texts2, show_progressbar=True)
# Deduplicate across datasets
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
num_duplicates = len(duplicate_indices_in_ds2)
num_total_ds2 = len(texts2)
num_unique_ds2 = num_total_ds2 - num_duplicates
result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
# Show deduplicated examples
result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
num_examples = min(5, num_duplicates)
examples_shown = 0
for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
original_idx = duplicate_to_original_mapping[duplicate_idx]
original_text = texts1[original_idx]
duplicate_text = texts2[duplicate_idx]
differences = display_word_differences(original_text, duplicate_text)
result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
result_text += f"**Differences:**\n{differences}\n"
result_text += "-" * 50 + "\n\n"
examples_shown += 1
return result_text
with gr.Blocks() as demo:
gr.Markdown("# Semantic Deduplication")
deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
with gr.Tab("Dataset 1"):
with gr.Row():
dataset1_name = gr.Textbox(value="ag_news", label="Dataset Name")
dataset1_split = gr.Textbox(value="train", label="Split")
dataset1_text_column = gr.Textbox(value="text", label="Text Column Name")
dataset2_tab = gr.Tab("Dataset 2", visible=False)
with dataset2_tab:
with gr.Row():
dataset2_name = gr.Textbox(value="ag_news", label="Dataset Name")
dataset2_split = gr.Textbox(value="test", label="Split")
dataset2_text_column = gr.Textbox(value="text", label="Text Column Name")
threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
compute_button = gr.Button("Compute")
output = gr.Markdown()
# Function to update the visibility of dataset2_tab
def update_visibility(deduplication_type):
if deduplication_type == "Cross-dataset":
return {dataset2_tab: gr.update(visible=True)}
else:
return {dataset2_tab: gr.update(visible=False)}
deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_tab])
compute_button.click(
fn=perform_deduplication,
inputs=[
deduplication_type,
dataset1_name,
dataset1_split,
dataset1_text_column,
dataset2_name,
dataset2_split,
dataset2_text_column,
threshold
],
outputs=output
)
demo.launch()
|