Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import logging | |
| import uuid | |
| import requests | |
| from case_util import get_json_from_model_response | |
| from models import ClinicalMCQ | |
| from prompts import mcq_prompt_all_questions_with_rag | |
| from abc import ABC, abstractmethod | |
| from google.oauth2 import service_account | |
| logger = logging.getLogger(__name__) | |
| class LLMClient(ABC): | |
| _api_key = None | |
| _endpoint_url = None | |
| def generate_all_questions(self, case_data: dict, guideline_context: str) -> list[ClinicalMCQ] | None: | |
| """ | |
| Orchestrates the prompt creation and live LLM call to generate the list of all MCQs. | |
| Receives pre-fetched RAG context as a string. | |
| """ | |
| # 1. Create the prompt messages payload | |
| messages = self._create_prompt_messages_for_all_questions( | |
| image_url=case_data.get('download_image_url'), | |
| ground_truth_labels=case_data.get('ground_truth_labels', {}), | |
| guideline_context=guideline_context # Pass the pre-fetched context | |
| ) | |
| try: | |
| # 2. Make the API call | |
| response_dict = self._make_chat_completion_request( | |
| model="tgi", # Or your configured model | |
| messages=messages, | |
| temperature=0, | |
| max_tokens=8192 | |
| ) | |
| # 3. Safely access the list of questions from the parsed dictionary | |
| list_of_question_dicts = response_dict.get("questions", []) | |
| if not list_of_question_dicts: | |
| raise ValueError("LLM response did not contain a 'questions' key or the list was empty.") | |
| # 4. Loop through the extracted list and create ClinicalMCQ objects | |
| list_clinical_mcq = [] | |
| for question_dict in list_of_question_dicts: | |
| if "question" not in question_dict: | |
| logger.warning("Skipping malformed question object in response.") | |
| continue | |
| mcq_uuid = str(uuid.uuid4()) | |
| clinical_mcq = ClinicalMCQ( | |
| id=mcq_uuid, | |
| question=question_dict.get('question', ''), | |
| choices=question_dict.get('choices', {}), | |
| hint=question_dict.get('hint', ''), | |
| answer=question_dict.get('answer', ''), | |
| rationale=question_dict.get('rationale', '') | |
| ) | |
| list_clinical_mcq.append(clinical_mcq) | |
| return list_clinical_mcq | |
| except Exception as e: | |
| logger.error(f"Failed to generate and parse learning module: {e}") | |
| return None | |
| def _make_chat_completion_request( | |
| self, | |
| model: str, | |
| messages: list, | |
| temperature: float, | |
| max_tokens: int, | |
| top_p: float | None = None, | |
| seed: int | None = None, | |
| stop: list[str] | str | None = None, | |
| frequency_penalty: float | None = None, | |
| presence_penalty: float | None = None | |
| ) -> dict | None: | |
| pass | |
| def _create_prompt_messages_for_all_questions(self, image_url: str, ground_truth_labels: dict, guideline_context: str): | |
| """ | |
| Creates the list of messages for the LLM prompt. | |
| Dynamically selects the prompt and constructs the payload based on whether RAG context is present. | |
| """ | |
| # The system message sets the stage and provides all instructions/examples. | |
| system_message = { | |
| "role": "system", | |
| "content": [ | |
| {"type": "text", "text": mcq_prompt_all_questions_with_rag}, | |
| ] | |
| } | |
| user_content_text = ( | |
| f"<significant_clinical_conditions>\n{json.dumps(ground_truth_labels, indent=2)}\n</significant_clinical_conditions>\n\n" | |
| f"<guideline_context>\n{guideline_context}\n</guideline_context>" | |
| ) | |
| # The user message provides the specific data for THIS request and the image. | |
| user_message = { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": image_url}}, | |
| {"type": "text", "text": user_content_text} | |
| ] | |
| } | |
| messages = [system_message, user_message] | |
| logger.info("Messages being sent:-\n{}".format(json.dumps(messages, indent=2))) | |
| return messages | |
| class HuggingFaceLLMClient(LLMClient): | |
| def __init__(self, _api_key, _endpoint_url): | |
| if not _api_key: | |
| raise ValueError("No API key provided.") | |
| if not _endpoint_url: | |
| raise ValueError("No endpoint URL provided.") | |
| self._api_key = _api_key | |
| self._endpoint_url = _endpoint_url | |
| def _make_chat_completion_request( | |
| self, | |
| model: str, | |
| messages: list, | |
| temperature: float, | |
| max_tokens: int, | |
| top_p: float | None = None, | |
| seed: int | None = None, | |
| stop: list[str] | str | None = None, | |
| frequency_penalty: float | None = None, | |
| presence_penalty: float | None = None | |
| ) -> dict | None: | |
| headers = { | |
| "Authorization": f"Bearer {self._api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "stream": True, | |
| } | |
| if top_p is not None: payload["top_p"] = top_p | |
| if seed is not None: payload["seed"] = seed | |
| if stop is not None: payload["stop"] = stop | |
| if frequency_penalty is not None: payload["frequency_penalty"] = frequency_penalty | |
| if presence_penalty is not None: payload["presence_penalty"] = presence_penalty | |
| temp_url = self._endpoint_url.rstrip('/') | |
| if temp_url.endswith("/v1/chat/completions"): | |
| full_url = temp_url | |
| elif temp_url.endswith("/v1"): | |
| full_url = temp_url + "/chat/completions" | |
| else: | |
| full_url = temp_url + "/v1/chat/completions" | |
| response = requests.post(full_url, headers=headers, json=payload, timeout=60) | |
| logger.info(f"LLM call status code: {response.status_code}, response: {response.reason}") | |
| explanation_parts = [] | |
| for line in response.iter_lines(): | |
| if line: | |
| decoded_line = line.decode('utf-8') | |
| if decoded_line.startswith('data: '): | |
| json_data_str = decoded_line[len('data: '):].strip() | |
| if json_data_str == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(json_data_str) | |
| if chunk.get("choices") and chunk["choices"][0].get( | |
| "delta") and chunk["choices"][0]["delta"].get( | |
| "content"): | |
| explanation_parts.append( | |
| chunk["choices"][0]["delta"]["content"]) | |
| except json.JSONDecodeError: | |
| logger.warning( | |
| f"Could not decode JSON from stream chunk: {json_data_str}") | |
| # Depending on API, might need to handle partial JSON or other errors | |
| elif decoded_line.strip() == "[DONE]": # Some APIs might send [DONE] without "data: " | |
| break | |
| explanation = "".join(explanation_parts).strip() | |
| if not explanation: | |
| logger.warning("Empty explanation from API") | |
| return get_json_from_model_response(explanation) | |
| class VertexAILLMClient(LLMClient): | |
| def __init__(self, _api_key, _endpoint_url): | |
| if not _api_key: | |
| raise ValueError("No API key provided.") | |
| if not _endpoint_url: | |
| raise ValueError("No endpoint URL provided.") | |
| self._api_key = _api_key | |
| self._endpoint_url = _endpoint_url | |
| def _make_chat_completion_request( | |
| self, | |
| model: str, | |
| messages: list, | |
| temperature: float, | |
| max_tokens: int, | |
| top_p: float | None = None, | |
| seed: int | None = None, | |
| stop: list[str] | str | None = None, | |
| frequency_penalty: float | None = None, | |
| presence_penalty: float | None = None | |
| ) -> dict | None: | |
| # 1. Get credentials directly from the secret | |
| creds = self._get_credentials_from_secret() | |
| logger.info("Successfully loaded credentials from secret.") | |
| # 2. Get a valid access token | |
| token = self._get_access_token(creds) | |
| logger.info("Successfully obtained access token.") | |
| # 3. Use the token to make an authenticated API call | |
| # Example: Calling a Vertex AI endpoint | |
| headers = { | |
| 'Authorization': f'Bearer {token}', | |
| 'Content-Type': 'application/json' | |
| } | |
| payload = { | |
| "model": model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| response = requests.post(self._endpoint_url, headers=headers, json=payload, | |
| timeout=60) | |
| logger.info(f"LLM call status code: {response.status_code}, status reason: {response.reason}") | |
| response_dict = response.json() | |
| final_response = response_dict["choices"][0]["message"]["content"] | |
| return get_json_from_model_response(final_response) | |
| def _get_credentials_from_secret(self): | |
| """Loads Google Cloud credentials from an environment variable.""" | |
| if not self._api_key: | |
| raise ValueError( | |
| f"Environment variable 'GCLOUD_SA_KEY' not found. Please set it in your Hugging Face Space secrets.") | |
| logger.info("Loading Google Cloud credentials...") | |
| # Parse the JSON string into a dictionary | |
| credentials_info = json.loads(self._api_key) | |
| logger.info("Google Cloud credentials loaded.") | |
| # Define the required scopes for the API you want to access | |
| scopes = ['https://www.googleapis.com/auth/cloud-platform'] | |
| # Create credentials from the dictionary | |
| credentials = service_account.Credentials.from_service_account_info( | |
| credentials_info, | |
| scopes=scopes | |
| ) | |
| return credentials | |
| def _get_access_token(self, credentials): | |
| """Refreshes the credentials to get a valid access token.""" | |
| from google.auth.transport.requests import Request | |
| # Refresh the token to ensure it's not expired | |
| credentials.refresh(Request()) | |
| return credentials.token |