Spaces:
Runtime error
Runtime error
| """ | |
| Chat with a model with command line interface. | |
| Usage: | |
| python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 | |
| python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 | |
| Other commands: | |
| - Type "!!exit" or an empty line to exit. | |
| - Type "!!reset" to start a new conversation. | |
| - Type "!!remove" to remove the last prompt. | |
| - Type "!!regen" to regenerate the last message. | |
| - Type "!!save <filename>" to save the conversation history to a json file. | |
| - Type "!!load <filename>" to load a conversation history from a json file. | |
| """ | |
| import argparse | |
| import os | |
| import re | |
| import sys | |
| from prompt_toolkit import PromptSession | |
| from prompt_toolkit.auto_suggest import AutoSuggestFromHistory | |
| from prompt_toolkit.completion import WordCompleter | |
| from prompt_toolkit.history import InMemoryHistory | |
| from prompt_toolkit.key_binding import KeyBindings | |
| from rich.console import Console | |
| from rich.live import Live | |
| from rich.markdown import Markdown | |
| import torch | |
| from fastchat.model.model_adapter import add_model_args | |
| from fastchat.modules.awq import AWQConfig | |
| from fastchat.modules.exllama import ExllamaConfig | |
| from fastchat.modules.xfastertransformer import XftConfig | |
| from fastchat.modules.gptq import GptqConfig | |
| from fastchat.serve.inference import ChatIO, chat_loop | |
| from fastchat.utils import str_to_torch_dtype | |
| class SimpleChatIO(ChatIO): | |
| def __init__(self, multiline: bool = False, prefix: str = ''): | |
| self._multiline = multiline | |
| self.prefix = prefix | |
| def prompt_for_input(self, role) -> str: | |
| if not self._multiline: | |
| return input(f"{role}: {self.prefix}") | |
| prompt_data = [] | |
| line = input(f"{role} [ctrl-d/z on empty line to end]: ") | |
| while True: | |
| prompt_data.append(line.strip()) | |
| try: | |
| line = input() | |
| except EOFError as e: | |
| break | |
| return f"\n{self.prefix}".join(prompt_data) | |
| def prompt_for_output(self, role: str): | |
| print(f"{role}: ", end="", flush=True) | |
| def stream_output(self, output_stream): | |
| pre = 0 | |
| for outputs in output_stream: | |
| output_text = outputs["text"] | |
| output_text = output_text.strip().split(" ") | |
| now = len(output_text) - 1 | |
| if now > pre: | |
| print(" ".join(output_text[pre:now]), end=" ", flush=True) | |
| pre = now | |
| print(" ".join(output_text[pre:]), flush=True) | |
| return " ".join(output_text) | |
| def print_output(self, text: str): | |
| print(text) | |
| class RichChatIO(ChatIO): | |
| bindings = KeyBindings() | |
| def _(event): | |
| event.app.current_buffer.newline() | |
| def __init__(self, multiline: bool = False, mouse: bool = False): | |
| self._prompt_session = PromptSession(history=InMemoryHistory()) | |
| self._completer = WordCompleter( | |
| words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], | |
| pattern=re.compile("$"), | |
| ) | |
| self._console = Console() | |
| self._multiline = multiline | |
| self._mouse = mouse | |
| def prompt_for_input(self, role) -> str: | |
| self._console.print(f"[bold]{role}:") | |
| # TODO(suquark): multiline input has some issues. fix it later. | |
| prompt_input = self._prompt_session.prompt( | |
| completer=self._completer, | |
| multiline=False, | |
| mouse_support=self._mouse, | |
| auto_suggest=AutoSuggestFromHistory(), | |
| key_bindings=self.bindings if self._multiline else None, | |
| ) | |
| self._console.print() | |
| return prompt_input | |
| def prompt_for_output(self, role: str): | |
| self._console.print(f"[bold]{role.replace('/', '|')}:") | |
| def stream_output(self, output_stream): | |
| """Stream output from a role.""" | |
| # TODO(suquark): the console flickers when there is a code block | |
| # above it. We need to cut off "live" when a code block is done. | |
| # Create a Live context for updating the console output | |
| with Live(console=self._console, refresh_per_second=4) as live: | |
| # Read lines from the stream | |
| for outputs in output_stream: | |
| if not outputs: | |
| continue | |
| text = outputs["text"] | |
| # Render the accumulated text as Markdown | |
| # NOTE: this is a workaround for the rendering "unstandard markdown" | |
| # in rich. The chatbots output treat "\n" as a new line for | |
| # better compatibility with real-world text. However, rendering | |
| # in markdown would break the format. It is because standard markdown | |
| # treat a single "\n" in normal text as a space. | |
| # Our workaround is adding two spaces at the end of each line. | |
| # This is not a perfect solution, as it would | |
| # introduce trailing spaces (only) in code block, but it works well | |
| # especially for console output, because in general the console does not | |
| # care about trailing spaces. | |
| lines = [] | |
| for line in text.splitlines(): | |
| lines.append(line) | |
| if line.startswith("```"): | |
| # Code block marker - do not add trailing spaces, as it would | |
| # break the syntax highlighting | |
| lines.append("\n") | |
| else: | |
| lines.append(" \n") | |
| markdown = Markdown("".join(lines)) | |
| # Update the Live console output | |
| live.update(markdown) | |
| self._console.print() | |
| return text | |
| def print_output(self, text: str): | |
| self.stream_output([{"text": text}]) | |
| class ProgrammaticChatIO(ChatIO): | |
| def prompt_for_input(self, role) -> str: | |
| contents = "" | |
| # `end_sequence` signals the end of a message. It is unlikely to occur in | |
| # message content. | |
| end_sequence = " __END_OF_A_MESSAGE_47582648__\n" | |
| len_end = len(end_sequence) | |
| while True: | |
| if len(contents) >= len_end: | |
| last_chars = contents[-len_end:] | |
| if last_chars == end_sequence: | |
| break | |
| try: | |
| char = sys.stdin.read(1) | |
| contents = contents + char | |
| except EOFError: | |
| continue | |
| contents = contents[:-len_end] | |
| print(f"[!OP:{role}]: {contents}", flush=True) | |
| return contents | |
| def prompt_for_output(self, role: str): | |
| print(f"[!OP:{role}]: ", end="", flush=True) | |
| def stream_output(self, output_stream): | |
| pre = 0 | |
| for outputs in output_stream: | |
| output_text = outputs["text"] | |
| output_text = output_text.strip().split(" ") | |
| now = len(output_text) - 1 | |
| if now > pre: | |
| print(" ".join(output_text[pre:now]), end=" ", flush=True) | |
| pre = now | |
| print(" ".join(output_text[pre:]), flush=True) | |
| return " ".join(output_text) | |
| def print_output(self, text: str): | |
| print(text) | |
| def main(args): | |
| if args.gpus: | |
| if len(args.gpus.split(",")) < args.num_gpus: | |
| raise ValueError( | |
| f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" | |
| ) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus | |
| os.environ["XPU_VISIBLE_DEVICES"] = args.gpus | |
| if args.enable_exllama: | |
| exllama_config = ExllamaConfig( | |
| max_seq_len=args.exllama_max_seq_len, | |
| gpu_split=args.exllama_gpu_split, | |
| ) | |
| else: | |
| exllama_config = None | |
| if args.enable_xft: | |
| xft_config = XftConfig( | |
| max_seq_len=args.xft_max_seq_len, | |
| data_type=args.xft_dtype, | |
| ) | |
| if args.device != "cpu": | |
| print("xFasterTransformer now is only support CPUs. Reset device to CPU") | |
| args.device = "cpu" | |
| else: | |
| xft_config = None | |
| if args.style == "simple": | |
| chatio = SimpleChatIO(args.multiline) | |
| elif args.style == "rich": | |
| chatio = RichChatIO(args.multiline, args.mouse) | |
| elif args.style == "programmatic": | |
| chatio = ProgrammaticChatIO() | |
| else: | |
| raise ValueError(f"Invalid style for console: {args.style}") | |
| try: | |
| if args.upload_file_path: | |
| prefix = open(args.upload_file_path, 'r').read() | |
| args.conv_system_msg = prefix[:20000] | |
| chat_loop( | |
| args.model_path, | |
| args.device, | |
| args.num_gpus, | |
| args.max_gpu_memory, | |
| str_to_torch_dtype(args.dtype), | |
| args.load_8bit, | |
| args.cpu_offloading, | |
| args.conv_template, | |
| args.conv_system_msg, | |
| args.temperature, | |
| args.repetition_penalty, | |
| args.max_new_tokens, | |
| chatio, | |
| gptq_config=GptqConfig( | |
| ckpt=args.gptq_ckpt or args.model_path, | |
| wbits=args.gptq_wbits, | |
| groupsize=args.gptq_groupsize, | |
| act_order=args.gptq_act_order, | |
| ), | |
| awq_config=AWQConfig( | |
| ckpt=args.awq_ckpt or args.model_path, | |
| wbits=args.awq_wbits, | |
| groupsize=args.awq_groupsize, | |
| ), | |
| exllama_config=exllama_config, | |
| xft_config=xft_config, | |
| revision=args.revision, | |
| judge_sent_end=args.judge_sent_end, | |
| debug=args.debug, | |
| history=not args.no_history, | |
| ) | |
| except KeyboardInterrupt: | |
| print("exit...") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| add_model_args(parser) | |
| parser.add_argument( | |
| "--conv-template", type=str, default=None, help="Conversation prompt template." | |
| ) | |
| parser.add_argument( | |
| "--conv-system-msg", type=str, default=None, help="Conversation system message." | |
| ) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--repetition_penalty", type=float, default=1.0) | |
| parser.add_argument("--max-new-tokens", type=int, default=512) | |
| parser.add_argument("--no-history", action="store_true") | |
| parser.add_argument( | |
| "--style", | |
| type=str, | |
| default="simple", | |
| choices=["simple", "rich", "programmatic"], | |
| help="Display style.", | |
| ) | |
| parser.add_argument( | |
| "--multiline", | |
| action="store_true", | |
| help="Enable multiline input. Use ESC+Enter for newline.", | |
| ) | |
| parser.add_argument( | |
| "--mouse", | |
| action="store_true", | |
| help="[Rich Style]: Enable mouse support for cursor positioning.", | |
| ) | |
| parser.add_argument( | |
| "--judge-sent-end", | |
| action="store_true", | |
| help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", | |
| ) | |
| parser.add_argument( | |
| "--debug", | |
| action="store_true", | |
| help="Print useful debug information (e.g., prompts)", | |
| ) | |
| parser.add_argument( | |
| "--upload-file-path", | |
| type=str, | |
| default="", | |
| help="upload long txt for summary.", | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |