""" A model worker using Apple MLX https://github.com/ml-explore/mlx-examples/tree/main/llms Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py You must install MLX python: pip install mlx-lm """ import argparse import asyncio import atexit import json from typing import List import uuid from fastapi import FastAPI, Request, BackgroundTasks from fastapi.concurrency import run_in_threadpool from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from src.serve.base_model_worker import BaseModelWorker from src.serve.model_worker import ( logger, worker_id, ) from src.utils import get_context_length, is_partial_stop import mlx.core as mx from mlx_lm import load, generate from mlx_lm.utils import generate_step app = FastAPI() class MLXWorker(BaseModelWorker): def __init__( self, controller_addr: str, worker_addr: str, worker_id: str, model_path: str, model_names: List[str], limit_worker_concurrency: int, no_register: bool, llm_engine: "MLX", conv_template: str, ): super().__init__( controller_addr, worker_addr, worker_id, model_path, model_names, limit_worker_concurrency, conv_template, ) logger.info( f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." ) self.model_name = model_path self.mlx_model, self.mlx_tokenizer = load(model_path) self.tokenizer = self.mlx_tokenizer # self.context_len = get_context_length( # llm_engine.engine.model_config.hf_config) self.context_len = 2048 # hard code for now -- not sure how to get in MLX if not no_register: self.init_heart_beat() async def generate_stream(self, params): self.call_ct += 1 context = params.pop("prompt") request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = params.get("top_k", -1.0) presence_penalty = float(params.get("presence_penalty", 0.0)) frequency_penalty = float(params.get("frequency_penalty", 0.0)) max_new_tokens = params.get("max_new_tokens", 256) stop_str = params.get("stop", None) stop_token_ids = params.get("stop_token_ids", None) or [] if self.tokenizer.eos_token_id is not None: stop_token_ids.append(self.tokenizer.eos_token_id) echo = params.get("echo", True) use_beam_search = params.get("use_beam_search", False) best_of = params.get("best_of", None) # Handle stop_str stop = set() if isinstance(stop_str, str) and stop_str != "": stop.add(stop_str) elif isinstance(stop_str, list) and stop_str != []: stop.update(stop_str) for tid in stop_token_ids: if tid is not None: s = self.tokenizer.decode(tid) if s != "": stop.add(s) print("Stop patterns: ", stop) top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 tokens = [] skip = 0 context_mlx = mx.array(self.tokenizer.encode(context)) finish_reason = "length" iterator = await run_in_threadpool( generate_step, context_mlx, self.mlx_model, temperature ) for i in range(max_new_tokens): (token, _) = await run_in_threadpool(next, iterator) if token == self.mlx_tokenizer.eos_token_id: finish_reason = "stop" break tokens.append(token.item()) tokens_decoded = self.mlx_tokenizer.decode(tokens) last_token_decoded = self.mlx_tokenizer.decode([token.item()]) skip = len(tokens_decoded) partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) if partial_stop: finish_reason = "stop" break ret = { "text": tokens_decoded, "error_code": 0, "usage": { "prompt_tokens": len(context), "completion_tokens": len(tokens), "total_tokens": len(context) + len(tokens), }, "cumulative_logprob": [], "finish_reason": None, # hard code for now } # print(ret) yield (json.dumps(ret) + "\0").encode() ret = { "text": self.mlx_tokenizer.decode(tokens), "error_code": 0, "usage": {}, "cumulative_logprob": [], "finish_reason": finish_reason, } yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() yield (json.dumps(ret) + "\0").encode() async def generate(self, params): async for x in self.generate_stream(params): pass return json.loads(x[:-1].decode()) def release_worker_semaphore(): worker.semaphore.release() def acquire_worker_semaphore(): if worker.semaphore is None: worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) return worker.semaphore.acquire() def create_background_tasks(request_id): async def abort_request() -> None: print("trying to abort but not implemented") background_tasks = BackgroundTasks() background_tasks.add_task(release_worker_semaphore) background_tasks.add_task(abort_request) return background_tasks @app.post("/worker_generate_stream") async def api_generate_stream(request: Request): params = await request.json() await acquire_worker_semaphore() request_id = uuid.uuid4() params["request_id"] = str(request_id) generator = worker.generate_stream(params) background_tasks = create_background_tasks(request_id) return StreamingResponse(generator, background=background_tasks) @app.post("/worker_generate") async def api_generate(request: Request): params = await request.json() await acquire_worker_semaphore() request_id = uuid.uuid4() params["request_id"] = str(request_id) output = await worker.generate(params) release_worker_semaphore() # await engine.abort(request_id) print("Trying to abort but not implemented") return JSONResponse(output) @app.post("/worker_get_status") async def api_get_status(request: Request): return worker.get_status() @app.post("/count_token") async def api_count_token(request: Request): params = await request.json() return worker.count_token(params) @app.post("/worker_get_conv_template") async def api_get_conv(request: Request): return worker.get_conv_template() @app.post("/model_details") async def api_model_details(request: Request): return {"context_length": worker.context_len} worker = None def cleanup_at_exit(): global worker print("Cleaning up...") del worker atexit.register(cleanup_at_exit) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, default="http://localhost:21002") parser.add_argument( "--controller-address", type=str, default="http://localhost:21001" ) parser.add_argument("--model-path", type=str, default="microsoft/phi-2") parser.add_argument( "--model-names", type=lambda s: s.split(","), help="Optional display comma separated names", ) parser.add_argument( "--conv-template", type=str, default=None, help="Conversation prompt template." ) parser.add_argument( "--trust_remote_code", action="store_false", default=True, help="Trust remote code (e.g., from HuggingFace) when" "downloading the model and tokenizer.", ) args, unknown = parser.parse_known_args() if args.model_path: args.model = args.model_path worker = MLXWorker( args.controller_address, args.worker_address, worker_id, args.model_path, args.model_names, 1024, False, "MLX", args.conv_template, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info")