Spaces:
Running
Running
| from typing import Optional | |
| import weave | |
| from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage | |
| from medrag_multi_modal.assistant.llm_client import LLMClient | |
| from medrag_multi_modal.assistant.schema import ( | |
| MedQACitation, | |
| MedQAMCQResponse, | |
| MedQAResponse, | |
| ) | |
| from medrag_multi_modal.retrieval.common import SimilarityMetric | |
| from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever | |
| class MedQAAssistant(weave.Model): | |
| """ | |
| `MedQAAssistant` is a class designed to assist with medical queries by leveraging a | |
| language model client, a retriever model, and a figure annotator. | |
| !!! example "Usage Example" | |
| ```python | |
| import weave | |
| from dotenv import load_dotenv | |
| from medrag_multi_modal.assistant import ( | |
| FigureAnnotatorFromPageImage, | |
| LLMClient, | |
| MedQAAssistant, | |
| ) | |
| from medrag_multi_modal.retrieval import MedCPTRetriever | |
| load_dotenv() | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| llm_client = LLMClient(model_name="gemini-1.5-flash") | |
| retriever=MedCPTRetriever.from_wandb_artifact( | |
| chunk_dataset_name="grays-anatomy-chunks:v0", | |
| index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", | |
| ) | |
| figure_annotator=FigureAnnotatorFromPageImage( | |
| figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"), | |
| structured_output_llm_client=LLMClient(model_name="gpt-4o"), | |
| image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", | |
| ) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator | |
| ) | |
| medqa_assistant.predict(query="What is ribosome?") | |
| ``` | |
| Args: | |
| llm_client (LLMClient): The language model client used to generate responses. | |
| retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document. | |
| figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages. | |
| top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query. | |
| top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options. | |
| retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval. | |
| """ | |
| llm_client: LLMClient | |
| retriever: weave.Model | |
| figure_annotator: Optional[FigureAnnotatorFromPageImage] = None | |
| top_k_chunks_for_query: int = 2 | |
| top_k_chunks_for_options: int = 2 | |
| rely_only_on_context: bool = True | |
| retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE | |
| def retrieve_chunks_for_query(self, query: str) -> list[dict]: | |
| retriever_kwargs = {"top_k": self.top_k_chunks_for_query} | |
| if not isinstance(self.retriever, BM25sRetriever): | |
| retriever_kwargs["metric"] = self.retrieval_similarity_metric | |
| return self.retriever.predict(query, **retriever_kwargs) | |
| def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]: | |
| retriever_kwargs = {"top_k": self.top_k_chunks_for_options} | |
| if not isinstance(self.retriever, BM25sRetriever): | |
| retriever_kwargs["metric"] = self.retrieval_similarity_metric | |
| retrieved_chunks = [] | |
| for option in options: | |
| retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs) | |
| return retrieved_chunks | |
| def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse: | |
| """ | |
| Generates a response to a medical query by retrieving relevant text chunks and figure descriptions | |
| from a medical document and using a language model to generate the final response. | |
| This function performs the following steps: | |
| 1. Retrieves relevant text chunks from the medical document based on the query and any provided options | |
| using the retriever model. | |
| 2. Extracts the text and page indices from the retrieved chunks. | |
| 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator. | |
| 4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks, | |
| and figure descriptions. | |
| 5. Uses the language model client to generate a response based on the constructed prompts, either choosing | |
| from provided options or generating a free-form response. | |
| 6. Returns the generated response, which includes the answer and explanation if options were provided. | |
| The function can operate in two modes: | |
| - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice | |
| - Free response: When no options are provided, it generates a comprehensive response based on the context | |
| Args: | |
| query (str): The medical query to be answered. | |
| options (Optional[list[str]]): The list of options to choose from. | |
| rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation. | |
| Returns: | |
| MedQAResponse: The generated response to the query, including source information. | |
| """ | |
| retrieved_chunks = self.retrieve_chunks_for_query(query) | |
| options = options or [] | |
| retrieved_chunks += self.retrieve_chunks_for_options(options) | |
| retrieved_chunk_texts = [] | |
| page_indices = set() | |
| for chunk in retrieved_chunks: | |
| retrieved_chunk_texts.append(chunk["text"]) | |
| page_indices.add(int(chunk["page_idx"])) | |
| figure_descriptions = [] | |
| if self.figure_annotator is not None: | |
| for page_idx in page_indices: | |
| figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ | |
| page_idx | |
| ] | |
| figure_descriptions += [ | |
| item["figure_description"] for item in figure_annotations | |
| ] | |
| system_prompt = """You are an expert in medical science. You are given a question | |
| and a list of excerpts from various medical documents. | |
| """ | |
| query = f"""# Question | |
| {query} | |
| """ | |
| if len(options) > 0: | |
| system_prompt += """\nYou are also given a list of options to choose your answer from. | |
| You are supposed to choose the best possible option based on the context provided. You should also | |
| explain your answer to justify why you chose that option. | |
| """ | |
| query += "## Options\n" | |
| for option in options: | |
| query += f"- {option}\n" | |
| else: | |
| system_prompt += "\nYou are supposed to answer the question based on the context provided." | |
| if self.rely_only_on_context: | |
| system_prompt += """\n\nYou are only allowed to use the context provided to answer the question. | |
| You are not allowed to use any external knowledge to answer the question. | |
| """ | |
| response = self.llm_client.predict( | |
| system_prompt=system_prompt, | |
| user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], | |
| schema=MedQAMCQResponse if len(options) > 0 else None, | |
| ) | |
| # TODO: Add figure citations | |
| # TODO: Add source document name from retrieved chunks as citations | |
| citations = [] | |
| for page_idx in page_indices: | |
| citations.append( | |
| MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy") | |
| ) | |
| return MedQAResponse(response=response, citations=citations) | |