Spaces:
Runtime error
Runtime error
| """Inference for FastChat models.""" | |
| import abc | |
| import gc | |
| import json | |
| import math | |
| import os | |
| import sys | |
| import time | |
| from typing import Iterable, Optional, Dict | |
| import warnings | |
| import psutil | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| LlamaTokenizer, | |
| LlamaForCausalLM, | |
| AutoModel, | |
| AutoModelForSeq2SeqLM, | |
| T5Tokenizer, | |
| AutoConfig, | |
| ) | |
| from transformers.generation.logits_process import ( | |
| LogitsProcessorList, | |
| RepetitionPenaltyLogitsProcessor, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| TopPLogitsWarper, | |
| ) | |
| from src.conversation import get_conv_template, SeparatorStyle | |
| from src.model.model_adapter import ( | |
| load_model, | |
| get_conversation_template, | |
| get_generate_stream_function, | |
| ) | |
| from src.modules.awq import AWQConfig | |
| from src.modules.gptq import GptqConfig | |
| from src.modules.exllama import ExllamaConfig | |
| from src.modules.xfastertransformer import XftConfig | |
| from src.utils import is_partial_stop, is_sentence_complete, get_context_length | |
| def prepare_logits_processor( | |
| temperature: float, repetition_penalty: float, top_p: float, top_k: int | |
| ) -> LogitsProcessorList: | |
| processor_list = LogitsProcessorList() | |
| # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. | |
| if temperature >= 1e-5 and temperature != 1.0: | |
| processor_list.append(TemperatureLogitsWarper(temperature)) | |
| if repetition_penalty > 1.0: | |
| processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) | |
| if 1e-8 <= top_p < 1.0: | |
| processor_list.append(TopPLogitsWarper(top_p)) | |
| if top_k > 0: | |
| processor_list.append(TopKLogitsWarper(top_k)) | |
| return processor_list | |
| def generate_stream( | |
| model, | |
| tokenizer, | |
| params: Dict, | |
| device: str, | |
| context_len: int, | |
| stream_interval: int = 2, | |
| judge_sent_end: bool = False, | |
| ): | |
| if hasattr(model, "device"): | |
| device = model.device | |
| # Read parameters | |
| prompt = params["prompt"] | |
| len_prompt = len(prompt) | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| top_k = int(params.get("top_k", -1)) # -1 means disable | |
| max_new_tokens = int(params.get("max_new_tokens", 256)) | |
| logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. | |
| echo = bool(params.get("echo", True)) | |
| stop_str = params.get("stop", None) | |
| stop_token_ids = params.get("stop_token_ids", None) or [] | |
| if tokenizer.eos_token_id not in stop_token_ids: | |
| stop_token_ids.append(tokenizer.eos_token_id) | |
| logits_processor = prepare_logits_processor( | |
| temperature, repetition_penalty, top_p, top_k | |
| ) | |
| input_ids = tokenizer(prompt).input_ids | |
| if model.config.is_encoder_decoder: | |
| max_src_len = context_len | |
| else: # truncate | |
| max_src_len = context_len - max_new_tokens - 1 | |
| input_ids = input_ids[-max_src_len:] | |
| output_ids = list(input_ids) | |
| input_echo_len = len(input_ids) | |
| if model.config.is_encoder_decoder: | |
| if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. | |
| raise NotImplementedError | |
| encoder_output = model.encoder( | |
| input_ids=torch.as_tensor([input_ids], device=device) | |
| )[0] | |
| start_ids = torch.as_tensor( | |
| [[model.generation_config.decoder_start_token_id]], | |
| dtype=torch.int64, | |
| device=device, | |
| ) | |
| else: | |
| start_ids = torch.as_tensor([input_ids], device=device) | |
| past_key_values = out = None | |
| token_logprobs = [None] # The first token has no logprobs. | |
| sent_interrupt = False | |
| finish_reason = None | |
| stopped = False | |
| for i in range(max_new_tokens): | |
| if i == 0: # prefill | |
| if model.config.is_encoder_decoder: | |
| out = model.decoder( | |
| input_ids=start_ids, | |
| encoder_hidden_states=encoder_output, | |
| use_cache=True, | |
| ) | |
| logits = model.lm_head(out[0]) | |
| else: | |
| out = model(input_ids=start_ids, use_cache=True) | |
| logits = out.logits | |
| past_key_values = out.past_key_values | |
| if logprobs is not None: | |
| # Prefull logprobs for the prompt. | |
| shift_input_ids = start_ids[..., 1:].contiguous() | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() | |
| for label_id, logit in zip( | |
| shift_input_ids[0].tolist(), shift_logits[0] | |
| ): | |
| token_logprobs.append(logit[label_id]) | |
| else: # decoding | |
| if model.config.is_encoder_decoder: | |
| out = model.decoder( | |
| input_ids=torch.as_tensor( | |
| [[token] if not sent_interrupt else output_ids], | |
| device=device, | |
| ), | |
| encoder_hidden_states=encoder_output, | |
| use_cache=True, | |
| past_key_values=past_key_values if not sent_interrupt else None, | |
| ) | |
| sent_interrupt = False | |
| logits = model.lm_head(out[0]) | |
| else: | |
| out = model( | |
| input_ids=torch.as_tensor( | |
| [[token] if not sent_interrupt else output_ids], | |
| device=device, | |
| ), | |
| use_cache=True, | |
| past_key_values=past_key_values if not sent_interrupt else None, | |
| ) | |
| sent_interrupt = False | |
| logits = out.logits | |
| past_key_values = out.past_key_values | |
| if logits_processor: | |
| if repetition_penalty > 1.0: | |
| tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) | |
| else: | |
| tmp_output_ids = None | |
| last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] | |
| else: | |
| last_token_logits = logits[0, -1, :] | |
| if device == "mps": | |
| # Switch to CPU by avoiding some bugs in mps backend. | |
| last_token_logits = last_token_logits.float().to("cpu") | |
| if temperature < 1e-5 or top_p < 1e-8: # greedy | |
| _, indices = torch.topk(last_token_logits, 2) | |
| tokens = [int(index) for index in indices.tolist()] | |
| else: | |
| probs = torch.softmax(last_token_logits, dim=-1) | |
| indices = torch.multinomial(probs, num_samples=2) | |
| tokens = [int(token) for token in indices.tolist()] | |
| token = tokens[0] | |
| output_ids.append(token) | |
| if logprobs is not None: | |
| # Cannot use last_token_logits because logprobs is based on raw logits. | |
| token_logprobs.append( | |
| torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() | |
| ) | |
| if token in stop_token_ids: | |
| stopped = True | |
| else: | |
| stopped = False | |
| # Yield the output tokens | |
| if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: | |
| if echo: | |
| tmp_output_ids = output_ids | |
| rfind_start = len_prompt | |
| else: | |
| tmp_output_ids = output_ids[input_echo_len:] | |
| rfind_start = 0 | |
| output = tokenizer.decode( | |
| tmp_output_ids, | |
| skip_special_tokens=True, | |
| spaces_between_special_tokens=False, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| ret_logprobs = None | |
| if logprobs is not None: | |
| ret_logprobs = { | |
| "text_offset": [], | |
| "tokens": [ | |
| tokenizer.decode(token) | |
| for token in ( | |
| output_ids if echo else output_ids[input_echo_len:] | |
| ) | |
| ], | |
| "token_logprobs": token_logprobs | |
| if echo | |
| else token_logprobs[input_echo_len:], | |
| "top_logprobs": [{}] | |
| * len(token_logprobs if echo else token_logprobs[input_echo_len:]), | |
| } | |
| # Compute text_offset | |
| curr_pos = 0 | |
| for text in ret_logprobs["tokens"]: | |
| ret_logprobs["text_offset"].append(curr_pos) | |
| curr_pos += len(text) | |
| # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way | |
| if judge_sent_end and stopped and not is_sentence_complete(output): | |
| if len(tokens) > 1: | |
| token = tokens[1] | |
| output_ids[-1] = token | |
| else: | |
| output_ids.pop() | |
| stopped = False | |
| sent_interrupt = True | |
| partially_stopped = False | |
| if stop_str: | |
| if isinstance(stop_str, str): | |
| pos = output.rfind(stop_str, rfind_start) | |
| if pos != -1: | |
| output = output[:pos] | |
| stopped = True | |
| else: | |
| partially_stopped = is_partial_stop(output, stop_str) | |
| elif isinstance(stop_str, Iterable): | |
| for each_stop in stop_str: | |
| pos = output.rfind(each_stop, rfind_start) | |
| if pos != -1: | |
| output = output[:pos] | |
| stopped = True | |
| break | |
| else: | |
| partially_stopped = is_partial_stop(output, each_stop) | |
| if partially_stopped: | |
| break | |
| else: | |
| raise ValueError("Invalid stop field type.") | |
| # Prevent yielding partial stop sequence | |
| if not partially_stopped: | |
| yield { | |
| "text": output, | |
| "logprobs": ret_logprobs, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| "finish_reason": None, | |
| } | |
| if stopped: | |
| break | |
| # Finish stream event, which contains finish reason | |
| else: | |
| finish_reason = "length" | |
| if stopped: | |
| finish_reason = "stop" | |
| yield { | |
| "text": output, | |
| "logprobs": ret_logprobs, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| "finish_reason": finish_reason, | |
| } | |
| # Clean | |
| del past_key_values, out | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if device == "xpu": | |
| torch.xpu.empty_cache() | |
| if device == "npu": | |
| torch.npu.empty_cache() | |
| class ChatIO(abc.ABC): | |
| def prompt_for_input(self, role: str) -> str: | |
| """Prompt for input from a role.""" | |
| def prompt_for_output(self, role: str): | |
| """Prompt for output from a role.""" | |
| def stream_output(self, output_stream): | |
| """Stream output.""" | |
| def print_output(self, text: str): | |
| """Print output.""" | |
| def chat_loop( | |
| model_path: str, | |
| device: str, | |
| num_gpus: int, | |
| max_gpu_memory: str, | |
| dtype: Optional[torch.dtype], | |
| load_8bit: bool, | |
| cpu_offloading: bool, | |
| conv_template: Optional[str], | |
| conv_system_msg: Optional[str], | |
| temperature: float, | |
| repetition_penalty: float, | |
| max_new_tokens: int, | |
| chatio: ChatIO, | |
| gptq_config: Optional[GptqConfig] = None, | |
| awq_config: Optional[AWQConfig] = None, | |
| exllama_config: Optional[ExllamaConfig] = None, | |
| xft_config: Optional[XftConfig] = None, | |
| revision: str = "main", | |
| judge_sent_end: bool = True, | |
| debug: bool = True, | |
| history: bool = True, | |
| ): | |
| # Model | |
| model, tokenizer = load_model( | |
| model_path, | |
| device=device, | |
| num_gpus=num_gpus, | |
| max_gpu_memory=max_gpu_memory, | |
| dtype=dtype, | |
| load_8bit=load_8bit, | |
| cpu_offloading=cpu_offloading, | |
| gptq_config=gptq_config, | |
| awq_config=awq_config, | |
| exllama_config=exllama_config, | |
| xft_config=xft_config, | |
| revision=revision, | |
| debug=debug, | |
| ) | |
| generate_stream_func = get_generate_stream_function(model, model_path) | |
| model_type = str(type(model)).lower() | |
| is_t5 = "t5" in model_type | |
| is_codet5p = "codet5p" in model_type | |
| is_xft = "xft" in model_type | |
| # Hardcode T5's default repetition penalty to be 1.2 | |
| if is_t5 and repetition_penalty == 1.0: | |
| repetition_penalty = 1.2 | |
| # Set context length | |
| context_len = get_context_length(model.config) | |
| # Chat | |
| def new_chat(): | |
| if conv_template: | |
| conv = get_conv_template(conv_template) | |
| else: | |
| conv = get_conversation_template(model_path) | |
| if conv_system_msg is not None: | |
| conv.set_system_message(conv_system_msg) | |
| return conv | |
| def reload_conv(conv): | |
| """ | |
| Reprints the conversation from the start. | |
| """ | |
| for message in conv.messages[conv.offset :]: | |
| chatio.prompt_for_output(message[0]) | |
| chatio.print_output(message[1]) | |
| conv = None | |
| while True: | |
| if not history or not conv: | |
| conv = new_chat() | |
| try: | |
| inp = chatio.prompt_for_input(conv.roles[0]) | |
| except EOFError: | |
| inp = "" | |
| if inp == "!!exit" or not inp: | |
| print("exit...") | |
| break | |
| elif inp == "!!reset": | |
| print("resetting...") | |
| conv = new_chat() | |
| continue | |
| elif inp == "!!remove": | |
| print("removing last message...") | |
| if len(conv.messages) > conv.offset: | |
| # Assistant | |
| if conv.messages[-1][0] == conv.roles[1]: | |
| conv.messages.pop() | |
| # User | |
| if conv.messages[-1][0] == conv.roles[0]: | |
| conv.messages.pop() | |
| reload_conv(conv) | |
| else: | |
| print("No messages to remove.") | |
| continue | |
| elif inp == "!!regen": | |
| print("regenerating last message...") | |
| if len(conv.messages) > conv.offset: | |
| # Assistant | |
| if conv.messages[-1][0] == conv.roles[1]: | |
| conv.messages.pop() | |
| # User | |
| if conv.messages[-1][0] == conv.roles[0]: | |
| reload_conv(conv) | |
| # Set inp to previous message | |
| inp = conv.messages.pop()[1] | |
| else: | |
| # Shouldn't happen in normal circumstances | |
| print("No user message to regenerate from.") | |
| continue | |
| else: | |
| print("No messages to regenerate.") | |
| continue | |
| elif inp.startswith("!!save"): | |
| args = inp.split(" ", 1) | |
| if len(args) != 2: | |
| print("usage: !!save <filename>") | |
| continue | |
| else: | |
| filename = args[1] | |
| # Add .json if extension not present | |
| if not "." in filename: | |
| filename += ".json" | |
| print("saving...", filename) | |
| with open(filename, "w") as outfile: | |
| json.dump(conv.dict(), outfile) | |
| continue | |
| elif inp.startswith("!!load"): | |
| args = inp.split(" ", 1) | |
| if len(args) != 2: | |
| print("usage: !!load <filename>") | |
| continue | |
| else: | |
| filename = args[1] | |
| # Check if file exists and add .json if needed | |
| if not os.path.exists(filename): | |
| if (not filename.endswith(".json")) and os.path.exists( | |
| filename + ".json" | |
| ): | |
| filename += ".json" | |
| else: | |
| print("file not found:", filename) | |
| continue | |
| print("loading...", filename) | |
| with open(filename, "r") as infile: | |
| new_conv = json.load(infile) | |
| conv = get_conv_template(new_conv["template_name"]) | |
| conv.set_system_message(new_conv["system_message"]) | |
| conv.messages = new_conv["messages"] | |
| reload_conv(conv) | |
| continue | |
| conv.append_message(conv.roles[0], inp) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| if is_codet5p: # codet5p is a code completion model. | |
| prompt = inp | |
| gen_params = { | |
| "model": model_path, | |
| "prompt": prompt, | |
| "temperature": temperature, | |
| "repetition_penalty": repetition_penalty, | |
| "max_new_tokens": max_new_tokens, | |
| "stop": conv.stop_str, | |
| "stop_token_ids": conv.stop_token_ids, | |
| "echo": False, | |
| } | |
| try: | |
| chatio.prompt_for_output(conv.roles[1]) | |
| output_stream = generate_stream_func( | |
| model, | |
| tokenizer, | |
| gen_params, | |
| device, | |
| context_len=context_len, | |
| judge_sent_end=judge_sent_end, | |
| ) | |
| t = time.time() | |
| outputs = chatio.stream_output(output_stream) | |
| duration = time.time() - t | |
| conv.update_last_message(outputs.strip()) | |
| if debug: | |
| num_tokens = len(tokenizer.encode(outputs)) | |
| msg = { | |
| "conv_template": conv.name, | |
| "prompt": prompt, | |
| "outputs": outputs, | |
| "speed (token/s)": round(num_tokens / duration, 2), | |
| } | |
| print(f"\n{msg}\n") | |
| except KeyboardInterrupt: | |
| print("stopped generation.") | |
| # If generation didn't finish | |
| if conv.messages[-1][1] is None: | |
| conv.messages.pop() | |
| # Remove last user message, so there isn't a double up | |
| if conv.messages[-1][0] == conv.roles[0]: | |
| conv.messages.pop() | |
| reload_conv(conv) | |