Spaces:
Runtime error
Runtime error
| """ | |
| A model worker that calls huggingface inference endpoint. | |
| Register models in a JSON file with the following format: | |
| { | |
| "falcon-180b-chat": { | |
| "model_path": "tiiuae/falcon-180B-chat", | |
| "api_base": "https://api-inference.huggingface.co/models", | |
| "token": "hf_xxx", | |
| "context_length": 2048, | |
| "model_names": "falcon-180b-chat", | |
| "conv_template": null | |
| } | |
| } | |
| "model_path", "api_base", "token", and "context_length" are necessary, while others are optional. | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| import uuid | |
| from typing import List, Optional | |
| import requests | |
| import uvicorn | |
| from fastapi import BackgroundTasks, FastAPI, Request | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from huggingface_hub import InferenceClient | |
| from fastchat.constants import SERVER_ERROR_MSG, ErrorCode | |
| from fastchat.serve.base_model_worker import BaseModelWorker | |
| from fastchat.utils import build_logger | |
| worker_id = str(uuid.uuid4())[:8] | |
| logger = build_logger("model_worker", f"model_worker_{worker_id}.log") | |
| workers = [] | |
| worker_map = {} | |
| app = FastAPI() | |
| # reference to | |
| # https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 | |
| def get_gen_kwargs( | |
| params, | |
| seed: Optional[int] = None, | |
| ): | |
| stop = params.get("stop", None) | |
| if isinstance(stop, list): | |
| stop_sequences = stop | |
| elif isinstance(stop, str): | |
| stop_sequences = [stop] | |
| else: | |
| stop_sequences = [] | |
| gen_kwargs = { | |
| "do_sample": True, | |
| "return_full_text": bool(params.get("echo", False)), | |
| "max_new_tokens": int(params.get("max_new_tokens", 256)), | |
| "top_p": float(params.get("top_p", 1.0)), | |
| "temperature": float(params.get("temperature", 1.0)), | |
| "stop_sequences": stop_sequences, | |
| "repetition_penalty": float(params.get("repetition_penalty", 1.0)), | |
| "top_k": params.get("top_k", None), | |
| "seed": seed, | |
| } | |
| if gen_kwargs["top_p"] == 1: | |
| gen_kwargs["top_p"] = 0.9999999 | |
| if gen_kwargs["top_p"] == 0: | |
| gen_kwargs.pop("top_p") | |
| if gen_kwargs["temperature"] == 0: | |
| gen_kwargs.pop("temperature") | |
| gen_kwargs["do_sample"] = False | |
| return gen_kwargs | |
| def could_be_stop(text, stop): | |
| for s in stop: | |
| if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): | |
| return True | |
| return False | |
| class HuggingfaceApiWorker(BaseModelWorker): | |
| def __init__( | |
| self, | |
| controller_addr: str, | |
| worker_addr: str, | |
| worker_id: str, | |
| model_path: str, | |
| api_base: str, | |
| token: str, | |
| context_length: int, | |
| model_names: List[str], | |
| limit_worker_concurrency: int, | |
| no_register: bool, | |
| conv_template: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| controller_addr, | |
| worker_addr, | |
| worker_id, | |
| model_path, | |
| model_names, | |
| limit_worker_concurrency, | |
| conv_template=conv_template, | |
| ) | |
| self.model_path = model_path | |
| self.api_base = api_base | |
| self.token = token | |
| self.context_len = context_length | |
| self.seed = seed | |
| logger.info( | |
| f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." | |
| ) | |
| if not no_register: | |
| self.init_heart_beat() | |
| def count_token(self, params): | |
| # No tokenizer here | |
| ret = { | |
| "count": 0, | |
| "error_code": 0, | |
| } | |
| return ret | |
| def generate_stream_gate(self, params): | |
| self.call_ct += 1 | |
| prompt = params["prompt"] | |
| gen_kwargs = get_gen_kwargs(params, seed=self.seed) | |
| stop = gen_kwargs["stop_sequences"] | |
| if "falcon" in self.model_path and "chat" in self.model_path: | |
| stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) | |
| stop = list(set(stop)) | |
| gen_kwargs["stop_sequences"] = stop | |
| logger.info(f"prompt: {prompt}") | |
| logger.info(f"gen_kwargs: {gen_kwargs}") | |
| try: | |
| if self.model_path == "": | |
| url = f"{self.api_base}" | |
| else: | |
| url = f"{self.api_base}/{self.model_path}" | |
| client = InferenceClient(url, token=self.token) | |
| res = client.text_generation( | |
| prompt, stream=True, details=True, **gen_kwargs | |
| ) | |
| reason = None | |
| text = "" | |
| for chunk in res: | |
| if chunk.token.special: | |
| continue | |
| text += chunk.token.text | |
| s = next((x for x in stop if text.endswith(x)), None) | |
| if s is not None: | |
| text = text[: -len(s)] | |
| reason = "stop" | |
| break | |
| if could_be_stop(text, stop): | |
| continue | |
| if ( | |
| chunk.details is not None | |
| and chunk.details.finish_reason is not None | |
| ): | |
| reason = chunk.details.finish_reason | |
| if reason not in ["stop", "length"]: | |
| reason = None | |
| ret = { | |
| "text": text, | |
| "error_code": 0, | |
| "finish_reason": reason, | |
| } | |
| yield json.dumps(ret).encode() + b"\0" | |
| except Exception as e: | |
| ret = { | |
| "text": f"{SERVER_ERROR_MSG}\n\n({e})", | |
| "error_code": ErrorCode.INTERNAL_ERROR, | |
| } | |
| yield json.dumps(ret).encode() + b"\0" | |
| def generate_gate(self, params): | |
| for x in self.generate_stream_gate(params): | |
| pass | |
| return json.loads(x[:-1].decode()) | |
| def get_embeddings(self, params): | |
| raise NotImplementedError() | |
| def release_worker_semaphore(worker): | |
| worker.semaphore.release() | |
| def acquire_worker_semaphore(worker): | |
| if worker.semaphore is None: | |
| worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | |
| return worker.semaphore.acquire() | |
| def create_background_tasks(worker): | |
| background_tasks = BackgroundTasks() | |
| background_tasks.add_task(lambda: release_worker_semaphore(worker)) | |
| return background_tasks | |
| async def api_generate_stream(request: Request): | |
| params = await request.json() | |
| worker = worker_map[params["model"]] | |
| await acquire_worker_semaphore(worker) | |
| generator = worker.generate_stream_gate(params) | |
| background_tasks = create_background_tasks(worker) | |
| return StreamingResponse(generator, background=background_tasks) | |
| async def api_generate(request: Request): | |
| params = await request.json() | |
| worker = worker_map[params["model"]] | |
| await acquire_worker_semaphore(worker) | |
| output = worker.generate_gate(params) | |
| release_worker_semaphore(worker) | |
| return JSONResponse(output) | |
| async def api_get_embeddings(request: Request): | |
| params = await request.json() | |
| worker = worker_map[params["model"]] | |
| await acquire_worker_semaphore(worker) | |
| embedding = worker.get_embeddings(params) | |
| release_worker_semaphore(worker) | |
| return JSONResponse(content=embedding) | |
| async def api_get_status(request: Request): | |
| return { | |
| "model_names": [m for w in workers for m in w.model_names], | |
| "speed": 1, | |
| "queue_length": sum([w.get_queue_length() for w in workers]), | |
| } | |
| async def api_count_token(request: Request): | |
| params = await request.json() | |
| worker = worker_map[params["model"]] | |
| return worker.count_token(params) | |
| async def api_get_conv(request: Request): | |
| params = await request.json() | |
| worker = worker_map[params["model"]] | |
| return worker.get_conv_template() | |
| async def api_model_details(request: Request): | |
| params = await request.json() | |
| worker = worker_map[params["model"]] | |
| return {"context_length": worker.context_len} | |
| def create_huggingface_api_worker(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="localhost") | |
| parser.add_argument("--port", type=int, default=21002) | |
| parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | |
| parser.add_argument( | |
| "--controller-address", type=str, default="http://localhost:21001" | |
| ) | |
| # all model-related parameters are listed in --model-info-file | |
| parser.add_argument( | |
| "--model-info-file", | |
| type=str, | |
| required=True, | |
| help="Huggingface API model's info file path", | |
| ) | |
| parser.add_argument( | |
| "--limit-worker-concurrency", | |
| type=int, | |
| default=5, | |
| help="Limit the model concurrency to prevent OOM.", | |
| ) | |
| parser.add_argument("--no-register", action="store_true") | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="Overwrite the random seed for each generation.", | |
| ) | |
| args = parser.parse_args() | |
| with open(args.model_info_file, "r", encoding="UTF-8") as f: | |
| model_info = json.load(f) | |
| logger.info(f"args: {args}") | |
| model_path_list = [] | |
| api_base_list = [] | |
| token_list = [] | |
| context_length_list = [] | |
| model_names_list = [] | |
| conv_template_list = [] | |
| for m in model_info: | |
| model_path_list.append(model_info[m]["model_path"]) | |
| api_base_list.append(model_info[m]["api_base"]) | |
| token_list.append(model_info[m]["token"]) | |
| context_length = model_info[m]["context_length"] | |
| model_names = model_info[m].get("model_names", [m.split("/")[-1]]) | |
| if isinstance(model_names, str): | |
| model_names = [model_names] | |
| conv_template = model_info[m].get("conv_template", None) | |
| context_length_list.append(context_length) | |
| model_names_list.append(model_names) | |
| conv_template_list.append(conv_template) | |
| logger.info(f"Model paths: {model_path_list}") | |
| logger.info(f"API bases: {api_base_list}") | |
| logger.info(f"Tokens: {token_list}") | |
| logger.info(f"Context lengths: {context_length_list}") | |
| logger.info(f"Model names: {model_names_list}") | |
| logger.info(f"Conv templates: {conv_template_list}") | |
| for ( | |
| model_names, | |
| conv_template, | |
| model_path, | |
| api_base, | |
| token, | |
| context_length, | |
| ) in zip( | |
| model_names_list, | |
| conv_template_list, | |
| model_path_list, | |
| api_base_list, | |
| token_list, | |
| context_length_list, | |
| ): | |
| m = HuggingfaceApiWorker( | |
| args.controller_address, | |
| args.worker_address, | |
| worker_id, | |
| model_path, | |
| api_base, | |
| token, | |
| context_length, | |
| model_names, | |
| args.limit_worker_concurrency, | |
| no_register=args.no_register, | |
| conv_template=conv_template, | |
| seed=args.seed, | |
| ) | |
| workers.append(m) | |
| for name in model_names: | |
| worker_map[name] = m | |
| # register all the models | |
| url = args.controller_address + "/register_worker" | |
| data = { | |
| "worker_name": workers[0].worker_addr, | |
| "check_heart_beat": not args.no_register, | |
| "worker_status": { | |
| "model_names": [m for w in workers for m in w.model_names], | |
| "speed": 1, | |
| "queue_length": sum([w.get_queue_length() for w in workers]), | |
| }, | |
| } | |
| r = requests.post(url, json=data) | |
| assert r.status_code == 200 | |
| return args, workers | |
| if __name__ == "__main__": | |
| args, workers = create_huggingface_api_worker() | |
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |