Spaces:
Runtime error
Runtime error
| import json | |
| import traceback | |
| from queue import Queue | |
| from threading import Thread | |
| from typing import List | |
| import argilla as rg | |
| import gradio as gr | |
| from gradio_client import Client | |
| client = rg.Argilla() | |
| completed_record_events = Queue() | |
| def build_dataset(client: rg.Argilla) -> rg.Dataset: | |
| settings = rg.Settings.from_hub("stanfordnlp/imdb") | |
| settings.questions.add( | |
| rg.LabelQuestion(name="sentiment", labels=["negative", "positive"]) | |
| ) | |
| dataset_name = "stanfordnlp_imdb" | |
| dataset = client.datasets(dataset_name) or rg.Dataset.from_hub( | |
| "stanfordnlp/imdb", | |
| name=dataset_name, | |
| settings=settings, | |
| client=client, | |
| split="train[:1000]", | |
| ) | |
| return dataset | |
| with gr.Blocks() as demo: | |
| argilla_server = client.http_client.base_url | |
| gr.Markdown("## Argilla Events") | |
| gr.Markdown( | |
| f"This demo shows the incoming events from the [Argilla Server]({argilla_server})." | |
| ) | |
| gr.Markdown("### Record Events") | |
| gr.Markdown("#### Records are processed in background and suggestions are added.") | |
| server, _, _ = demo.launch(prevent_thread_lock=True, app_kwargs={"docs_url": "/docs"}) | |
| # Set up the webhook listeners | |
| rg.set_webhook_server(server) | |
| for webhook in client.webhooks: | |
| webhook.enabled = False | |
| webhook.update() | |
| # Create a webhook for record events | |
| async def record_events(record: rg.Record, type: str, **kwargs): | |
| print("Received event", type) | |
| completed_record_events.put(record) | |
| dataset = build_dataset(client) | |
| def add_record_suggestions_on_response_created(): | |
| print("Starting thread") | |
| completed_records_filter = rg.Filter(("status", "==", "completed")) | |
| pending_records_filter = rg.Filter(("status", "==", "pending")) | |
| while True: | |
| try: | |
| record: rg.Record = completed_record_events.get() | |
| if dataset.id != record.dataset.id: | |
| continue | |
| # Prepare predict data | |
| field = dataset.settings.fields["text"] | |
| question = dataset.settings.questions["sentiment"] | |
| examples = list( | |
| dataset.records( | |
| query=rg.Query(filter=completed_records_filter), | |
| limit=5, | |
| ) | |
| ) | |
| some_pending_records = list( | |
| dataset.records( | |
| query=rg.Query(filter=pending_records_filter), | |
| limit=5, | |
| ) | |
| ) | |
| if not some_pending_records: | |
| continue | |
| some_pending_records = parse_pending_records( | |
| some_pending_records, field, question, examples | |
| ) | |
| dataset.records.log(some_pending_records) | |
| except Exception: | |
| print(traceback.format_exc()) | |
| continue | |
| def parse_pending_records( | |
| records: List[rg.Record], | |
| field: rg.Field, | |
| question, | |
| example_records: List[rg.Record], | |
| ) -> List[rg.Record]: | |
| try: | |
| gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
| payload = { | |
| "records": [record.to_dict() for record in records], | |
| "fields": [field.serialize()], | |
| "question": question.serialize(), | |
| "example_records": [record.to_dict() for record in example_records], | |
| "api_name": "/predict", | |
| } | |
| response = gradio_client.predict(**payload) | |
| response = json.loads(response) if isinstance(response, str) else response | |
| for record, suggestion in zip(records, response["results"]): | |
| record.suggestions.add( | |
| rg.Suggestion( | |
| question_name=question.name, | |
| value=suggestion["value"], | |
| score=suggestion["score"], | |
| agent=suggestion["agent"], | |
| ) | |
| ) | |
| except Exception: | |
| print(traceback.format_exc()) | |
| return records | |
| thread = Thread(target=add_record_suggestions_on_response_created) | |
| thread.start() | |
| demo.block_thread() | |