Spaces:
Runtime error
Runtime error
| import json | |
| import math | |
| import os | |
| import random | |
| import uuid | |
| from datetime import datetime | |
| import gradio as gr | |
| import jsonlines | |
| import pyarrow as pa | |
| import s3fs | |
| from datasets import Dataset | |
| from huggingface_hub import HfApi | |
| S3 = s3fs.S3FileSystem(anon=False, key=os.getenv("AWS_ACCESS_KEY_ID"), secret=os.getenv("AWS_SECRET_ACCESS_KEY")) | |
| BASE_S3_DIR = "s3://geclm-datasets/samples/" | |
| LABELLING_COMPLETE_TEXT = ( | |
| "Completed the labelling the sample for the {} dataset. Please consider labelling other datasets." | |
| ) | |
| DATASETS = [ | |
| "c4", | |
| "bigcode_python_code", | |
| "bigcode_python_github_issues", | |
| "bigcode_python_jupyter_markdowned_clean_dedup", | |
| "books3", | |
| "gutenberg_raw", | |
| "reddit_threaded", | |
| "enwiki_data", | |
| "s2orc_dedup", | |
| "stackexchange2", | |
| "commoncrawl", | |
| ] | |
| def get_parquet_lines(dataset, sample_size=1000): | |
| s3_paths = S3.glob(BASE_S3_DIR + dataset + "/*") | |
| if len(s3_paths) == 0: | |
| raise FileNotFoundError(f"Nothing found at {path}") | |
| print("Number of parquet files", len(s3_paths)) | |
| s3_path = random.choice(s3_paths) | |
| print("Reading", s3_path) | |
| lines = [] | |
| with S3.open(s3_path) as f: | |
| pf = pa.parquet.ParquetFile(f) | |
| for ix_row_group in range(pf.metadata.num_row_groups): | |
| # We load dataset by row group - 1000 rows at a time | |
| # using open_input_stream would return bytes per bytes not row per row | |
| table = pf.read_row_group(ix_row_group) | |
| lines.extend(table.to_pylist()) | |
| random.shuffle(lines) | |
| return lines[:sample_size] | |
| def get_local_lines(dataset): | |
| lines = [] | |
| with jsonlines.open("data/{}_examples_with_stats.json".format(dataset), "r") as f: | |
| for line in f: | |
| lines.append(line) | |
| return lines | |
| def line_generator(lines_dict, dataset): | |
| for line in lines_dict[dataset]: | |
| yield line | |
| # local_lines = {dataset: get_local_lines(dataset) for dataset in DATASETS} | |
| # line_generators_local = {dataset: line_generator(local_lines, dataset) for dataset in DATASETS} | |
| # Parallelize the below ? | |
| s3_lines = {dataset: get_parquet_lines(dataset) for dataset in DATASETS} | |
| line_generators_s3 = {dataset: line_generator(s3_lines, dataset) for dataset in DATASETS} | |
| def send_report(sample, dataset, reason, annotator, campaign): | |
| text_col = "text" | |
| if text_col not in sample: | |
| text_col = "content" | |
| text = sample[text_col] | |
| sample.pop(text_col) | |
| if "record_timestamp" in sample: | |
| sample.pop("record_timestamp") | |
| sample_id = "" | |
| if "id" not in sample: | |
| if "title" in sample: | |
| sample_id = sample["title"] | |
| else: | |
| sample_id = sample["id"] | |
| with jsonlines.open("report.jsonl", "w") as f: | |
| f.write( | |
| { | |
| "dataset": dataset, | |
| "docid": sample_id, | |
| "text": text, | |
| "metadata": json.dumps(sample), | |
| "reason": reason, | |
| "annotator": annotator, | |
| "campaign": campaign, | |
| "timestamp": str(datetime.now()), | |
| } | |
| ) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj="report.jsonl", | |
| path_in_repo="report-{}.jsonl".format(uuid.uuid4()), | |
| repo_id="HuggingFaceGECLM/data_feedback", | |
| repo_type="dataset", | |
| token=os.environ.get("geclm_token"), | |
| ) | |
| def get_title_and_text_for_line(next_line): | |
| text_col = "text" | |
| if text_col not in next_line: | |
| text_col = "content" | |
| text = next_line[text_col] | |
| label = "" | |
| if "title" in next_line: | |
| label = next_line["title"] | |
| if "url" in next_line: | |
| label += " | " + next_line["url"] | |
| elif "metadata" in next_line: | |
| if next_line["metadata"] is not None: | |
| print(next_line["metadata"]) | |
| if isinstance(next_line["metadata"], list) and len(next_line["metadata"]) > 0: | |
| label = next_line["metadata"][0] | |
| elif isinstance(next_line["metadata"], str): | |
| metadata = json.loads(next_line["metadata"]) | |
| if "document_url" in metadata: | |
| label = metadata["document_url"] | |
| elif "url" in next_line: | |
| label = next_line["url"] | |
| return text, label | |
| if __name__ == "__main__": | |
| demo = gr.Blocks() | |
| with demo: | |
| current_sample_state = gr.State(dict()) | |
| description = gr.Markdown( | |
| value="""GecLM annotations. All annotations are recorded in the [data_feedback](https://huggingface.co/datasets/HuggingFaceGECLM/data_feedback) dataset. | |
| """, | |
| ) | |
| with gr.Row(): | |
| annotator = gr.Textbox( | |
| lines=1, | |
| max_lines=1, | |
| placeholder="Optionally provide your name here if you'd like it to be recorded.", | |
| label="Annotator", | |
| ) | |
| campaign = gr.Textbox( | |
| lines=1, | |
| max_lines=1, | |
| placeholder="Optionally provide the name of the annotation campagin for ease of filtering the reports.", | |
| label="Annotation campaign", | |
| ) | |
| with gr.Row(): | |
| dataset = gr.Dropdown( | |
| choices=DATASETS, | |
| value="Pick a dataset below", | |
| label="Dataset", | |
| ) | |
| with gr.Row(): | |
| reason_txt = gr.Textbox( | |
| label="Flagging reason", | |
| placeholder="Provide the reason for flagging if you think the sample is bad.", | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| bad_btn = gr.Button("Bad β", visible=False) | |
| good_btn = gr.Button("Next β ", visible=False) | |
| with gr.Row(): | |
| text = gr.Textbox(visible=False, label="Datapoint", lines=500, max_lines=500) | |
| def get_next_line(dataset): | |
| try: | |
| next_line = next(line_generators_s3[dataset]) | |
| text, label = get_title_and_text_for_line(next_line) | |
| except StopIteration: | |
| text = LABELLING_COMPLETE_TEXT.format(dataset) | |
| next_line = text | |
| return [ | |
| gr.update( | |
| value=text, | |
| visible=True, | |
| label=label, | |
| ), | |
| next_line, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ] | |
| def report_bad_line_and_next(current_sample, dataset, reason, annotator, campaign): | |
| if current_sample != LABELLING_COMPLETE_TEXT.format(dataset): | |
| send_report(current_sample, dataset, reason, annotator, campaign) | |
| try: | |
| next_line = next(line_generators_s3[dataset]) | |
| text, label = get_title_and_text_for_line(next_line) | |
| except StopIteration: | |
| text = LABELLING_COMPLETE_TEXT.format(dataset) | |
| next_line = text | |
| return [ | |
| gr.update( | |
| value=text, | |
| visible=True, | |
| label=label, | |
| ), | |
| gr.update( | |
| value="", | |
| placeholder="Provide the reason for flagging if you think the sample is bad.", | |
| ), | |
| next_line, | |
| ] | |
| good_btn.click( | |
| get_next_line, | |
| inputs=dataset, | |
| outputs=[text, current_sample_state, reason_txt, good_btn, bad_btn], | |
| ) | |
| dataset.change( | |
| get_next_line, | |
| inputs=dataset, | |
| outputs=[text, current_sample_state, reason_txt, good_btn, bad_btn], | |
| ) | |
| bad_btn.click( | |
| report_bad_line_and_next, | |
| inputs=[current_sample_state, dataset, reason_txt, annotator, campaign], | |
| outputs=[text, reason_txt, current_sample_state], | |
| ) | |
| demo.launch(enable_queue=False, debug=True) | |