File size: 9,501 Bytes
25d2eb7
2827b8a
 
 
 
7a1cd7a
2827b8a
892ceeb
2827b8a
 
 
 
 
 
 
 
 
 
 
 
 
7a1cd7a
2827b8a
 
 
892ceeb
2827b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
892ceeb
2827b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
7a1cd7a
2827b8a
 
 
892ceeb
2827b8a
 
 
 
 
 
 
 
 
25d2eb7
7a1cd7a
 
892ceeb
7a1cd7a
2827b8a
 
 
 
7a1cd7a
2827b8a
 
7a1cd7a
892ceeb
 
2827b8a
 
 
7a1cd7a
 
 
 
 
 
 
 
 
 
 
 
 
892ceeb
7a1cd7a
 
 
 
 
 
 
 
 
 
 
 
 
 
892ceeb
7a1cd7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892ceeb
7a1cd7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25d2eb7
2827b8a
 
7a1cd7a
2827b8a
7a1cd7a
 
 
 
 
 
 
 
 
 
 
 
 
 
2827b8a
7a1cd7a
2827b8a
7a1cd7a
2827b8a
7a1cd7a
 
 
 
 
2827b8a
7a1cd7a
 
 
2827b8a
 
 
7a1cd7a
 
 
 
 
 
 
 
 
 
2827b8a
 
7a1cd7a
2827b8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import gradio as gr
from datasets import load_dataset
import numpy as np
from model2vec import StaticModel
from reach import Reach
from difflib import ndiff

def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[np.ndarray, dict[int, int]]:
    """
    Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
    """
    reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])

    # Use a set for deduplicated indices and keep track of duplicates
    deduplicated_indices = set(range(len(embedding_matrix)))  # Start with all indices as deduplicated
    duplicate_to_original_mapping = {}

    results = reach.nearest_neighbor_threshold(
        embedding_matrix, 
        threshold=threshold, 
        batch_size=batch_size, 
        show_progressbar=True
    )

    # Process duplicates
    for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates")):
        if i not in deduplicated_indices:
            continue  # Skip already marked duplicates

        # Similar items are returned as (index, score), we are only interested in the index
        similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]

        # Mark similar documents as duplicates and map them to the original
        for sim_idx in similar_indices:
            if sim_idx in deduplicated_indices:
                deduplicated_indices.remove(sim_idx)
                duplicate_to_original_mapping[sim_idx] = i  # Map duplicate to original

    return np.array(list(deduplicated_indices)), duplicate_to_original_mapping

def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=gr.Progress(track_tqdm=True)) -> tuple[list[int], dict[int, int]]:
    """
    Deduplicate embeddings across two datasets and return the indices of duplicates between them.
    """
    reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])

    # Keep track of duplicates in the second dataset
    duplicate_indices_in_test = []
    duplicate_to_original_mapping = {}

    # Find nearest neighbors from the test set in the train set
    results = reach.nearest_neighbor_threshold(
        embedding_matrix_2, 
        threshold=threshold, 
        batch_size=batch_size, 
        show_progressbar=True
    )

    # Process duplicates
    for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets")):
        # Similar items are returned as (index, score), we are only interested in the index
        similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]  # Keep those above the threshold

        # If we find a similar item in the train set, mark it as a duplicate
        if similar_indices:
            duplicate_indices_in_test.append(i)
            duplicate_to_original_mapping[i] = similar_indices[0]  # Map duplicate in test to original in train

    return duplicate_indices_in_test, duplicate_to_original_mapping

def display_word_differences(x: str, y: str) -> str:
    diff = ndiff(x.split(), y.split())
    return " ".join([word for word in diff if word.startswith(('+', '-'))])

