Updates
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ from contextlib import contextmanager
|
|
| 10 |
# Load the model at startup
|
| 11 |
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
|
| 12 |
|
| 13 |
-
#
|
| 14 |
default_dataset1_name = "sst2"
|
| 15 |
default_dataset1_split = "train"
|
| 16 |
default_dataset2_name = "sst2"
|
|
@@ -47,7 +47,6 @@ def deduplicate(
|
|
| 47 |
batch_size: int = 1024,
|
| 48 |
progress=None
|
| 49 |
) -> tuple[np.ndarray, dict[int, int]]:
|
| 50 |
-
# Existing deduplication code remains unchanged
|
| 51 |
# Building the index
|
| 52 |
progress(0, desc="Building search index...")
|
| 53 |
reach = Reach(
|
|
@@ -171,18 +170,137 @@ def perform_deduplication(
|
|
| 171 |
|
| 172 |
elif deduplication_type == "Cross-dataset":
|
| 173 |
# Similar code for cross-dataset deduplication
|
| 174 |
-
#
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
except Exception as e:
|
| 178 |
yield f"An error occurred: {e}", ""
|
| 179 |
raise e
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
deduplication_type = gr.Radio(
|
| 188 |
choices=["Single dataset", "Cross-dataset"],
|
|
@@ -209,8 +327,8 @@ with gr.Blocks() as demo:
|
|
| 209 |
|
| 210 |
compute_button = gr.Button("Compute")
|
| 211 |
|
| 212 |
-
#
|
| 213 |
-
status_output = gr.
|
| 214 |
result_output = gr.Markdown()
|
| 215 |
|
| 216 |
# Function to update the visibility of dataset2_inputs
|
|
|
|
| 10 |
# Load the model at startup
|
| 11 |
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
|
| 12 |
|
| 13 |
+
# Default dataset parameters
|
| 14 |
default_dataset1_name = "sst2"
|
| 15 |
default_dataset1_split = "train"
|
| 16 |
default_dataset2_name = "sst2"
|
|
|
|
| 47 |
batch_size: int = 1024,
|
| 48 |
progress=None
|
| 49 |
) -> tuple[np.ndarray, dict[int, int]]:
|
|
|
|
| 50 |
# Building the index
|
| 51 |
progress(0, desc="Building search index...")
|
| 52 |
reach = Reach(
|
|
|
|
| 170 |
|
| 171 |
elif deduplication_type == "Cross-dataset":
|
| 172 |
# Similar code for cross-dataset deduplication
|
| 173 |
+
# Load Dataset 1
|
| 174 |
+
status = "Loading Dataset 1..."
|
| 175 |
+
yield status, ""
|
| 176 |
+
if (
|
| 177 |
+
dataset1_name == default_dataset1_name
|
| 178 |
+
and dataset1_split == default_dataset1_split
|
| 179 |
+
):
|
| 180 |
+
ds1 = ds_default1
|
| 181 |
+
else:
|
| 182 |
+
ds1 = load_dataset(dataset1_name, split=dataset1_split)
|
| 183 |
+
|
| 184 |
+
# Load Dataset 2
|
| 185 |
+
status = "Loading Dataset 2..."
|
| 186 |
+
yield status, ""
|
| 187 |
+
if (
|
| 188 |
+
dataset2_name == default_dataset2_name
|
| 189 |
+
and dataset2_split == default_dataset2_split
|
| 190 |
+
):
|
| 191 |
+
ds2 = ds_default2
|
| 192 |
+
else:
|
| 193 |
+
ds2 = load_dataset(dataset2_name, split=dataset2_split)
|
| 194 |
+
|
| 195 |
+
# Extract texts from Dataset 1
|
| 196 |
+
status = "Extracting texts from Dataset 1..."
|
| 197 |
+
yield status, ""
|
| 198 |
+
texts1 = [example[dataset1_text_column] for example in ds1]
|
| 199 |
+
|
| 200 |
+
# Extract texts from Dataset 2
|
| 201 |
+
status = "Extracting texts from Dataset 2..."
|
| 202 |
+
yield status, ""
|
| 203 |
+
texts2 = [example[dataset2_text_column] for example in ds2]
|
| 204 |
+
|
| 205 |
+
# Compute embeddings for Dataset 1
|
| 206 |
+
status = "Computing embeddings for Dataset 1..."
|
| 207 |
+
yield status, ""
|
| 208 |
+
embedding_matrix1 = compute_embeddings(
|
| 209 |
+
texts1,
|
| 210 |
+
batch_size=64,
|
| 211 |
+
progress=progress,
|
| 212 |
+
desc="Computing embeddings for Dataset 1",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Compute embeddings for Dataset 2
|
| 216 |
+
status = "Computing embeddings for Dataset 2..."
|
| 217 |
+
yield status, ""
|
| 218 |
+
embedding_matrix2 = compute_embeddings(
|
| 219 |
+
texts2,
|
| 220 |
+
batch_size=64,
|
| 221 |
+
progress=progress,
|
| 222 |
+
desc="Computing embeddings for Dataset 2",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Deduplicate across datasets
|
| 226 |
+
status = "Deduplicating embeddings across datasets..."
|
| 227 |
+
yield status, ""
|
| 228 |
+
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
|
| 229 |
+
embedding_matrix1, embedding_matrix2, threshold, progress=progress
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
num_duplicates = len(duplicate_indices_in_ds2)
|
| 233 |
+
num_total_ds2 = len(texts2)
|
| 234 |
+
num_unique_ds2 = num_total_ds2 - num_duplicates
|
| 235 |
+
|
| 236 |
+
result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
|
| 237 |
+
result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
|
| 238 |
+
result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
|
| 239 |
+
|
| 240 |
+
# Show deduplicated examples
|
| 241 |
+
if num_duplicates > 0:
|
| 242 |
+
result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
|
| 243 |
+
num_examples = min(5, num_duplicates)
|
| 244 |
+
for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
|
| 245 |
+
original_idx = duplicate_to_original_mapping[duplicate_idx]
|
| 246 |
+
original_text = texts1[original_idx]
|
| 247 |
+
duplicate_text = texts2[duplicate_idx]
|
| 248 |
+
differences = display_word_differences(original_text, duplicate_text)
|
| 249 |
+
result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
|
| 250 |
+
result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
|
| 251 |
+
result_text += f"**Differences:**\n{differences}\n"
|
| 252 |
+
result_text += "-" * 50 + "\n\n"
|
| 253 |
+
else:
|
| 254 |
+
result_text += "No duplicates found."
|
| 255 |
+
|
| 256 |
+
# Final status
|
| 257 |
+
status = "Deduplication completed."
|
| 258 |
+
yield status, result_text
|
| 259 |
|
| 260 |
except Exception as e:
|
| 261 |
yield f"An error occurred: {e}", ""
|
| 262 |
raise e
|
| 263 |
|
| 264 |
+
def deduplicate_across_datasets(
|
| 265 |
+
embedding_matrix_1: np.ndarray,
|
| 266 |
+
embedding_matrix_2: np.ndarray,
|
| 267 |
+
threshold: float,
|
| 268 |
+
batch_size: int = 1024,
|
| 269 |
+
progress=None
|
| 270 |
+
) -> tuple[list[int], dict[int, int]]:
|
| 271 |
+
# Building the index from Dataset 1
|
| 272 |
+
progress(0, desc="Building search index from Dataset 1...")
|
| 273 |
+
reach = Reach(
|
| 274 |
+
vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
|
| 275 |
+
)
|
| 276 |
|
| 277 |
+
duplicate_indices_in_test = []
|
| 278 |
+
duplicate_to_original_mapping = {}
|
| 279 |
+
|
| 280 |
+
# Finding nearest neighbors between datasets
|
| 281 |
+
progress(0, desc="Finding nearest neighbors between datasets...")
|
| 282 |
+
results = reach.nearest_neighbor_threshold(
|
| 283 |
+
embedding_matrix_2,
|
| 284 |
+
threshold=threshold,
|
| 285 |
+
batch_size=batch_size,
|
| 286 |
+
show_progressbar=False, # Disable internal progress bar
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
total_items = len(embedding_matrix_2)
|
| 290 |
+
# Processing duplicates with a progress bar
|
| 291 |
+
for i, similar_items in enumerate(
|
| 292 |
+
progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
|
| 293 |
+
):
|
| 294 |
+
similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
|
| 295 |
+
|
| 296 |
+
if similar_indices:
|
| 297 |
+
duplicate_indices_in_test.append(i)
|
| 298 |
+
duplicate_to_original_mapping[i] = similar_indices[0]
|
| 299 |
+
|
| 300 |
+
return duplicate_indices_in_test, duplicate_to_original_mapping
|
| 301 |
+
|
| 302 |
+
with gr.Blocks() as demo:
|
| 303 |
+
gr.Markdown("# Semantic Deduplication")
|
| 304 |
|
| 305 |
deduplication_type = gr.Radio(
|
| 306 |
choices=["Single dataset", "Cross-dataset"],
|
|
|
|
| 327 |
|
| 328 |
compute_button = gr.Button("Compute")
|
| 329 |
|
| 330 |
+
# Use 'lines' parameter to set the height
|
| 331 |
+
status_output = gr.Textbox(lines=10, label="Status")
|
| 332 |
result_output = gr.Markdown()
|
| 333 |
|
| 334 |
# Function to update the visibility of dataset2_inputs
|