Spaces:
Runtime error
Runtime error
| import os | |
| import uuid | |
| from typing import List, Dict, Optional | |
| import pandas as pd | |
| from autorag.deploy import GradioRunner | |
| from autorag.deploy.api import RetrievedPassage | |
| from autorag.nodes.generator.base import BaseGenerator | |
| from autorag.utils import fetch_contents | |
| empty_retrieved_passage = RetrievedPassage( | |
| content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None | |
| ) | |
| class GradioStreamRunner(GradioRunner): | |
| def __init__(self, config: Dict, project_dir: Optional[str] = None): | |
| super().__init__(config, project_dir) | |
| data_dir = os.path.join(project_dir, "data") | |
| self.corpus_df = pd.read_parquet( | |
| os.path.join(data_dir, "corpus.parquet"), engine="pyarrow" | |
| ) | |
| def stream_run(self, query: str): | |
| previous_result = pd.DataFrame( | |
| { | |
| "qid": str(uuid.uuid4()), | |
| "query": [query], | |
| "retrieval_gt": [[]], | |
| "generation_gt": [""], | |
| } | |
| ) # pseudo qa data for execution | |
| for module_instance, module_param in zip( | |
| self.module_instances, self.module_params | |
| ): | |
| if not isinstance(module_instance, BaseGenerator): | |
| new_result = module_instance.pure( | |
| previous_result=previous_result, **module_param | |
| ) | |
| duplicated_columns = previous_result.columns.intersection( | |
| new_result.columns | |
| ) | |
| drop_previous_result = previous_result.drop( | |
| columns=duplicated_columns | |
| ) | |
| previous_result = pd.concat( | |
| [drop_previous_result, new_result], axis=1 | |
| ) | |
| else: | |
| # retrieved_passages = self.extract_retrieve_passage( | |
| # previous_result | |
| # ) | |
| # yield "", retrieved_passages | |
| # Start streaming of the result | |
| assert len(previous_result) == 1 | |
| prompt: str = previous_result["prompts"].tolist()[0] | |
| for delta in module_instance.stream(prompt=prompt, | |
| **module_param): | |
| yield delta, [empty_retrieved_passage] | |
| def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]: | |
| retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0] | |
| contents = fetch_contents(self.corpus_df, [retrieved_ids])[0] | |
| if "path" in self.corpus_df.columns: | |
| paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[ | |
| 0 | |
| ] | |
| else: | |
| paths = [None] * len(retrieved_ids) | |
| metadatas = fetch_contents( | |
| self.corpus_df, [retrieved_ids], column_name="metadata" | |
| )[0] | |
| if "start_end_idx" in self.corpus_df.columns: | |
| start_end_indices = fetch_contents( | |
| self.corpus_df, [retrieved_ids], column_name="start_end_idx" | |
| )[0] | |
| else: | |
| start_end_indices = [None] * len(retrieved_ids) | |
| return list( | |
| map( | |
| lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage( | |
| content=content, | |
| doc_id=doc_id, | |
| filepath=path, | |
| file_page=metadata.get("page", None), | |
| start_idx=start_end_idx[0] if start_end_idx else None, | |
| end_idx=start_end_idx[1] if start_end_idx else None, | |
| ), | |
| contents, | |
| retrieved_ids, | |
| paths, | |
| metadatas, | |
| start_end_indices, | |
| ) | |
| ) | |