Spaces:
Build error
Build error
| import time | |
| import json | |
| from pathlib import Path | |
| from typing import Optional | |
| import logging | |
| logging.basicConfig(level = logging.INFO) | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer | |
| import re | |
| import tensorrt_llm | |
| from tensorrt_llm.logger import logger | |
| from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner | |
| if PYTHON_BINDINGS: | |
| from tensorrt_llm.runtime import ModelRunnerCpp | |
| def read_model_name(engine_dir: str): | |
| engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir) | |
| with open(Path(engine_dir) / "config.json", 'r') as f: | |
| config = json.load(f) | |
| if engine_version is None: | |
| return config['builder_config']['name'] | |
| return config['pretrained_config']['architecture'] | |
| def throttle_generator(generator, stream_interval): | |
| for i, out in enumerate(generator): | |
| if not i % stream_interval: | |
| yield out | |
| if i % stream_interval: | |
| yield out | |
| def load_tokenizer(tokenizer_dir: Optional[str] = None, | |
| vocab_file: Optional[str] = None, | |
| model_name: str = 'gpt', | |
| tokenizer_type: Optional[str] = None): | |
| if vocab_file is None: | |
| use_fast = True | |
| if tokenizer_type is not None and tokenizer_type == "llama": | |
| use_fast = False | |
| # Should set both padding_side and truncation_side to be 'left' | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, | |
| legacy=False, | |
| padding_side='left', | |
| truncation_side='left', | |
| trust_remote_code=True, | |
| tokenizer_type=tokenizer_type, | |
| use_fast=use_fast) | |
| else: | |
| # For gpt-next, directly load from tokenizer.model | |
| assert model_name == 'gpt' | |
| tokenizer = T5Tokenizer(vocab_file=vocab_file, | |
| padding_side='left', | |
| truncation_side='left') | |
| if model_name == 'qwen': | |
| with open(Path(tokenizer_dir) / "generation_config.json") as f: | |
| gen_config = json.load(f) | |
| chat_format = gen_config['chat_format'] | |
| if chat_format == 'raw': | |
| pad_id = gen_config['pad_token_id'] | |
| end_id = gen_config['eos_token_id'] | |
| elif chat_format == 'chatml': | |
| pad_id = tokenizer.im_end_id | |
| end_id = tokenizer.im_end_id | |
| else: | |
| raise Exception(f"unknown chat format: {chat_format}") | |
| elif model_name == 'glm_10b': | |
| pad_id = tokenizer.pad_token_id | |
| end_id = tokenizer.eop_token_id | |
| else: | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| pad_id = tokenizer.pad_token_id | |
| end_id = tokenizer.eos_token_id | |
| return tokenizer, pad_id, end_id | |
| class TensorRTLLMEngine: | |
| def __init__(self): | |
| pass | |
| def initialize_model(self, engine_dir, tokenizer_dir): | |
| self.log_level = 'error' | |
| self.runtime_rank = tensorrt_llm.mpi_rank() | |
| logger.set_level(self.log_level) | |
| model_name = read_model_name(engine_dir) | |
| self.tokenizer, self.pad_id, self.end_id = load_tokenizer( | |
| tokenizer_dir=tokenizer_dir, | |
| vocab_file=None, | |
| model_name=model_name, | |
| tokenizer_type=None, | |
| ) | |
| self.prompt_template = None | |
| self.runner_cls = ModelRunner | |
| self.runner_kwargs = dict(engine_dir=engine_dir, | |
| lora_dir=None, | |
| rank=self.runtime_rank, | |
| debug_mode=False, | |
| lora_ckpt_source='hf') | |
| self.runner = self.runner_cls.from_dir(**self.runner_kwargs) | |
| self.last_prompt = None | |
| self.last_output = None | |
| def parse_input( | |
| self, | |
| input_text=None, | |
| add_special_tokens=True, | |
| max_input_length=923, | |
| pad_id=None, | |
| ): | |
| if self.pad_id is None: | |
| self.pad_id = self.tokenizer.pad_token_id | |
| batch_input_ids = [] | |
| for curr_text in input_text: | |
| if self.prompt_template is not None: | |
| curr_text = self.prompt_template.format(input_text=curr_text) | |
| input_ids = self.tokenizer.encode( | |
| curr_text, | |
| add_special_tokens=add_special_tokens, | |
| truncation=True, | |
| max_length=max_input_length | |
| ) | |
| batch_input_ids.append(input_ids) | |
| batch_input_ids = [ | |
| torch.tensor(x, dtype=torch.int32) for x in batch_input_ids | |
| ] | |
| return batch_input_ids | |
| def decode_tokens( | |
| self, | |
| output_ids, | |
| input_lengths, | |
| sequence_lengths, | |
| transcription_queue | |
| ): | |
| batch_size, num_beams, _ = output_ids.size() | |
| for batch_idx in range(batch_size): | |
| if transcription_queue.qsize() != 0: | |
| return None | |
| inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist() | |
| input_text = self.tokenizer.decode(inputs) | |
| output = [] | |
| for beam in range(num_beams): | |
| if transcription_queue.qsize() != 0: | |
| return None | |
| output_begin = input_lengths[batch_idx] | |
| output_end = sequence_lengths[batch_idx][beam] | |
| outputs = output_ids[batch_idx][beam][ | |
| output_begin:output_end].tolist() | |
| output_text = self.tokenizer.decode(outputs) | |
| output.append(output_text) | |
| return output | |
| def format_prompt_qa(self, prompt, conversation_history): | |
| formatted_prompt = "" | |
| for user_prompt, llm_response in conversation_history: | |
| formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n" | |
| return f"{formatted_prompt}Instruct: {prompt}\nOutput:" | |
| def format_prompt_chat(self, prompt, conversation_history): | |
| formatted_prompt = "" | |
| for user_prompt, llm_response in conversation_history: | |
| formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n" | |
| return f"{formatted_prompt}Alice: {prompt}\nBob:" | |
| def format_prompt_chatml(self, prompt, conversation_history, system_prompt=""): | |
| formatted_prompt = ("<|im_start|>system\n" + system_prompt + "<|im_end|>\n") | |
| for user_prompt, llm_response in conversation_history: | |
| formatted_prompt += f"<|im_start|>user\n{user_prompt}<|im_end|>\n" | |
| formatted_prompt += f"<|im_start|>assistant\n{llm_response}<|im_end|>\n" | |
| formatted_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n" | |
| return formatted_prompt | |
| def run( | |
| self, | |
| model_path, | |
| tokenizer_path, | |
| transcription_queue=None, | |
| llm_queue=None, | |
| audio_queue=None, | |
| input_text=None, | |
| max_output_len=50, | |
| max_attention_window_size=4096, | |
| num_beams=1, | |
| streaming=False, | |
| streaming_interval=4, | |
| debug=False, | |
| ): | |
| self.initialize_model( | |
| model_path, | |
| tokenizer_path, | |
| ) | |
| logging.info("[LLM INFO:] Loaded LLM TensorRT Engine.") | |
| conversation_history = {} | |
| while True: | |
| # Get the last transcription output from the queue | |
| transcription_output = transcription_queue.get() | |
| if transcription_queue.qsize() != 0: | |
| continue | |
| if transcription_output["uid"] not in conversation_history: | |
| conversation_history[transcription_output["uid"]] = [] | |
| prompt = transcription_output['prompt'].strip() | |
| # if prompt is same but EOS is True, we need that to send outputs to websockets | |
| if self.last_prompt == prompt: | |
| if self.last_output is not None and transcription_output["eos"]: | |
| self.eos = transcription_output["eos"] | |
| llm_queue.put({ | |
| "uid": transcription_output["uid"], | |
| "llm_output": self.last_output, | |
| "eos": self.eos, | |
| "latency": self.infer_time | |
| }) | |
| audio_queue.put({"llm_output": self.last_output, "eos": self.eos}) | |
| conversation_history[transcription_output["uid"]].append( | |
| (transcription_output['prompt'].strip(), self.last_output[0].strip()) | |
| ) | |
| continue | |
| # input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])] | |
| input_text=[self.format_prompt_chatml(prompt, conversation_history[transcription_output["uid"]], system_prompt="You are Dolphin, a helpful AI assistant")] | |
| self.eos = transcription_output["eos"] | |
| batch_input_ids = self.parse_input( | |
| input_text=input_text, | |
| add_special_tokens=True, | |
| max_input_length=923, | |
| pad_id=None, | |
| ) | |
| input_lengths = [x.size(0) for x in batch_input_ids] | |
| logging.info(f"[LLM INFO:] Running LLM Inference with WhisperLive prompt: {prompt}, eos: {self.eos}") | |
| start = time.time() | |
| with torch.no_grad(): | |
| outputs = self.runner.generate( | |
| batch_input_ids, | |
| max_new_tokens=max_output_len, | |
| max_attention_window_size=max_attention_window_size, | |
| end_id=self.end_id, | |
| pad_id=self.pad_id, | |
| temperature=1.0, | |
| top_k=1, | |
| top_p=0.0, | |
| num_beams=num_beams, | |
| length_penalty=1.0, | |
| repetition_penalty=1.0, | |
| stop_words_list=None, | |
| bad_words_list=None, | |
| lora_uids=None, | |
| prompt_table_path=None, | |
| prompt_tasks=None, | |
| streaming=streaming, | |
| output_sequence_lengths=True, | |
| return_dict=True) | |
| torch.cuda.synchronize() | |
| if streaming: | |
| for curr_outputs in throttle_generator(outputs, streaming_interval): | |
| output_ids = curr_outputs['output_ids'] | |
| sequence_lengths = curr_outputs['sequence_lengths'] | |
| output = self.decode_tokens( | |
| output_ids, | |
| input_lengths, | |
| sequence_lengths, | |
| transcription_queue | |
| ) | |
| if output is None: | |
| break | |
| # Interrupted by transcription queue | |
| if output is None: | |
| continue | |
| else: | |
| output_ids = outputs['output_ids'] | |
| sequence_lengths = outputs['sequence_lengths'] | |
| context_logits = None | |
| generation_logits = None | |
| if self.runner.gather_context_logits: | |
| context_logits = outputs['context_logits'] | |
| if self.runner.gather_generation_logits: | |
| generation_logits = outputs['generation_logits'] | |
| output = self.decode_tokens( | |
| output_ids, | |
| input_lengths, | |
| sequence_lengths, | |
| transcription_queue | |
| ) | |
| self.infer_time = time.time() - start | |
| # if self.eos: | |
| if output is not None: | |
| output[0] = clean_llm_output(output[0]) | |
| self.last_output = output | |
| self.last_prompt = prompt | |
| llm_queue.put({ | |
| "uid": transcription_output["uid"], | |
| "llm_output": output, | |
| "eos": self.eos, | |
| "latency": self.infer_time | |
| }) | |
| audio_queue.put({"llm_output": output, "eos": self.eos}) | |
| logging.info(f"[LLM INFO:] Output: {output[0]}\nLLM inference done in {self.infer_time} ms\n\n") | |
| if self.eos: | |
| conversation_history[transcription_output["uid"]].append( | |
| (transcription_output['prompt'].strip(), output[0].strip()) | |
| ) | |
| self.last_prompt = None | |
| self.last_output = None | |
| def clean_llm_output(output): | |
| output = output.replace("\n\nDolphin\n\n", "") | |
| output = output.replace("\nDolphin\n\n", "") | |
| output = output.replace("Dolphin: ", "") | |
| output = output.replace("Assistant: ", "") | |
| if not output.endswith('.') and not output.endswith('?') and not output.endswith('!'): | |
| last_punct = output.rfind('.') | |
| last_q = output.rfind('?') | |
| if last_q > last_punct: | |
| last_punct = last_q | |
| last_ex = output.rfind('!') | |
| if last_ex > last_punct: | |
| last_punct = last_ex | |
| if last_punct > 0: | |
| output = output[:last_punct+1] | |
| return output | |