Spaces:
Sleeping
Sleeping
Security added for query reqs via LLM
Browse files- app.py +13 -6
- schemas.py +3 -0
app.py
CHANGED
|
@@ -294,6 +294,8 @@ def download_tdocs(req: DownloadRequest):
|
|
| 294 |
async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
| 295 |
documents = req.documents
|
| 296 |
n_docs = len(documents)
|
|
|
|
|
|
|
| 297 |
|
| 298 |
async def process_document(doc):
|
| 299 |
doc_id = doc.document
|
|
@@ -309,7 +311,7 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
|
| 309 |
async with limiter_mapping[model_used]:
|
| 310 |
resp_ai = await llm_router.acompletion(
|
| 311 |
model=model_used,
|
| 312 |
-
messages=[{"role":"user","content":
|
| 313 |
response_format=RequirementsResponse
|
| 314 |
)
|
| 315 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
|
@@ -320,7 +322,7 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
|
| 320 |
async with limiter_mapping[model_used]:
|
| 321 |
resp_ai = await llm_router.acompletion(
|
| 322 |
model=model_used,
|
| 323 |
-
messages=[{"role":"user","content":
|
| 324 |
response_format=RequirementsResponse
|
| 325 |
)
|
| 326 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
|
@@ -357,14 +359,19 @@ def find_requirements_from_problem_description(req: ReqSearchRequest):
|
|
| 357 |
requirements = req.requirements
|
| 358 |
query = req.query
|
| 359 |
|
| 360 |
-
requirements_text = "\n".join([f"[Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
|
| 361 |
|
| 362 |
print("Called the LLM")
|
| 363 |
resp_ai = llm_router.completion(
|
| 364 |
model="gemini-v2",
|
| 365 |
-
messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of
|
| 366 |
-
response_format=
|
| 367 |
)
|
| 368 |
print("Answered")
|
|
|
|
| 369 |
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
|
| 295 |
documents = req.documents
|
| 296 |
n_docs = len(documents)
|
| 297 |
+
def prompt(doc_id, full):
|
| 298 |
+
return f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"
|
| 299 |
|
| 300 |
async def process_document(doc):
|
| 301 |
doc_id = doc.document
|
|
|
|
| 311 |
async with limiter_mapping[model_used]:
|
| 312 |
resp_ai = await llm_router.acompletion(
|
| 313 |
model=model_used,
|
| 314 |
+
messages=[{"role":"user","content": prompt(doc_id, full)}],
|
| 315 |
response_format=RequirementsResponse
|
| 316 |
)
|
| 317 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
|
|
|
| 322 |
async with limiter_mapping[model_used]:
|
| 323 |
resp_ai = await llm_router.acompletion(
|
| 324 |
model=model_used,
|
| 325 |
+
messages=[{"role":"user","content": prompt(doc_id, full)}],
|
| 326 |
response_format=RequirementsResponse
|
| 327 |
)
|
| 328 |
return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
|
|
|
|
| 359 |
requirements = req.requirements
|
| 360 |
query = req.query
|
| 361 |
|
| 362 |
+
requirements_text = "\n".join([f"[Selection ID: {x} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for x, r in enumerate(requirements)])
|
| 363 |
|
| 364 |
print("Called the LLM")
|
| 365 |
resp_ai = llm_router.completion(
|
| 366 |
model="gemini-v2",
|
| 367 |
+
messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
|
| 368 |
+
response_format=ReqSearchLLMResponse
|
| 369 |
)
|
| 370 |
print("Answered")
|
| 371 |
+
print(resp_ai.choices[0].message.content)
|
| 372 |
|
| 373 |
+
out_llm = ReqSearchLLMResponse.model_validate_json(resp_ai.choices[0].message.content).selected
|
| 374 |
+
if max(out_llm) > len(out_llm) - 1:
|
| 375 |
+
raise HTTPException(status_code=500, detail="LLM error : Generated a wrong index, please try again.")
|
| 376 |
+
|
| 377 |
+
return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])
|
schemas.py
CHANGED
|
@@ -37,6 +37,9 @@ class SingleRequirement(BaseModel):
|
|
| 37 |
context: str
|
| 38 |
requirement: str
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
class ReqSearchRequest(BaseModel):
|
| 41 |
query: str
|
| 42 |
requirements: List[SingleRequirement]
|
|
|
|
| 37 |
context: str
|
| 38 |
requirement: str
|
| 39 |
|
| 40 |
+
class ReqSearchLLMResponse(BaseModel):
|
| 41 |
+
selected: List[int]
|
| 42 |
+
|
| 43 |
class ReqSearchRequest(BaseModel):
|
| 44 |
query: str
|
| 45 |
requirements: List[SingleRequirement]
|