David Tang commited on
Commit
e0baf14
·
1 Parent(s): 0d34271

add modal backend

Browse files
Files changed (1) hide show
  1. modal-medgemma.py +269 -0
modal-medgemma.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modal-medgemma.py
3
+ -----------------
4
+
5
+ This module defines a Modal app and FastAPI endpoint for running the MedGemma agent, a multimodal LLM that can process text and images. It provides a streaming API for inference, including Wikipedia tool-calling capabilities, and handles model download, loading, and inference with GPU support.
6
+
7
+ Nb. needs to be deployed with the following command:
8
+ `modal deploy modal-medgemma.py`
9
+
10
+ Key components:
11
+ - Modal app and volume setup for model weights
12
+ - MedGemmaAgent class for model loading and inference
13
+ - FastAPI endpoint for streaming responses
14
+ - Utility for processing base64-encoded images
15
+ """
16
+ import modal
17
+ from typing import Optional, Generator, Dict, Any, List
18
+ import os
19
+ from fastapi import Security, HTTPException, Depends
20
+ from fastapi.security.api_key import APIKeyHeader
21
+ from fastapi.responses import StreamingResponse
22
+ from pydantic import BaseModel
23
+ import json
24
+ import base64
25
+ from PIL import Image
26
+ import io
27
+
28
+ app = modal.App("example-medgemma-agent")
29
+ volume = modal.Volume.from_name("model-weights-vol", create_if_missing=True)
30
+ MODEL_DIR = "/models"
31
+ # MODEL_ID = "google/medgemma-4b-it"
32
+ MODEL_ID = "unsloth/medgemma-4b-it-unsloth-bnb-4bit"
33
+ MINUTES = 60
34
+
35
+ API_KEY_NAME = "X-API-Key"
36
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
37
+
38
+ async def get_api_key(api_key_header: str = Security(api_key_header)):
39
+ """
40
+ Validates the provided API key against the environment variable.
41
+
42
+ Args:
43
+ api_key_header (str): The API key provided in the request header.
44
+
45
+ Raises:
46
+ HTTPException: If the API key is invalid.
47
+
48
+ Returns:
49
+ str: The validated API key.
50
+ """
51
+ if api_key_header != os.environ["FASTAPI_KEY"]:
52
+ raise HTTPException(
53
+ status_code=403, detail="Invalid API Key"
54
+ )
55
+ return api_key_header
56
+
57
+ image = (
58
+ modal.Image.debian_slim()
59
+ .pip_install(
60
+ "smolagents[vllm]",
61
+ "fastapi[standard]",
62
+ "wikipedia-api",
63
+ "accelerate",
64
+ "bitsandbytes",
65
+ "huggingface-hub[hf_transfer]",
66
+ "Pillow")
67
+ .env({
68
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
69
+ "HF_HUB_CACHE": MODEL_DIR
70
+ })
71
+ )
72
+
73
+ with image.imports():
74
+ from smolagents import VLLMModel, ToolCallingAgent, tool
75
+ from pydantic import BaseModel
76
+ import wikipediaapi
77
+
78
+ @app.function(
79
+ image=image,
80
+ secrets=[modal.Secret.from_name("access_medgemma_hf")],
81
+ volumes={MODEL_DIR: volume}
82
+ )
83
+ def download_model():
84
+ """
85
+ Downloads the model weights from Hugging Face Hub using the provided token.
86
+
87
+ Returns:
88
+ dict: Status message indicating success.
89
+ """
90
+ from huggingface_hub import snapshot_download
91
+ snapshot_download(
92
+ repo_id=MODEL_ID,
93
+ token=os.environ["HF_TOKEN"]
94
+ )
95
+ return {"status": "Model downloaded successfully"}
96
+
97
+ @app.cls(
98
+ image=image,
99
+ gpu="L4:1",
100
+ volumes={MODEL_DIR: volume},
101
+ min_containers=1,
102
+ max_containers=1,
103
+ timeout=15 * MINUTES,
104
+ secrets=[modal.Secret.from_name("access_medgemma_hf")],
105
+ )
106
+ class MedGemmaAgent:
107
+ """
108
+ Modal class for managing the MedGemma model and running inference with optional tool-calling.
109
+ """
110
+ @modal.enter()
111
+ def load_models(self):
112
+ """
113
+ Loads the MedGemma model into memory and prepares it for inference.
114
+ Downloads the model weights if not already present.
115
+ """
116
+ download_model.remote()
117
+ model_kwargs = {
118
+ "max_model_len": 8192,
119
+ "dtype": "bfloat16",
120
+ "gpu_memory_utilization": 0.95,
121
+ "tensor_parallel_size": 1,
122
+ "trust_remote_code": True
123
+ }
124
+ self.model = VLLMModel(
125
+ model_id=MODEL_ID,
126
+ model_kwargs=model_kwargs
127
+ )
128
+ print(f"Model: {MODEL_ID} loaded successfully")
129
+
130
+ @modal.method()
131
+ def run(self, prompt: str, images: Optional[List[Image.Image]] = None) -> Generator[Dict[str, Any], None, None]:
132
+ """
133
+ Runs the MedGemma agent on the provided prompt and optional images, yielding streaming responses.
134
+
135
+ Args:
136
+ prompt (str): The user prompt to process.
137
+ images (Optional[List[Image.Image]]): List of PIL Images to provide as context (optional).
138
+
139
+ Yields:
140
+ Dict[str, Any]: Streaming response chunks, including 'thinking' and 'final' messages.
141
+ """
142
+ @tool
143
+ def wiki(query: str) -> str:
144
+ """
145
+ Fetches a summary of a Wikipedia page based on a given search query (only one word or group of words).
146
+
147
+ Args:
148
+ query: The search term for the Wikipedia page (only one word or group of words).
149
+ """
150
+ wiki = wikipediaapi.Wikipedia(language="en", user_agent="MinimalAgent/1.0")
151
+ page = wiki.page(query)
152
+ if not page.exists():
153
+ return "No Wikipedia page found."
154
+ return page.summary[:1000]
155
+
156
+ self.agent = ToolCallingAgent(
157
+ tools=[wiki],
158
+ model=self.model,
159
+ max_steps=3
160
+ )
161
+
162
+ # Yield thinking step
163
+ yield {
164
+ "type": "thinking",
165
+ "content": {"message": "Starting to process your request..."}
166
+ }
167
+
168
+ # Run the agent and capture the result
169
+ result = self.agent.run(
170
+ task=prompt,
171
+ stream=False,
172
+ reset=True,
173
+ images=images if images else None,
174
+ additional_args={"flatten_messages_as_text": False},
175
+ max_steps=3
176
+ )
177
+
178
+ # Yield the final response
179
+ yield {
180
+ "type": "final",
181
+ "content": {"response": result}
182
+ }
183
+
184
+ class PromptRequest(BaseModel):
185
+ """
186
+ Request model for the /run_medgemma endpoint.
187
+
188
+ Attributes:
189
+ prompt (str): The user prompt to process.
190
+ image (Optional[str]): Base64-encoded image string (optional).
191
+ history (Optional[list]): Conversation history (optional).
192
+ """
193
+ prompt: str
194
+ image: Optional[str] = None # Base64 encoded image
195
+ history: Optional[list] = None
196
+
197
+ class GenerationResponse(BaseModel):
198
+ """
199
+ Response model for non-streaming generation (not used in this file).
200
+
201
+ Attributes:
202
+ response (str): The generated response.
203
+ """
204
+ response: str
205
+
206
+ class StreamResponse(BaseModel):
207
+ """
208
+ Response model for streaming responses.
209
+
210
+ Attributes:
211
+ type (str): The type of message ('thinking', 'tool_call', 'tool_result', 'final').
212
+ content (Dict[str, Any]): The content of the message.
213
+ """
214
+ type: str # 'thinking', 'tool_call', 'tool_result', 'final'
215
+ content: Dict[str, Any]
216
+
217
+ def process_image(image_base64: Optional[str]) -> Optional[Image.Image]:
218
+ """
219
+ Decodes a base64-encoded image string into a PIL Image.
220
+
221
+ Args:
222
+ image_base64 (Optional[str]): Base64-encoded image string.
223
+
224
+ Returns:
225
+ Optional[Image.Image]: The decoded PIL Image, or None if decoding fails or input is None.
226
+ """
227
+ if not image_base64:
228
+ return None
229
+ try:
230
+ image_data = base64.b64decode(image_base64)
231
+ return Image.open(io.BytesIO(image_data))
232
+ except Exception as e:
233
+ print(f"Error processing image: {e}")
234
+ return None
235
+
236
+ @app.function(
237
+ image=image,
238
+ secrets=[
239
+ modal.Secret.from_name("access_medgemma_hf"),
240
+ modal.Secret.from_name("FASTAPI_KEY")
241
+ ]
242
+ )
243
+ @modal.fastapi_endpoint(method="POST")
244
+ async def run_medgemma(request: PromptRequest, api_key: str = Depends(get_api_key)):
245
+ """
246
+ FastAPI endpoint for running the MedGemma agent with streaming responses.
247
+
248
+ Args:
249
+ request (PromptRequest): The request payload containing prompt and optional image.
250
+ api_key (str): The validated API key (injected by Depends).
251
+
252
+ Returns:
253
+ StreamingResponse: An event-stream response yielding model output chunks.
254
+ """
255
+ model_handler = MedGemmaAgent()
256
+
257
+ # Process image if provided
258
+ image = process_image(request.image)
259
+ images = [image] if image else None
260
+
261
+ async def generate():
262
+ stream = model_handler.run.remote_gen.aio(request.prompt, images=images)
263
+ async for chunk in stream:
264
+ yield f"data: {json.dumps(chunk)}\n\n"
265
+
266
+ return StreamingResponse(
267
+ generate(),
268
+ media_type="text/event-stream"
269
+ )