def perform_deduplication(
    deduplication_type,
    dataset1_name,
    dataset1_split,
    dataset1_text_column,
    dataset2_name,
    dataset2_split,
    dataset2_text_column,
    threshold,
    progress=gr.Progress(track_tqdm=True)
):
    # Convert threshold to float
    threshold = float(threshold)
    
    if deduplication_type == "Single dataset":
        # Load the dataset
        ds = load_dataset(dataset1_name, split=dataset1_split)
        
        # Extract texts
        texts = [example[dataset1_text_column] for example in ds]
        
        # Compute embeddings
        model = StaticModel.from_pretrained("minishlab/M2V_base_output")
        embedding_matrix = model.encode(texts, show_progressbar=True)
        
        # Deduplicate
        deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold, progress=progress)
        
        # Prepare the results
        num_duplicates = len(duplicate_to_original_mapping)
        num_total = len(texts)
        num_deduplicated = len(deduplicated_indices)
        
        result_text = f"**Total documents:** {num_total}\n"
        result_text += f"**Number of duplicates found:** {num_duplicates}\n"
        result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
        
        # Show deduplicated examples
        result_text += "**Examples of duplicates found:**\n\n"
        num_examples = min(5, num_duplicates)
        examples_shown = 0
        for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
            original_text = texts[original_idx]
            duplicate_text = texts[duplicate_idx]
            differences = display_word_differences(original_text, duplicate_text)
            result_text += f"**Original text:**\n{original_text}\n\n"
            result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
            result_text += f"**Differences:**\n{differences}\n"
            result_text += "-" * 50 + "\n\n"
            examples_shown += 1
        
        return result_text
        
    elif deduplication_type == "Cross-dataset":
        # Load datasets
        ds1 = load_dataset(dataset1_name, split=dataset1_split)
        ds2 = load_dataset(dataset2_name, split=dataset2_split)
        
        # Extract texts
        texts1 = [example[dataset1_text_column] for example in ds1]
        texts2 = [example[dataset2_text_column] for example in ds2]
        
        # Compute embeddings
        model = StaticModel.from_pretrained("minishlab/M2V_base_output")
        embedding_matrix1 = model.encode(texts1, show_progressbar=True)
        embedding_matrix2 = model.encode(texts2, show_progressbar=True)
        
        # Deduplicate across datasets
        duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold, progress=progress)
        
        num_duplicates = len(duplicate_indices_in_ds2)
        num_total_ds2 = len(texts2)
        num_unique_ds2 = num_total_ds2 - num_duplicates
        
        result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
        result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
        result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
        
        # Show deduplicated examples
        result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
        num_examples = min(5, num_duplicates)
        examples_shown = 0
        for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
            original_idx = duplicate_to_original_mapping[duplicate_idx]
            original_text = texts1[original_idx]
            duplicate_text = texts2[duplicate_idx]
            differences = display_word_differences(original_text, duplicate_text)
            result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
            result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
            result_text += f"**Differences:**\n{differences}\n"
            result_text += "-" * 50 + "\n\n"
            examples_shown += 1
        
        return result_text

with gr.Blocks() as demo:
    gr.Markdown("# Semantic Deduplication")
    
    deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
    
    with gr.Tab("Dataset 1"):
        with gr.Row():
            dataset1_name = gr.Textbox(value="ag_news", label="Dataset Name")
            dataset1_split = gr.Textbox(value="train", label="Split")
            dataset1_text_column = gr.Textbox(value="text", label="Text Column Name")
    
    dataset2_tab = gr.Tab("Dataset 2", visible=False)
    with dataset2_tab:
        with gr.Row():
            dataset2_name = gr.Textbox(value="ag_news", label="Dataset Name")
            dataset2_split = gr.Textbox(value="test", label="Split")
            dataset2_text_column = gr.Textbox(value="text", label="Text Column Name")
    
    threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
    
    compute_button = gr.Button("Compute")
    
    output = gr.Markdown()
    
    # Function to update the visibility of dataset2_tab
    def update_visibility(deduplication_type):
        if deduplication_type == "Cross-dataset":
            return {dataset2_tab: gr.update(visible=True)}
        else:
            return {dataset2_tab: gr.update(visible=False)}
    
    deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_tab])
    
    compute_button.click(
        fn=perform_deduplication,
        inputs=[
            deduplication_type, 
            dataset1_name, 
            dataset1_split, 
            dataset1_text_column,
            dataset2_name, 
            dataset2_split, 
            dataset2_text_column,
            threshold
        ],
        outputs=output
    )
    
demo.launch()