Spaces:
Paused
Paused
| """ | |
| A model worker using Apple MLX | |
| https://github.com/ml-explore/mlx-examples/tree/main/llms | |
| Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py | |
| You must install MLX python: | |
| pip install mlx-lm | |
| """ | |
| import argparse | |
| import asyncio | |
| import atexit | |
| import json | |
| from typing import List | |
| import uuid | |
| from fastapi import FastAPI, Request, BackgroundTasks | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| import uvicorn | |
| from fastchat.serve.base_model_worker import BaseModelWorker | |
| from fastchat.serve.model_worker import ( | |
| logger, | |
| worker_id, | |
| ) | |
| from fastchat.utils import get_context_length, is_partial_stop | |
| import mlx.core as mx | |
| from mlx_lm import load, generate | |
| from mlx_lm.utils import generate_step | |
| app = FastAPI() | |
| class MLXWorker(BaseModelWorker): | |
| def __init__( | |
| self, | |
| controller_addr: str, | |
| worker_addr: str, | |
| worker_id: str, | |
| model_path: str, | |
| model_names: List[str], | |
| limit_worker_concurrency: int, | |
| no_register: bool, | |
| llm_engine: "MLX", | |
| conv_template: str, | |
| ): | |
| super().__init__( | |
| controller_addr, | |
| worker_addr, | |
| worker_id, | |
| model_path, | |
| model_names, | |
| limit_worker_concurrency, | |
| conv_template, | |
| ) | |
| logger.info( | |
| f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." | |
| ) | |
| self.model_name = model_path | |
| self.mlx_model, self.mlx_tokenizer = load(model_path) | |
| self.tokenizer = self.mlx_tokenizer | |
| # self.context_len = get_context_length( | |
| # llm_engine.engine.model_config.hf_config) | |
| self.context_len = 2048 # hard code for now -- not sure how to get in MLX | |
| if not no_register: | |
| self.init_heart_beat() | |
| async def generate_stream(self, params): | |
| self.call_ct += 1 | |
| context = params.pop("prompt") | |
| request_id = params.pop("request_id") | |
| temperature = float(params.get("temperature", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| top_k = params.get("top_k", -1.0) | |
| presence_penalty = float(params.get("presence_penalty", 0.0)) | |
| frequency_penalty = float(params.get("frequency_penalty", 0.0)) | |
| max_new_tokens = params.get("max_new_tokens", 256) | |
| stop_str = params.get("stop", None) | |
| stop_token_ids = params.get("stop_token_ids", None) or [] | |
| if self.tokenizer.eos_token_id is not None: | |
| stop_token_ids.append(self.tokenizer.eos_token_id) | |
| echo = params.get("echo", True) | |
| use_beam_search = params.get("use_beam_search", False) | |
| best_of = params.get("best_of", None) | |
| # Handle stop_str | |
| stop = set() | |
| if isinstance(stop_str, str) and stop_str != "": | |
| stop.add(stop_str) | |
| elif isinstance(stop_str, list) and stop_str != []: | |
| stop.update(stop_str) | |
| for tid in stop_token_ids: | |
| if tid is not None: | |
| s = self.tokenizer.decode(tid) | |
| if s != "": | |
| stop.add(s) | |
| print("Stop patterns: ", stop) | |
| top_p = max(top_p, 1e-5) | |
| if temperature <= 1e-5: | |
| top_p = 1.0 | |
| tokens = [] | |
| skip = 0 | |
| context_mlx = mx.array(self.tokenizer.encode(context)) | |
| finish_reason = "length" | |
| iterator = await run_in_threadpool( | |
| generate_step, context_mlx, self.mlx_model, temperature | |
| ) | |
| for i in range(max_new_tokens): | |
| (token, _) = await run_in_threadpool(next, iterator) | |
| if token == self.mlx_tokenizer.eos_token_id: | |
| finish_reason = "stop" | |
| break | |
| tokens.append(token.item()) | |
| tokens_decoded = self.mlx_tokenizer.decode(tokens) | |
| last_token_decoded = self.mlx_tokenizer.decode([token.item()]) | |
| skip = len(tokens_decoded) | |
| partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) | |
| if partial_stop: | |
| finish_reason = "stop" | |
| break | |
| ret = { | |
| "text": tokens_decoded, | |
| "error_code": 0, | |
| "usage": { | |
| "prompt_tokens": len(context), | |
| "completion_tokens": len(tokens), | |
| "total_tokens": len(context) + len(tokens), | |
| }, | |
| "cumulative_logprob": [], | |
| "finish_reason": None, # hard code for now | |
| } | |
| # print(ret) | |
| yield (json.dumps(ret) + "\0").encode() | |
| ret = { | |
| "text": self.mlx_tokenizer.decode(tokens), | |
| "error_code": 0, | |
| "usage": {}, | |
| "cumulative_logprob": [], | |
| "finish_reason": finish_reason, | |
| } | |
| yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() | |
| yield (json.dumps(ret) + "\0").encode() | |
| async def generate(self, params): | |
| async for x in self.generate_stream(params): | |
| pass | |
| return json.loads(x[:-1].decode()) | |
| def release_worker_semaphore(): | |
| worker.semaphore.release() | |
| def acquire_worker_semaphore(): | |
| if worker.semaphore is None: | |
| worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | |
| return worker.semaphore.acquire() | |
| def create_background_tasks(request_id): | |
| async def abort_request() -> None: | |
| print("trying to abort but not implemented") | |
| background_tasks = BackgroundTasks() | |
| background_tasks.add_task(release_worker_semaphore) | |
| background_tasks.add_task(abort_request) | |
| return background_tasks | |
| async def api_generate_stream(request: Request): | |
| params = await request.json() | |
| await acquire_worker_semaphore() | |
| request_id = uuid.uuid4() | |
| params["request_id"] = str(request_id) | |
| generator = worker.generate_stream(params) | |
| background_tasks = create_background_tasks(request_id) | |
| return StreamingResponse(generator, background=background_tasks) | |
| async def api_generate(request: Request): | |
| params = await request.json() | |
| await acquire_worker_semaphore() | |
| request_id = uuid.uuid4() | |
| params["request_id"] = str(request_id) | |
| output = await worker.generate(params) | |
| release_worker_semaphore() | |
| # await engine.abort(request_id) | |
| print("Trying to abort but not implemented") | |
| return JSONResponse(output) | |
| async def api_get_status(request: Request): | |
| return worker.get_status() | |
| async def api_count_token(request: Request): | |
| params = await request.json() | |
| return worker.count_token(params) | |
| async def api_get_conv(request: Request): | |
| return worker.get_conv_template() | |
| async def api_model_details(request: Request): | |
| return {"context_length": worker.context_len} | |
| worker = None | |
| def cleanup_at_exit(): | |
| global worker | |
| print("Cleaning up...") | |
| del worker | |
| atexit.register(cleanup_at_exit) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="localhost") | |
| parser.add_argument("--port", type=int, default=21002) | |
| parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | |
| parser.add_argument( | |
| "--controller-address", type=str, default="http://localhost:21001" | |
| ) | |
| parser.add_argument("--model-path", type=str, default="microsoft/phi-2") | |
| parser.add_argument( | |
| "--model-names", | |
| type=lambda s: s.split(","), | |
| help="Optional display comma separated names", | |
| ) | |
| parser.add_argument( | |
| "--conv-template", type=str, default=None, help="Conversation prompt template." | |
| ) | |
| parser.add_argument( | |
| "--trust_remote_code", | |
| action="store_false", | |
| default=True, | |
| help="Trust remote code (e.g., from HuggingFace) when" | |
| "downloading the model and tokenizer.", | |
| ) | |
| args, unknown = parser.parse_known_args() | |
| if args.model_path: | |
| args.model = args.model_path | |
| worker = MLXWorker( | |
| args.controller_address, | |
| args.worker_address, | |
| worker_id, | |
| args.model_path, | |
| args.model_names, | |
| 1024, | |
| False, | |
| "MLX", | |
| args.conv_template, | |
| ) | |
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |