Pringled commited on
Commit
225d3fb
1 Parent(s): 2f9e086
Files changed (1) hide show
  1. app.py +259 -8
app.py CHANGED
@@ -10,7 +10,8 @@ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
10
 
11
  # Default parameters
12
  default_dataset_name = "sst2"
13
- default_dataset_split = "train"
 
14
  default_text_column = "sentence"
15
  default_threshold = 0.9
16
 
@@ -75,7 +76,7 @@ def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str)
75
  Load texts from a specified dataset and split.
76
 
77
  :param dataset_name: Name of the dataset.
78
- :param dataset_split: Split of the dataset (e.g., 'train', 'validation').
79
  :param text_column: Name of the text column.
80
  :return: A list of texts from the dataset.
81
  """
@@ -206,7 +207,7 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
206
 
207
  with gr.Row():
208
  dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
209
- dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
210
  dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
211
 
212
  dataset2_inputs = gr.Column(visible=True) # Make dataset2_inputs visible by default
@@ -214,7 +215,7 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
214
  gr.Markdown("### Dataset 2")
215
  with gr.Row():
216
  dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
217
- dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
218
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
219
 
220
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
@@ -449,7 +450,7 @@ demo.launch()
449
  # """)
450
 
451
  # deduplication_type = gr.Radio(
452
- # choices=["Single dataset", "Cross-dataset"],
453
  # label="Deduplication Type",
454
  # value="Cross-dataset", # Set "Cross-dataset" as the default value
455
  # )
@@ -468,7 +469,10 @@ demo.launch()
468
  # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
469
 
470
  # threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
471
- # compute_button = gr.Button("Deduplicate")
 
 
 
472
  # status_output = gr.Markdown(elem_id="status_output")
473
  # result_output = gr.Markdown()
474
 
@@ -698,7 +702,7 @@ demo.launch()
698
  # # deduplication_type = gr.Radio(
699
  # # choices=["Single dataset", "Cross-dataset"],
700
  # # label="Deduplication Type",
701
- # # value="Single dataset",
702
  # # )
703
 
704
  # # with gr.Row():
@@ -706,7 +710,7 @@ demo.launch()
706
  # # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
707
  # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
708
 
709
- # # dataset2_inputs = gr.Column(visible=False)
710
  # # with dataset2_inputs:
711
  # # gr.Markdown("### Dataset 2")
712
  # # with gr.Row():
@@ -741,3 +745,250 @@ demo.launch()
741
 
742
 
743
  # # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Default parameters
12
  default_dataset_name = "sst2"
13
+ default_dataset1_split = "train" # Default for the first dataset is "train"
14
+ default_dataset2_split = "test" # Default for the second dataset is "test"
15
  default_text_column = "sentence"
16
  default_threshold = 0.9
17
 
 
76
  Load texts from a specified dataset and split.
77
 
78
  :param dataset_name: Name of the dataset.
79
+ :param dataset_split: Split of the dataset (e.g., 'train', 'validation', 'test').
80
  :param text_column: Name of the text column.
81
  :return: A list of texts from the dataset.
82
  """
 
207
 
208
  with gr.Row():
209
  dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
210
+ dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split") # Default split is "train"
211
  dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
212
 
213
  dataset2_inputs = gr.Column(visible=True) # Make dataset2_inputs visible by default
 
215
  gr.Markdown("### Dataset 2")
216
  with gr.Row():
217
  dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
218
+ dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split") # Default split is "test"
219
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
220
 
221
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
 
450
  # """)
451
 
452
  # deduplication_type = gr.Radio(
453
+ # choices=["Cross-dataset", "Single dataset"], # Swapped "Cross-dataset" to the left
454
  # label="Deduplication Type",
455
  # value="Cross-dataset", # Set "Cross-dataset" as the default value
456
  # )
 
469
  # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
470
 
471
  # threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
472
+
473
+ # with gr.Row(): # Placing the button in the same row for better alignment
474
+ # compute_button = gr.Button("Deduplicate")
475
+
476
  # status_output = gr.Markdown(elem_id="status_output")
477
  # result_output = gr.Markdown()
478
 
 
702
  # # deduplication_type = gr.Radio(
703
  # # choices=["Single dataset", "Cross-dataset"],
704
  # # label="Deduplication Type",
705
+ # # value="Cross-dataset", # Set "Cross-dataset" as the default value
706
  # # )
707
 
708
  # # with gr.Row():
 
710
  # # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
711
  # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
712
 
713
+ # # dataset2_inputs = gr.Column(visible=True) # Make dataset2_inputs visible by default
714
  # # with dataset2_inputs:
715
  # # gr.Markdown("### Dataset 2")
716
  # # with gr.Row():
 
745
 
746
 
747
  # # demo.launch()
748
+
749
+ # # # import gradio as gr
750
+ # # # from datasets import load_dataset
751
+ # # # import numpy as np
752
+ # # # from model2vec import StaticModel
753
+ # # # from reach import Reach
754
+ # # # from difflib import ndiff
755
+
756
+ # # # # Load the model
757
+ # # # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
758
+
759
+ # # # # Default parameters
760
+ # # # default_dataset_name = "sst2"
761
+ # # # default_dataset_split = "train"
762
+ # # # default_text_column = "sentence"
763
+ # # # default_threshold = 0.9
764
+
765
+ # # # def deduplicate_embeddings(
766
+ # # # embeddings_a: np.ndarray,
767
+ # # # embeddings_b: np.ndarray = None,
768
+ # # # threshold: float = 0.9,
769
+ # # # batch_size: int = 1024,
770
+ # # # progress=None
771
+ # # # ) -> tuple[np.ndarray, dict[int, int]]:
772
+ # # # """
773
+ # # # Deduplicate embeddings within one dataset or across two datasets.
774
+
775
+ # # # :param embeddings_a: Embeddings of Dataset 1.
776
+ # # # :param embeddings_b: Optional, embeddings of Dataset 2.
777
+ # # # :param threshold: Similarity threshold for deduplication.
778
+ # # # :param batch_size: Batch size for similarity computation.
779
+ # # # :param progress: Gradio progress tracker for feedback.
780
+ # # # :return: Deduplicated indices and a mapping of removed indices to their original counterparts.
781
+ # # # """
782
+ # # # if embeddings_b is None:
783
+ # # # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
784
+ # # # duplicate_to_original = {}
785
+ # # # results = reach.nearest_neighbor_threshold(
786
+ # # # embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
787
+ # # # )
788
+ # # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
789
+ # # # for sim_idx, _ in similar_items:
790
+ # # # sim_idx = int(sim_idx)
791
+ # # # if sim_idx != i and sim_idx not in duplicate_to_original:
792
+ # # # duplicate_to_original[sim_idx] = i
793
+ # # # deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
794
+ # # # return deduplicated_indices, duplicate_to_original
795
+ # # # else:
796
+ # # # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
797
+ # # # duplicate_indices_in_b = []
798
+ # # # duplicate_to_original = {}
799
+ # # # results = reach.nearest_neighbor_threshold(
800
+ # # # embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
801
+ # # # )
802
+ # # # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
803
+ # # # if similar_items:
804
+ # # # duplicate_indices_in_b.append(i)
805
+ # # # duplicate_to_original[i] = int(similar_items[0][0])
806
+ # # # return duplicate_indices_in_b, duplicate_to_original
807
+
808
+ # # # def display_word_differences(x: str, y: str) -> str:
809
+ # # # """
810
+ # # # Display the word-level differences between two texts, formatted to avoid
811
+ # # # misinterpretation of Markdown syntax.
812
+
813
+ # # # :param x: First text.
814
+ # # # :param y: Second text.
815
+ # # # :return: A string showing word-level differences, wrapped in a code block.
816
+ # # # """
817
+ # # # diff = ndiff(x.split(), y.split())
818
+ # # # formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
819
+ # # # return f"```\n{formatted_diff}\n```"
820
+
821
+ # # # def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
822
+ # # # """
823
+ # # # Load texts from a specified dataset and split.
824
+
825
+ # # # :param dataset_name: Name of the dataset.
826
+ # # # :param dataset_split: Split of the dataset (e.g., 'train', 'validation').
827
+ # # # :param text_column: Name of the text column.
828
+ # # # :return: A list of texts from the dataset.
829
+ # # # """
830
+ # # # ds = load_dataset(dataset_name, split=dataset_split)
831
+ # # # return [example[text_column] for example in ds]
832
+
833
+ # # # def perform_deduplication(
834
+ # # # deduplication_type: str,
835
+ # # # dataset1_name: str,
836
+ # # # dataset1_split: str,
837
+ # # # dataset1_text_column: str,
838
+ # # # dataset2_name: str = "",
839
+ # # # dataset2_split: str = "",
840
+ # # # dataset2_text_column: str = "",
841
+ # # # threshold: float = default_threshold,
842
+ # # # progress: gr.Progress = gr.Progress(track_tqdm=True)
843
+ # # # ):
844
+ # # # """
845
+ # # # Perform deduplication on one or two datasets based on the deduplication type.
846
+
847
+ # # # :param deduplication_type: 'Single dataset' or 'Cross-dataset'.
848
+ # # # :param dataset1_name: Name of the first dataset.
849
+ # # # :param dataset1_split: Split of the first dataset.
850
+ # # # :param dataset1_text_column: Text column of the first dataset.
851
+ # # # :param dataset2_name: Optional, name of the second dataset (for cross-dataset deduplication).
852
+ # # # :param dataset2_split: Optional, split of the second dataset.
853
+ # # # :param dataset2_text_column: Optional, text column of the second dataset.
854
+ # # # :param threshold: Similarity threshold for deduplication.
855
+ # # # :param progress: Gradio progress tracker.
856
+ # # # :return: Status updates and result text for the Gradio interface.
857
+ # # # """
858
+ # # # try:
859
+ # # # threshold = float(threshold)
860
+
861
+ # # # # Load and process Dataset 1
862
+ # # # yield "Loading Dataset 1...", ""
863
+ # # # texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
864
+ # # # yield "Computing embeddings for Dataset 1...", ""
865
+ # # # embeddings1 = model.encode(texts1, show_progressbar=True)
866
+
867
+ # # # if deduplication_type == "Single dataset":
868
+ # # # # Deduplicate within Dataset 1
869
+ # # # yield "Deduplicating within Dataset 1...", ""
870
+ # # # deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
871
+ # # # embeddings1, threshold=threshold, progress=progress
872
+ # # # )
873
+
874
+ # # # num_duplicates = len(duplicate_mapping)
875
+ # # # result_text = (
876
+ # # # f"**Total documents:** {len(texts1)}\n\n"
877
+ # # # f"**Duplicates found:** {num_duplicates}\n\n"
878
+ # # # f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
879
+ # # # )
880
+
881
+ # # # if num_duplicates > 0:
882
+ # # # result_text += "**Sample duplicates:**\n\n"
883
+ # # # for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
884
+ # # # orig_text = texts1[orig_idx]
885
+ # # # dup_text = texts1[dup_idx]
886
+ # # # differences = display_word_differences(orig_text, dup_text)
887
+ # # # result_text += (
888
+ # # # f"**Original:**\n{orig_text}\n\n"
889
+ # # # f"**Duplicate:**\n{dup_text}\n\n"
890
+ # # # f"**Differences:**\n{differences}\n"
891
+ # # # + "-" * 50 + "\n\n"
892
+ # # # )
893
+ # # # else:
894
+ # # # result_text += "No duplicates found."
895
+
896
+ # # # yield "Deduplication completed.", result_text
897
+
898
+ # # # else:
899
+ # # # # Load and process Dataset 2
900
+ # # # yield "Loading Dataset 2...", ""
901
+ # # # texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
902
+ # # # yield "Computing embeddings for Dataset 2...", ""
903
+ # # # embeddings2 = model.encode(texts2, show_progressbar=True)
904
+
905
+ # # # # Deduplicate Dataset 2 against Dataset 1
906
+ # # # yield "Deduplicating Dataset 2 against Dataset 1...", ""
907
+ # # # duplicate_indices, duplicate_mapping = deduplicate_embeddings(
908
+ # # # embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
909
+ # # # )
910
+
911
+ # # # num_duplicates = len(duplicate_indices)
912
+ # # # result_text = (
913
+ # # # f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n\n"
914
+ # # # f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
915
+ # # # f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
916
+ # # # )
917
+
918
+ # # # if num_duplicates > 0:
919
+ # # # result_text += "**Sample duplicates from Dataset 2:**\n\n"
920
+ # # # for idx in duplicate_indices[:5]:
921
+ # # # orig_text = texts1[duplicate_mapping[idx]]
922
+ # # # dup_text = texts2[idx]
923
+ # # # differences = display_word_differences(orig_text, dup_text)
924
+ # # # result_text += (
925
+ # # # f"**Original (Dataset 1):**\n{orig_text}\n\n"
926
+ # # # f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
927
+ # # # f"**Differences:**\n{differences}\n"
928
+ # # # + "-" * 50 + "\n\n"
929
+ # # # )
930
+ # # # else:
931
+ # # # result_text += "No duplicates found."
932
+
933
+ # # # yield "Deduplication completed.", result_text
934
+
935
+ # # # except Exception as e:
936
+ # # # yield f"An error occurred: {e}", ""
937
+ # # # raise e
938
+
939
+ # # # # Gradio app with stop button support
940
+ # # # with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
941
+ # # # gr.Markdown("# Semantic Deduplication")
942
+ # # # gr.Markdown("""
943
+ # # # This demo showcases semantic deduplication using Model2Vec for HuggingFace datasets.
944
+ # # # It can be used to identify duplicate texts within a single dataset or across two datasets.
945
+ # # # You can adjust the similarity threshold to control the strictness of the deduplication.\n
946
+ # # # NOTE: this demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally.
947
+ # # # """)
948
+
949
+ # # # deduplication_type = gr.Radio(
950
+ # # # choices=["Single dataset", "Cross-dataset"],
951
+ # # # label="Deduplication Type",
952
+ # # # value="Single dataset",
953
+ # # # )
954
+
955
+ # # # with gr.Row():
956
+ # # # dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
957
+ # # # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
958
+ # # # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
959
+
960
+ # # # dataset2_inputs = gr.Column(visible=False)
961
+ # # # with dataset2_inputs:
962
+ # # # gr.Markdown("### Dataset 2")
963
+ # # # with gr.Row():
964
+ # # # dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
965
+ # # # dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
966
+ # # # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
967
+
968
+ # # # threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
969
+ # # # compute_button = gr.Button("Deduplicate")
970
+ # # # status_output = gr.Markdown(elem_id="status_output")
971
+ # # # result_output = gr.Markdown()
972
+
973
+ # # # def update_visibility(choice: str):
974
+ # # # return gr.update(visible=choice == "Cross-dataset")
975
+
976
+ # # # deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
977
+
978
+ # # # compute_button.click(
979
+ # # # fn=perform_deduplication,
980
+ # # # inputs=[
981
+ # # # deduplication_type,
982
+ # # # dataset1_name,
983
+ # # # dataset1_split,
984
+ # # # dataset1_text_column,
985
+ # # # dataset2_name,
986
+ # # # dataset2_split,
987
+ # # # dataset2_text_column,
988
+ # # # threshold,
989
+ # # # ],
990
+ # # # outputs=[status_output, result_output],
991
+ # # # )
992
+
993
+
994
+ # # # demo.launch()