Spaces:
Runtime error
Runtime error
| import time | |
| import uuid | |
| from functools import partial | |
| from typing import ( | |
| List, | |
| Dict, | |
| Any, | |
| AsyncIterator, | |
| Optional, | |
| ) | |
| import anyio | |
| from fastapi import APIRouter, Depends | |
| from fastapi import HTTPException, Request | |
| from loguru import logger | |
| from openai.types.completion import Completion | |
| from openai.types.completion_choice import CompletionChoice, Logprobs | |
| from openai.types.completion_usage import CompletionUsage | |
| from sse_starlette import EventSourceResponse | |
| from vllm.outputs import RequestOutput | |
| from api.core.vllm_engine import VllmEngine | |
| from api.models import GENERATE_ENGINE | |
| from api.utils.compat import model_dump | |
| from api.utils.protocol import CompletionCreateParams | |
| from api.utils.request import ( | |
| handle_request, | |
| get_event_publisher, | |
| check_api_key | |
| ) | |
| completion_router = APIRouter() | |
| def get_engine(): | |
| yield GENERATE_ENGINE | |
| async def create_completion( | |
| request: CompletionCreateParams, | |
| raw_request: Request, | |
| engine: VllmEngine = Depends(get_engine), | |
| ): | |
| """Completion API similar to OpenAI's API. | |
| See https://platform.openai.com/docs/api-reference/completions/create | |
| for the API specification. This API mimics the OpenAI Completion API. | |
| """ | |
| if request.echo: | |
| # We do not support echo since the vLLM engine does not | |
| # currently support getting the logprobs of prompt tokens. | |
| raise HTTPException(status_code=400, detail="echo is not currently supported") | |
| if request.suffix: | |
| # The language models we currently support do not support suffix. | |
| raise HTTPException(status_code=400, detail="suffix is not currently supported") | |
| request.max_tokens = request.max_tokens or 128 | |
| request = await handle_request(request, engine.prompt_adapter.stop, chat=False) | |
| if isinstance(request.prompt, list): | |
| request.prompt = request.prompt[0] | |
| params = model_dump(request, exclude={"prompt"}) | |
| params.update(dict(prompt_or_messages=request.prompt)) | |
| logger.debug(f"==== request ====\n{params}") | |
| request_id: str = f"cmpl-{str(uuid.uuid4())}" | |
| generator = engine.generate(params, request_id) | |
| if request.stream: | |
| iterator = create_completion_stream(generator, params, request_id, engine.tokenizer) | |
| send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
| return EventSourceResponse( | |
| recv_chan, | |
| data_sender_callable=partial( | |
| get_event_publisher, | |
| request=raw_request, | |
| inner_send_chan=send_chan, | |
| iterator=iterator, | |
| ), | |
| ) | |
| else: | |
| # Non-streaming response | |
| final_res: RequestOutput = None | |
| async for res in generator: | |
| if raw_request is not None and await raw_request.is_disconnected(): | |
| await engine.model.abort(request_id) | |
| return | |
| final_res = res | |
| assert final_res is not None | |
| choices = [] | |
| for output in final_res.outputs: | |
| output.text = output.text.replace("�", "") | |
| logprobs = None | |
| if params.get("logprobs", None) is not None: | |
| logprobs = create_logprobs(engine.tokenizer, output.token_ids, output.logprobs) | |
| choice = CompletionChoice( | |
| index=output.index, | |
| text=output.text, | |
| finish_reason=output.finish_reason, | |
| logprobs=logprobs, | |
| ) | |
| choices.append(choice) | |
| num_prompt_tokens = len(final_res.prompt_token_ids) | |
| num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) | |
| usage = CompletionUsage( | |
| prompt_tokens=num_prompt_tokens, | |
| completion_tokens=num_generated_tokens, | |
| total_tokens=num_prompt_tokens + num_generated_tokens, | |
| ) | |
| return Completion( | |
| id=request_id, | |
| choices=choices, | |
| created=int(time.time()), | |
| model=params.get("model", "llm"), | |
| object="text_completion", | |
| usage=usage, | |
| ) | |
| def create_logprobs( | |
| tokenizer, | |
| token_ids: List[int], | |
| top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, | |
| num_output_top_logprobs: Optional[int] = None, | |
| initial_text_offset: int = 0, | |
| ) -> Logprobs: | |
| logprobs = Logprobs(text_offset=[], token_logprobs=[], tokens=[], top_logprobs=None) | |
| last_token_len = 0 | |
| if num_output_top_logprobs: | |
| logprobs.top_logprobs = [] | |
| for i, token_id in enumerate(token_ids): | |
| step_top_logprobs = top_logprobs[i] | |
| if step_top_logprobs is not None: | |
| token_logprob = step_top_logprobs[token_id] | |
| else: | |
| token_logprob = None | |
| token = tokenizer.convert_ids_to_tokens(token_id) | |
| logprobs.tokens.append(token) | |
| logprobs.token_logprobs.append(token_logprob) | |
| if len(logprobs.text_offset) == 0: | |
| logprobs.text_offset.append(initial_text_offset) | |
| else: | |
| logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) | |
| last_token_len = len(token) | |
| if num_output_top_logprobs: | |
| logprobs.top_logprobs.append( | |
| { | |
| tokenizer.convert_ids_to_tokens(i): p | |
| for i, p in step_top_logprobs.items() | |
| } | |
| if step_top_logprobs else None | |
| ) | |
| return logprobs | |
| async def create_completion_stream( | |
| generator: AsyncIterator, params: Dict[str, Any], request_id: str, tokenizer, | |
| ) -> AsyncIterator: | |
| n = params.get("n", 1) | |
| previous_texts = [""] * n | |
| previous_num_tokens = [0] * n | |
| async for res in generator: | |
| res: RequestOutput | |
| for output in res.outputs: | |
| i = output.index | |
| output.text = output.text.replace("�", "") | |
| delta_text = output.text[len(previous_texts[i]):] | |
| if params.get("logprobs") is not None: | |
| logprobs = create_logprobs( | |
| tokenizer, | |
| output.token_ids[previous_num_tokens[i]:], | |
| output.logprobs[previous_num_tokens[i]:], | |
| len(previous_texts[i]) | |
| ) | |
| else: | |
| logprobs = None | |
| previous_texts[i] = output.text | |
| previous_num_tokens[i] = len(output.token_ids) | |
| choice = CompletionChoice( | |
| index=i, | |
| text=delta_text, | |
| finish_reason="stop", | |
| logprobs=logprobs, | |
| ) | |
| yield Completion( | |
| id=request_id, | |
| choices=[choice], | |
| created=int(time.time()), | |
| model=params.get("model", "llm"), | |
| object="text_completion", | |
| ) | |
| if output.finish_reason is not None: | |
| if params.get("logprobs") is not None: | |
| logprobs = Logprobs( | |
| text_offset=[], token_logprobs=[], tokens=[], top_logprobs=[] | |
| ) | |
| else: | |
| logprobs = None | |
| choice = CompletionChoice( | |
| index=i, | |
| text=delta_text, | |
| finish_reason="stop", | |
| logprobs=logprobs, | |
| ) | |
| yield Completion( | |
| id=request_id, | |
| choices=[choice], | |
| created=int(time.time()), | |
| model=params.get("model", "llm"), | |
| object="text_completion", | |
| ) | |