Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| from concurrent import futures | |
| import time | |
| import argparse | |
| import signal | |
| import sys | |
| import os | |
| import backend_pb2 | |
| import backend_pb2_grpc | |
| import grpc | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
| _ONE_DAY_IN_SECONDS = 60 * 60 * 24 | |
| # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 | |
| MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) | |
| MAMBA_CHAT= os.environ.get('MAMBA_CHAT', '1') == '1' | |
| # Implement the BackendServicer class with the service methods | |
| class BackendServicer(backend_pb2_grpc.BackendServicer): | |
| """ | |
| A gRPC servicer that implements the Backend service defined in backend.proto. | |
| """ | |
| def generate(self,prompt, max_new_tokens): | |
| """ | |
| Generates text based on the given prompt and maximum number of new tokens. | |
| Args: | |
| prompt (str): The prompt to generate text from. | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| Returns: | |
| str: The generated text. | |
| """ | |
| self.generator.end_beam_search() | |
| # Tokenizing the input | |
| ids = self.generator.tokenizer.encode(prompt) | |
| self.generator.gen_begin_reuse(ids) | |
| initial_len = self.generator.sequence[0].shape[0] | |
| has_leading_space = False | |
| decoded_text = '' | |
| for i in range(max_new_tokens): | |
| token = self.generator.gen_single_token() | |
| if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): | |
| has_leading_space = True | |
| decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) | |
| if has_leading_space: | |
| decoded_text = ' ' + decoded_text | |
| if token.item() == self.generator.tokenizer.eos_token_id: | |
| break | |
| return decoded_text | |
| def Health(self, request, context): | |
| """ | |
| Returns a health check message. | |
| Args: | |
| request: The health check request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The health check reply. | |
| """ | |
| return backend_pb2.Reply(message=bytes("OK", 'utf-8')) | |
| def LoadModel(self, request, context): | |
| """ | |
| Loads a language model. | |
| Args: | |
| request: The load model request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The load model result. | |
| """ | |
| try: | |
| tokenizerModel = request.Tokenizer | |
| if tokenizerModel == "": | |
| tokenizerModel = request.Model | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizerModel) | |
| if MAMBA_CHAT: | |
| tokenizer.eos_token = "<|endoftext|>" | |
| tokenizer.pad_token = tokenizer.eos_token | |
| self.tokenizer = tokenizer | |
| self.model = MambaLMHeadModel.from_pretrained(request.Model, device="cuda", dtype=torch.float16) | |
| except Exception as err: | |
| return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") | |
| return backend_pb2.Result(message="Model loaded successfully", success=True) | |
| def Predict(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters. | |
| Args: | |
| request: The predict request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The predict result. | |
| """ | |
| if request.TopP == 0: | |
| request.TopP = 0.9 | |
| max_tokens = request.Tokens | |
| if request.Tokens == 0: | |
| max_tokens = 2000 | |
| # encoded_input = self.tokenizer(request.Prompt) | |
| tokens = self.tokenizer(request.Prompt, return_tensors="pt") | |
| input_ids = tokens.input_ids.to(device="cuda") | |
| out = self.model.generate(input_ids=input_ids, max_length=max_tokens, temperature=request.Temperature, | |
| top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id) | |
| decoded = self.tokenizer.batch_decode(out) | |
| generated_text = decoded[0] | |
| # Remove prompt from response if present | |
| if request.Prompt in generated_text: | |
| generated_text = generated_text.replace(request.Prompt, "") | |
| return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) | |
| def PredictStream(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters, and streams the results. | |
| Args: | |
| request: The predict stream request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The predict stream result. | |
| """ | |
| yield self.Predict(request, context) | |
| def serve(address): | |
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) | |
| backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) | |
| server.add_insecure_port(address) | |
| server.start() | |
| print("Server started. Listening on: " + address, file=sys.stderr) | |
| # Define the signal handler function | |
| def signal_handler(sig, frame): | |
| print("Received termination signal. Shutting down...") | |
| server.stop(0) | |
| sys.exit(0) | |
| # Set the signal handlers for SIGINT and SIGTERM | |
| signal.signal(signal.SIGINT, signal_handler) | |
| signal.signal(signal.SIGTERM, signal_handler) | |
| try: | |
| while True: | |
| time.sleep(_ONE_DAY_IN_SECONDS) | |
| except KeyboardInterrupt: | |
| server.stop(0) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run the gRPC server.") | |
| parser.add_argument( | |
| "--addr", default="localhost:50051", help="The address to bind the server to." | |
| ) | |
| args = parser.parse_args() | |
| serve(args.addr) | |