Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 7 files
Browse files- llava/serve/__init__.py +0 -0
- llava/serve/cli.py +125 -0
- llava/serve/controller.py +298 -0
- llava/serve/gradio_web_server.py +420 -0
- llava/serve/model_worker.py +285 -0
- llava/serve/register_worker.py +26 -0
- llava/serve/test_message.py +62 -0
    	
        llava/serve/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        llava/serve/cli.py
    ADDED
    
    | @@ -0,0 +1,125 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 5 | 
            +
            from llava.conversation import conv_templates, SeparatorStyle
         | 
| 6 | 
            +
            from llava.model.builder import load_pretrained_model
         | 
| 7 | 
            +
            from llava.utils import disable_torch_init
         | 
| 8 | 
            +
            from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import requests
         | 
| 13 | 
            +
            from PIL import Image
         | 
| 14 | 
            +
            from io import BytesIO
         | 
| 15 | 
            +
            from transformers import TextStreamer
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def load_image(image_file):
         | 
| 19 | 
            +
                if image_file.startswith('http://') or image_file.startswith('https://'):
         | 
| 20 | 
            +
                    response = requests.get(image_file)
         | 
| 21 | 
            +
                    image = Image.open(BytesIO(response.content)).convert('RGB')
         | 
| 22 | 
            +
                else:
         | 
| 23 | 
            +
                    image = Image.open(image_file).convert('RGB')
         | 
| 24 | 
            +
                return image
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def main(args):
         | 
| 28 | 
            +
                # Model
         | 
| 29 | 
            +
                disable_torch_init()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                model_name = get_model_name_from_path(args.model_path)
         | 
| 32 | 
            +
                tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                if 'llama-2' in model_name.lower():
         | 
| 35 | 
            +
                    conv_mode = "llava_llama_2"
         | 
| 36 | 
            +
                elif "v1" in model_name.lower():
         | 
| 37 | 
            +
                    conv_mode = "llava_v1"
         | 
| 38 | 
            +
                elif "mpt" in model_name.lower():
         | 
| 39 | 
            +
                    conv_mode = "mpt"
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    conv_mode = "llava_v0"
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                if args.conv_mode is not None and conv_mode != args.conv_mode:
         | 
| 44 | 
            +
                    print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
         | 
| 45 | 
            +
                else:
         | 
| 46 | 
            +
                    args.conv_mode = conv_mode
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                conv = conv_templates[args.conv_mode].copy()
         | 
| 49 | 
            +
                if "mpt" in model_name.lower():
         | 
| 50 | 
            +
                    roles = ('user', 'assistant')
         | 
| 51 | 
            +
                else:
         | 
| 52 | 
            +
                    roles = conv.roles
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                image = load_image(args.image_file)
         | 
| 55 | 
            +
                # Similar operation in model_worker.py
         | 
| 56 | 
            +
                image_tensor = process_images([image], image_processor, args)
         | 
| 57 | 
            +
                if type(image_tensor) is list:
         | 
| 58 | 
            +
                    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    image_tensor = image_tensor.to(model.device, dtype=torch.float16)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                while True:
         | 
| 63 | 
            +
                    try:
         | 
| 64 | 
            +
                        inp = input(f"{roles[0]}: ")
         | 
| 65 | 
            +
                    except EOFError:
         | 
| 66 | 
            +
                        inp = ""
         | 
| 67 | 
            +
                    if not inp:
         | 
| 68 | 
            +
                        print("exit...")
         | 
| 69 | 
            +
                        break
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    print(f"{roles[1]}: ", end="")
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if image is not None:
         | 
| 74 | 
            +
                        # first message
         | 
| 75 | 
            +
                        if model.config.mm_use_im_start_end:
         | 
| 76 | 
            +
                            inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
         | 
| 77 | 
            +
                        else:
         | 
| 78 | 
            +
                            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
         | 
| 79 | 
            +
                        conv.append_message(conv.roles[0], inp)
         | 
| 80 | 
            +
                        image = None
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        # later messages
         | 
| 83 | 
            +
                        conv.append_message(conv.roles[0], inp)
         | 
| 84 | 
            +
                    conv.append_message(conv.roles[1], None)
         | 
| 85 | 
            +
                    prompt = conv.get_prompt()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
         | 
| 88 | 
            +
                    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
         | 
| 89 | 
            +
                    keywords = [stop_str]
         | 
| 90 | 
            +
                    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
         | 
| 91 | 
            +
                    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    with torch.inference_mode():
         | 
| 94 | 
            +
                        output_ids = model.generate(
         | 
| 95 | 
            +
                            input_ids,
         | 
| 96 | 
            +
                            images=image_tensor,
         | 
| 97 | 
            +
                            do_sample=True,
         | 
| 98 | 
            +
                            temperature=args.temperature,
         | 
| 99 | 
            +
                            max_new_tokens=args.max_new_tokens,
         | 
| 100 | 
            +
                            streamer=streamer,
         | 
| 101 | 
            +
                            use_cache=True,
         | 
| 102 | 
            +
                            stopping_criteria=[stopping_criteria])
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
         | 
| 105 | 
            +
                    conv.messages[-1][-1] = outputs
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    if args.debug:
         | 
| 108 | 
            +
                        print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            if __name__ == "__main__":
         | 
| 112 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 113 | 
            +
                parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
         | 
| 114 | 
            +
                parser.add_argument("--model-base", type=str, default=None)
         | 
| 115 | 
            +
                parser.add_argument("--image-file", type=str, required=True)
         | 
| 116 | 
            +
                parser.add_argument("--device", type=str, default="cuda")
         | 
| 117 | 
            +
                parser.add_argument("--conv-mode", type=str, default=None)
         | 
| 118 | 
            +
                parser.add_argument("--temperature", type=float, default=0.2)
         | 
| 119 | 
            +
                parser.add_argument("--max-new-tokens", type=int, default=512)
         | 
| 120 | 
            +
                parser.add_argument("--load-8bit", action="store_true")
         | 
| 121 | 
            +
                parser.add_argument("--load-4bit", action="store_true")
         | 
| 122 | 
            +
                parser.add_argument("--debug", action="store_true")
         | 
| 123 | 
            +
                parser.add_argument("--image-aspect-ratio", type=str, default='pad')
         | 
| 124 | 
            +
                args = parser.parse_args()
         | 
| 125 | 
            +
                main(args)
         | 
    	
        llava/serve/controller.py
    ADDED
    
    | @@ -0,0 +1,298 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            A controller manages distributed workers.
         | 
| 3 | 
            +
            It sends worker addresses to clients.
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import asyncio
         | 
| 7 | 
            +
            import dataclasses
         | 
| 8 | 
            +
            from enum import Enum, auto
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import time
         | 
| 12 | 
            +
            from typing import List, Union
         | 
| 13 | 
            +
            import threading
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from fastapi import FastAPI, Request
         | 
| 16 | 
            +
            from fastapi.responses import StreamingResponse
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import requests
         | 
| 19 | 
            +
            import uvicorn
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
         | 
| 22 | 
            +
            from llava.utils import build_logger, server_error_msg
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            logger = build_logger("controller", "controller.log")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class DispatchMethod(Enum):
         | 
| 29 | 
            +
                LOTTERY = auto()
         | 
| 30 | 
            +
                SHORTEST_QUEUE = auto()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                @classmethod
         | 
| 33 | 
            +
                def from_str(cls, name):
         | 
| 34 | 
            +
                    if name == "lottery":
         | 
| 35 | 
            +
                        return cls.LOTTERY
         | 
| 36 | 
            +
                    elif name == "shortest_queue":
         | 
| 37 | 
            +
                        return cls.SHORTEST_QUEUE
         | 
| 38 | 
            +
                    else:
         | 
| 39 | 
            +
                        raise ValueError(f"Invalid dispatch method")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            @dataclasses.dataclass
         | 
| 43 | 
            +
            class WorkerInfo:
         | 
| 44 | 
            +
                model_names: List[str]
         | 
| 45 | 
            +
                speed: int
         | 
| 46 | 
            +
                queue_length: int
         | 
| 47 | 
            +
                check_heart_beat: bool
         | 
| 48 | 
            +
                last_heart_beat: str
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def heart_beat_controller(controller):
         | 
| 52 | 
            +
                while True:
         | 
| 53 | 
            +
                    time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
         | 
| 54 | 
            +
                    controller.remove_stable_workers_by_expiration()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Controller:
         | 
| 58 | 
            +
                def __init__(self, dispatch_method: str):
         | 
| 59 | 
            +
                    # Dict[str -> WorkerInfo]
         | 
| 60 | 
            +
                    self.worker_info = {}
         | 
| 61 | 
            +
                    self.dispatch_method = DispatchMethod.from_str(dispatch_method)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.heart_beat_thread = threading.Thread(
         | 
| 64 | 
            +
                        target=heart_beat_controller, args=(self,))
         | 
| 65 | 
            +
                    self.heart_beat_thread.start()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    logger.info("Init controller")
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def register_worker(self, worker_name: str, check_heart_beat: bool,
         | 
| 70 | 
            +
                                    worker_status: dict):
         | 
| 71 | 
            +
                    if worker_name not in self.worker_info:
         | 
| 72 | 
            +
                        logger.info(f"Register a new worker: {worker_name}")
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        logger.info(f"Register an existing worker: {worker_name}")
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    if not worker_status:
         | 
| 77 | 
            +
                        worker_status = self.get_worker_status(worker_name)
         | 
| 78 | 
            +
                    if not worker_status:
         | 
| 79 | 
            +
                        return False
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    self.worker_info[worker_name] = WorkerInfo(
         | 
| 82 | 
            +
                        worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
         | 
| 83 | 
            +
                        check_heart_beat, time.time())
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    logger.info(f"Register done: {worker_name}, {worker_status}")
         | 
| 86 | 
            +
                    return True
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def get_worker_status(self, worker_name: str):
         | 
| 89 | 
            +
                    try:
         | 
| 90 | 
            +
                        r = requests.post(worker_name + "/worker_get_status", timeout=5)
         | 
| 91 | 
            +
                    except requests.exceptions.RequestException as e:
         | 
| 92 | 
            +
                        logger.error(f"Get status fails: {worker_name}, {e}")
         | 
| 93 | 
            +
                        return None
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    if r.status_code != 200:
         | 
| 96 | 
            +
                        logger.error(f"Get status fails: {worker_name}, {r}")
         | 
| 97 | 
            +
                        return None
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    return r.json()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def remove_worker(self, worker_name: str):
         | 
| 102 | 
            +
                    del self.worker_info[worker_name]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def refresh_all_workers(self):
         | 
| 105 | 
            +
                    old_info = dict(self.worker_info)
         | 
| 106 | 
            +
                    self.worker_info = {}
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    for w_name, w_info in old_info.items():
         | 
| 109 | 
            +
                        if not self.register_worker(w_name, w_info.check_heart_beat, None):
         | 
| 110 | 
            +
                            logger.info(f"Remove stale worker: {w_name}")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def list_models(self):
         | 
| 113 | 
            +
                    model_names = set()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    for w_name, w_info in self.worker_info.items():
         | 
| 116 | 
            +
                        model_names.update(w_info.model_names)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    return list(model_names)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def get_worker_address(self, model_name: str):
         | 
| 121 | 
            +
                    if self.dispatch_method == DispatchMethod.LOTTERY:
         | 
| 122 | 
            +
                        worker_names = []
         | 
| 123 | 
            +
                        worker_speeds = []
         | 
| 124 | 
            +
                        for w_name, w_info in self.worker_info.items():
         | 
| 125 | 
            +
                            if model_name in w_info.model_names:
         | 
| 126 | 
            +
                                worker_names.append(w_name)
         | 
| 127 | 
            +
                                worker_speeds.append(w_info.speed)
         | 
| 128 | 
            +
                        worker_speeds = np.array(worker_speeds, dtype=np.float32)
         | 
| 129 | 
            +
                        norm = np.sum(worker_speeds)
         | 
| 130 | 
            +
                        if norm < 1e-4:
         | 
| 131 | 
            +
                            return ""
         | 
| 132 | 
            +
                        worker_speeds = worker_speeds / norm
         | 
| 133 | 
            +
                        if True:  # Directly return address
         | 
| 134 | 
            +
                            pt = np.random.choice(np.arange(len(worker_names)),
         | 
| 135 | 
            +
                                p=worker_speeds)
         | 
| 136 | 
            +
                            worker_name = worker_names[pt]
         | 
| 137 | 
            +
                            return worker_name
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        # Check status before returning
         | 
| 140 | 
            +
                        while True:
         | 
| 141 | 
            +
                            pt = np.random.choice(np.arange(len(worker_names)),
         | 
| 142 | 
            +
                                p=worker_speeds)
         | 
| 143 | 
            +
                            worker_name = worker_names[pt]
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                            if self.get_worker_status(worker_name):
         | 
| 146 | 
            +
                                break
         | 
| 147 | 
            +
                            else:
         | 
| 148 | 
            +
                                self.remove_worker(worker_name)
         | 
| 149 | 
            +
                                worker_speeds[pt] = 0
         | 
| 150 | 
            +
                                norm = np.sum(worker_speeds)
         | 
| 151 | 
            +
                                if norm < 1e-4:
         | 
| 152 | 
            +
                                    return ""
         | 
| 153 | 
            +
                                worker_speeds = worker_speeds / norm
         | 
| 154 | 
            +
                                continue
         | 
| 155 | 
            +
                        return worker_name
         | 
| 156 | 
            +
                    elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
         | 
| 157 | 
            +
                        worker_names = []
         | 
| 158 | 
            +
                        worker_qlen = []
         | 
| 159 | 
            +
                        for w_name, w_info in self.worker_info.items():
         | 
| 160 | 
            +
                            if model_name in w_info.model_names:
         | 
| 161 | 
            +
                                worker_names.append(w_name)
         | 
| 162 | 
            +
                                worker_qlen.append(w_info.queue_length / w_info.speed)
         | 
| 163 | 
            +
                        if len(worker_names) == 0:
         | 
| 164 | 
            +
                            return ""
         | 
| 165 | 
            +
                        min_index = np.argmin(worker_qlen)
         | 
| 166 | 
            +
                        w_name = worker_names[min_index]
         | 
| 167 | 
            +
                        self.worker_info[w_name].queue_length += 1
         | 
| 168 | 
            +
                        logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
         | 
| 169 | 
            +
                        return w_name
         | 
| 170 | 
            +
                    else:
         | 
| 171 | 
            +
                        raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def receive_heart_beat(self, worker_name: str, queue_length: int):
         | 
| 174 | 
            +
                    if worker_name not in self.worker_info:
         | 
| 175 | 
            +
                        logger.info(f"Receive unknown heart beat. {worker_name}")
         | 
| 176 | 
            +
                        return False
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    self.worker_info[worker_name].queue_length = queue_length
         | 
| 179 | 
            +
                    self.worker_info[worker_name].last_heart_beat = time.time()
         | 
| 180 | 
            +
                    logger.info(f"Receive heart beat. {worker_name}")
         | 
| 181 | 
            +
                    return True
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def remove_stable_workers_by_expiration(self):
         | 
| 184 | 
            +
                    expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
         | 
| 185 | 
            +
                    to_delete = []
         | 
| 186 | 
            +
                    for worker_name, w_info in self.worker_info.items():
         | 
| 187 | 
            +
                        if w_info.check_heart_beat and w_info.last_heart_beat < expire:
         | 
| 188 | 
            +
                            to_delete.append(worker_name)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    for worker_name in to_delete:
         | 
| 191 | 
            +
                        self.remove_worker(worker_name)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def worker_api_generate_stream(self, params):
         | 
| 194 | 
            +
                    worker_addr = self.get_worker_address(params["model"])
         | 
| 195 | 
            +
                    if not worker_addr:
         | 
| 196 | 
            +
                        logger.info(f"no worker: {params['model']}")
         | 
| 197 | 
            +
                        ret = {
         | 
| 198 | 
            +
                            "text": server_error_msg,
         | 
| 199 | 
            +
                            "error_code": 2,
         | 
| 200 | 
            +
                        }
         | 
| 201 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    try:
         | 
| 204 | 
            +
                        response = requests.post(worker_addr + "/worker_generate_stream",
         | 
| 205 | 
            +
                            json=params, stream=True, timeout=5)
         | 
| 206 | 
            +
                        for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
         | 
| 207 | 
            +
                            if chunk:
         | 
| 208 | 
            +
                                yield chunk + b"\0"
         | 
| 209 | 
            +
                    except requests.exceptions.RequestException as e:
         | 
| 210 | 
            +
                        logger.info(f"worker timeout: {worker_addr}")
         | 
| 211 | 
            +
                        ret = {
         | 
| 212 | 
            +
                            "text": server_error_msg,
         | 
| 213 | 
            +
                            "error_code": 3,
         | 
| 214 | 
            +
                        }
         | 
| 215 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
                # Let the controller act as a worker to achieve hierarchical
         | 
| 219 | 
            +
                # management. This can be used to connect isolated sub networks.
         | 
| 220 | 
            +
                def worker_api_get_status(self):
         | 
| 221 | 
            +
                    model_names = set()
         | 
| 222 | 
            +
                    speed = 0
         | 
| 223 | 
            +
                    queue_length = 0
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    for w_name in self.worker_info:
         | 
| 226 | 
            +
                        worker_status = self.get_worker_status(w_name)
         | 
| 227 | 
            +
                        if worker_status is not None:
         | 
| 228 | 
            +
                            model_names.update(worker_status["model_names"])
         | 
| 229 | 
            +
                            speed += worker_status["speed"]
         | 
| 230 | 
            +
                            queue_length += worker_status["queue_length"]
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    return {
         | 
| 233 | 
            +
                        "model_names": list(model_names),
         | 
| 234 | 
            +
                        "speed": speed,
         | 
| 235 | 
            +
                        "queue_length": queue_length,
         | 
| 236 | 
            +
                    }
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            app = FastAPI()
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            @app.post("/register_worker")
         | 
| 243 | 
            +
            async def register_worker(request: Request):
         | 
| 244 | 
            +
                data = await request.json()
         | 
| 245 | 
            +
                controller.register_worker(
         | 
| 246 | 
            +
                    data["worker_name"], data["check_heart_beat"],
         | 
| 247 | 
            +
                    data.get("worker_status", None))
         | 
| 248 | 
            +
             | 
| 249 | 
            +
             | 
| 250 | 
            +
            @app.post("/refresh_all_workers")
         | 
| 251 | 
            +
            async def refresh_all_workers():
         | 
| 252 | 
            +
                models = controller.refresh_all_workers()
         | 
| 253 | 
            +
             | 
| 254 | 
            +
             | 
| 255 | 
            +
            @app.post("/list_models")
         | 
| 256 | 
            +
            async def list_models():
         | 
| 257 | 
            +
                models = controller.list_models()
         | 
| 258 | 
            +
                return {"models": models}
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            @app.post("/get_worker_address")
         | 
| 262 | 
            +
            async def get_worker_address(request: Request):
         | 
| 263 | 
            +
                data = await request.json()
         | 
| 264 | 
            +
                addr = controller.get_worker_address(data["model"])
         | 
| 265 | 
            +
                return {"address": addr}
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            @app.post("/receive_heart_beat")
         | 
| 269 | 
            +
            async def receive_heart_beat(request: Request):
         | 
| 270 | 
            +
                data = await request.json()
         | 
| 271 | 
            +
                exist = controller.receive_heart_beat(
         | 
| 272 | 
            +
                    data["worker_name"], data["queue_length"])
         | 
| 273 | 
            +
                return {"exist": exist}
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            @app.post("/worker_generate_stream")
         | 
| 277 | 
            +
            async def worker_api_generate_stream(request: Request):
         | 
| 278 | 
            +
                params = await request.json()
         | 
| 279 | 
            +
                generator = controller.worker_api_generate_stream(params)
         | 
| 280 | 
            +
                return StreamingResponse(generator)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
             | 
| 283 | 
            +
            @app.post("/worker_get_status")
         | 
| 284 | 
            +
            async def worker_api_get_status(request: Request):
         | 
| 285 | 
            +
                return controller.worker_api_get_status()
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            if __name__ == "__main__":
         | 
| 289 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 290 | 
            +
                parser.add_argument("--host", type=str, default="localhost")
         | 
| 291 | 
            +
                parser.add_argument("--port", type=int, default=21001)
         | 
| 292 | 
            +
                parser.add_argument("--dispatch-method", type=str, choices=[
         | 
| 293 | 
            +
                    "lottery", "shortest_queue"], default="shortest_queue")
         | 
| 294 | 
            +
                args = parser.parse_args()
         | 
| 295 | 
            +
                logger.info(f"args: {args}")
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                controller = Controller(args.dispatch_method)
         | 
| 298 | 
            +
                uvicorn.run(app, host=args.host, port=args.port, log_level="info")
         | 
    	
        llava/serve/gradio_web_server.py
    ADDED
    
    | @@ -0,0 +1,420 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import datetime
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import time
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import gradio as gr
         | 
| 8 | 
            +
            import requests
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from llava.conversation import (default_conversation, conv_templates,
         | 
| 11 | 
            +
                                               SeparatorStyle)
         | 
| 12 | 
            +
            from llava.constants import LOGDIR
         | 
| 13 | 
            +
            from llava.utils import (build_logger, server_error_msg,
         | 
| 14 | 
            +
                violates_moderation, moderation_msg)
         | 
| 15 | 
            +
            import hashlib
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            logger = build_logger("gradio_web_server", "gradio_web_server.log")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            headers = {"User-Agent": "LLaVA Client"}
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            no_change_btn = gr.Button.update()
         | 
| 23 | 
            +
            enable_btn = gr.Button.update(interactive=True)
         | 
| 24 | 
            +
            disable_btn = gr.Button.update(interactive=False)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            priority = {
         | 
| 27 | 
            +
                "vicuna-13b": "aaaaaaa",
         | 
| 28 | 
            +
                "koala-13b": "aaaaaab",
         | 
| 29 | 
            +
            }
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def get_conv_log_filename():
         | 
| 33 | 
            +
                t = datetime.datetime.now()
         | 
| 34 | 
            +
                name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
         | 
| 35 | 
            +
                return name
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def get_model_list():
         | 
| 39 | 
            +
                ret = requests.post(args.controller_url + "/refresh_all_workers")
         | 
| 40 | 
            +
                assert ret.status_code == 200
         | 
| 41 | 
            +
                ret = requests.post(args.controller_url + "/list_models")
         | 
| 42 | 
            +
                models = ret.json()["models"]
         | 
| 43 | 
            +
                models.sort(key=lambda x: priority.get(x, x))
         | 
| 44 | 
            +
                logger.info(f"Models: {models}")
         | 
| 45 | 
            +
                return models
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            get_window_url_params = """
         | 
| 49 | 
            +
            function() {
         | 
| 50 | 
            +
                const params = new URLSearchParams(window.location.search);
         | 
| 51 | 
            +
                url_params = Object.fromEntries(params);
         | 
| 52 | 
            +
                console.log(url_params);
         | 
| 53 | 
            +
                return url_params;
         | 
| 54 | 
            +
                }
         | 
| 55 | 
            +
            """
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def load_demo(url_params, request: gr.Request):
         | 
| 59 | 
            +
                logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                dropdown_update = gr.Dropdown.update(visible=True)
         | 
| 62 | 
            +
                if "model" in url_params:
         | 
| 63 | 
            +
                    model = url_params["model"]
         | 
| 64 | 
            +
                    if model in models:
         | 
| 65 | 
            +
                        dropdown_update = gr.Dropdown.update(
         | 
| 66 | 
            +
                            value=model, visible=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                state = default_conversation.copy()
         | 
| 69 | 
            +
                return state, dropdown_update
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def load_demo_refresh_model_list(request: gr.Request):
         | 
| 73 | 
            +
                logger.info(f"load_demo. ip: {request.client.host}")
         | 
| 74 | 
            +
                models = get_model_list()
         | 
| 75 | 
            +
                state = default_conversation.copy()
         | 
| 76 | 
            +
                dropdown_update = gr.Dropdown.update(
         | 
| 77 | 
            +
                    choices=models,
         | 
| 78 | 
            +
                    value=models[0] if len(models) > 0 else ""
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                return state, dropdown_update
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def vote_last_response(state, vote_type, model_selector, request: gr.Request):
         | 
| 84 | 
            +
                with open(get_conv_log_filename(), "a") as fout:
         | 
| 85 | 
            +
                    data = {
         | 
| 86 | 
            +
                        "tstamp": round(time.time(), 4),
         | 
| 87 | 
            +
                        "type": vote_type,
         | 
| 88 | 
            +
                        "model": model_selector,
         | 
| 89 | 
            +
                        "state": state.dict(),
         | 
| 90 | 
            +
                        "ip": request.client.host,
         | 
| 91 | 
            +
                    }
         | 
| 92 | 
            +
                    fout.write(json.dumps(data) + "\n")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def upvote_last_response(state, model_selector, request: gr.Request):
         | 
| 96 | 
            +
                logger.info(f"upvote. ip: {request.client.host}")
         | 
| 97 | 
            +
                vote_last_response(state, "upvote", model_selector, request)
         | 
| 98 | 
            +
                return ("",) + (disable_btn,) * 3
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def downvote_last_response(state, model_selector, request: gr.Request):
         | 
| 102 | 
            +
                logger.info(f"downvote. ip: {request.client.host}")
         | 
| 103 | 
            +
                vote_last_response(state, "downvote", model_selector, request)
         | 
| 104 | 
            +
                return ("",) + (disable_btn,) * 3
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def flag_last_response(state, model_selector, request: gr.Request):
         | 
| 108 | 
            +
                logger.info(f"flag. ip: {request.client.host}")
         | 
| 109 | 
            +
                vote_last_response(state, "flag", model_selector, request)
         | 
| 110 | 
            +
                return ("",) + (disable_btn,) * 3
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def regenerate(state, image_process_mode, request: gr.Request):
         | 
| 114 | 
            +
                logger.info(f"regenerate. ip: {request.client.host}")
         | 
| 115 | 
            +
                state.messages[-1][-1] = None
         | 
| 116 | 
            +
                prev_human_msg = state.messages[-2]
         | 
| 117 | 
            +
                if type(prev_human_msg[1]) in (tuple, list):
         | 
| 118 | 
            +
                    prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
         | 
| 119 | 
            +
                state.skip_next = False
         | 
| 120 | 
            +
                return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def clear_history(request: gr.Request):
         | 
| 124 | 
            +
                logger.info(f"clear_history. ip: {request.client.host}")
         | 
| 125 | 
            +
                state = default_conversation.copy()
         | 
| 126 | 
            +
                return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            def add_text(state, text, image, image_process_mode, request: gr.Request):
         | 
| 130 | 
            +
                logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
         | 
| 131 | 
            +
                if len(text) <= 0 and image is None:
         | 
| 132 | 
            +
                    state.skip_next = True
         | 
| 133 | 
            +
                    return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
         | 
| 134 | 
            +
                if args.moderate:
         | 
| 135 | 
            +
                    flagged = violates_moderation(text)
         | 
| 136 | 
            +
                    if flagged:
         | 
| 137 | 
            +
                        state.skip_next = True
         | 
| 138 | 
            +
                        return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
         | 
| 139 | 
            +
                            no_change_btn,) * 5
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                text = text[:1536]  # Hard cut-off
         | 
| 142 | 
            +
                if image is not None:
         | 
| 143 | 
            +
                    text = text[:1200]  # Hard cut-off for images
         | 
| 144 | 
            +
                    if '<image>' not in text:
         | 
| 145 | 
            +
                        # text = '<Image><image></Image>' + text
         | 
| 146 | 
            +
                        text = text + '\n<image>'
         | 
| 147 | 
            +
                    text = (text, image, image_process_mode)
         | 
| 148 | 
            +
                    if len(state.get_images(return_pil=True)) > 0:
         | 
| 149 | 
            +
                        state = default_conversation.copy()
         | 
| 150 | 
            +
                state.append_message(state.roles[0], text)
         | 
| 151 | 
            +
                state.append_message(state.roles[1], None)
         | 
| 152 | 
            +
                state.skip_next = False
         | 
| 153 | 
            +
                return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
         | 
| 157 | 
            +
                logger.info(f"http_bot. ip: {request.client.host}")
         | 
| 158 | 
            +
                start_tstamp = time.time()
         | 
| 159 | 
            +
                model_name = model_selector
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if state.skip_next:
         | 
| 162 | 
            +
                    # This generate call is skipped due to invalid inputs
         | 
| 163 | 
            +
                    yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
         | 
| 164 | 
            +
                    return
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                if len(state.messages) == state.offset + 2:
         | 
| 167 | 
            +
                    # First round of conversation
         | 
| 168 | 
            +
                    if "llava" in model_name.lower():
         | 
| 169 | 
            +
                        if 'llama-2' in model_name.lower():
         | 
| 170 | 
            +
                            template_name = "llava_llama_2"
         | 
| 171 | 
            +
                        elif "v1" in model_name.lower():
         | 
| 172 | 
            +
                            if 'mmtag' in model_name.lower():
         | 
| 173 | 
            +
                                template_name = "v1_mmtag"
         | 
| 174 | 
            +
                            elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
         | 
| 175 | 
            +
                                template_name = "v1_mmtag"
         | 
| 176 | 
            +
                            else:
         | 
| 177 | 
            +
                                template_name = "llava_v1"
         | 
| 178 | 
            +
                        elif "mpt" in model_name.lower():
         | 
| 179 | 
            +
                            template_name = "mpt"
         | 
| 180 | 
            +
                        else:
         | 
| 181 | 
            +
                            if 'mmtag' in model_name.lower():
         | 
| 182 | 
            +
                                template_name = "v0_mmtag"
         | 
| 183 | 
            +
                            elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
         | 
| 184 | 
            +
                                template_name = "v0_mmtag"
         | 
| 185 | 
            +
                            else:
         | 
| 186 | 
            +
                                template_name = "llava_v0"
         | 
| 187 | 
            +
                    elif "mpt" in model_name:
         | 
| 188 | 
            +
                        template_name = "mpt_text"
         | 
| 189 | 
            +
                    elif "llama-2" in model_name:
         | 
| 190 | 
            +
                        template_name = "llama_2"
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        template_name = "vicuna_v1"
         | 
| 193 | 
            +
                    new_state = conv_templates[template_name].copy()
         | 
| 194 | 
            +
                    new_state.append_message(new_state.roles[0], state.messages[-2][1])
         | 
| 195 | 
            +
                    new_state.append_message(new_state.roles[1], None)
         | 
| 196 | 
            +
                    state = new_state
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # Query worker address
         | 
| 199 | 
            +
                controller_url = args.controller_url
         | 
| 200 | 
            +
                ret = requests.post(controller_url + "/get_worker_address",
         | 
| 201 | 
            +
                        json={"model": model_name})
         | 
| 202 | 
            +
                worker_addr = ret.json()["address"]
         | 
| 203 | 
            +
                logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                # No available worker
         | 
| 206 | 
            +
                if worker_addr == "":
         | 
| 207 | 
            +
                    state.messages[-1][-1] = server_error_msg
         | 
| 208 | 
            +
                    yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
         | 
| 209 | 
            +
                    return
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                # Construct prompt
         | 
| 212 | 
            +
                prompt = state.get_prompt()
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                all_images = state.get_images(return_pil=True)
         | 
| 215 | 
            +
                all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
         | 
| 216 | 
            +
                for image, hash in zip(all_images, all_image_hash):
         | 
| 217 | 
            +
                    t = datetime.datetime.now()
         | 
| 218 | 
            +
                    filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
         | 
| 219 | 
            +
                    if not os.path.isfile(filename):
         | 
| 220 | 
            +
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
         | 
| 221 | 
            +
                        image.save(filename)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                # Make requests
         | 
| 224 | 
            +
                pload = {
         | 
| 225 | 
            +
                    "model": model_name,
         | 
| 226 | 
            +
                    "prompt": prompt,
         | 
| 227 | 
            +
                    "temperature": float(temperature),
         | 
| 228 | 
            +
                    "top_p": float(top_p),
         | 
| 229 | 
            +
                    "max_new_tokens": min(int(max_new_tokens), 1536),
         | 
| 230 | 
            +
                    "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
         | 
| 231 | 
            +
                    "images": f'List of {len(state.get_images())} images: {all_image_hash}',
         | 
| 232 | 
            +
                }
         | 
| 233 | 
            +
                logger.info(f"==== request ====\n{pload}")
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                pload['images'] = state.get_images()
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                state.messages[-1][-1] = "▌"
         | 
| 238 | 
            +
                yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                try:
         | 
| 241 | 
            +
                    # Stream output
         | 
| 242 | 
            +
                    response = requests.post(worker_addr + "/worker_generate_stream",
         | 
| 243 | 
            +
                        headers=headers, json=pload, stream=True, timeout=10)
         | 
| 244 | 
            +
                    for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
         | 
| 245 | 
            +
                        if chunk:
         | 
| 246 | 
            +
                            data = json.loads(chunk.decode())
         | 
| 247 | 
            +
                            if data["error_code"] == 0:
         | 
| 248 | 
            +
                                output = data["text"][len(prompt):].strip()
         | 
| 249 | 
            +
                                state.messages[-1][-1] = output + "▌"
         | 
| 250 | 
            +
                                yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
         | 
| 251 | 
            +
                            else:
         | 
| 252 | 
            +
                                output = data["text"] + f" (error_code: {data['error_code']})"
         | 
| 253 | 
            +
                                state.messages[-1][-1] = output
         | 
| 254 | 
            +
                                yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
         | 
| 255 | 
            +
                                return
         | 
| 256 | 
            +
                            time.sleep(0.03)
         | 
| 257 | 
            +
                except requests.exceptions.RequestException as e:
         | 
| 258 | 
            +
                    state.messages[-1][-1] = server_error_msg
         | 
| 259 | 
            +
                    yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
         | 
| 260 | 
            +
                    return
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                state.messages[-1][-1] = state.messages[-1][-1][:-1]
         | 
| 263 | 
            +
                yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                finish_tstamp = time.time()
         | 
| 266 | 
            +
                logger.info(f"{output}")
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                with open(get_conv_log_filename(), "a") as fout:
         | 
| 269 | 
            +
                    data = {
         | 
| 270 | 
            +
                        "tstamp": round(finish_tstamp, 4),
         | 
| 271 | 
            +
                        "type": "chat",
         | 
| 272 | 
            +
                        "model": model_name,
         | 
| 273 | 
            +
                        "start": round(start_tstamp, 4),
         | 
| 274 | 
            +
                        "finish": round(start_tstamp, 4),
         | 
| 275 | 
            +
                        "state": state.dict(),
         | 
| 276 | 
            +
                        "images": all_image_hash,
         | 
| 277 | 
            +
                        "ip": request.client.host,
         | 
| 278 | 
            +
                    }
         | 
| 279 | 
            +
                    fout.write(json.dumps(data) + "\n")
         | 
| 280 | 
            +
             | 
| 281 | 
            +
            title_markdown = ("""
         | 
| 282 | 
            +
            # 🌋 LLaVA: Large Language and Vision Assistant
         | 
| 283 | 
            +
            [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
         | 
| 284 | 
            +
            """)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
            tos_markdown = ("""
         | 
| 287 | 
            +
            ### Terms of use
         | 
| 288 | 
            +
            By using this service, users are required to agree to the following terms:
         | 
| 289 | 
            +
            The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
         | 
| 290 | 
            +
            Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
         | 
| 291 | 
            +
            For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
         | 
| 292 | 
            +
            """)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
             | 
| 295 | 
            +
            learn_more_markdown = ("""
         | 
| 296 | 
            +
            ### License
         | 
| 297 | 
            +
            The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
         | 
| 298 | 
            +
            """)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
            block_css = """
         | 
| 301 | 
            +
             | 
| 302 | 
            +
            #buttons button {
         | 
| 303 | 
            +
                min-width: min(120px,100%);
         | 
| 304 | 
            +
            }
         | 
| 305 | 
            +
             | 
| 306 | 
            +
            """
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            def build_demo(embed_mode):
         | 
| 309 | 
            +
                textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
         | 
| 310 | 
            +
                with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
         | 
| 311 | 
            +
                    state = gr.State()
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    if not embed_mode:
         | 
| 314 | 
            +
                        gr.Markdown(title_markdown)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    with gr.Row():
         | 
| 317 | 
            +
                        with gr.Column(scale=3):
         | 
| 318 | 
            +
                            with gr.Row(elem_id="model_selector_row"):
         | 
| 319 | 
            +
                                model_selector = gr.Dropdown(
         | 
| 320 | 
            +
                                    choices=models,
         | 
| 321 | 
            +
                                    value=models[0] if len(models) > 0 else "",
         | 
| 322 | 
            +
                                    interactive=True,
         | 
| 323 | 
            +
                                    show_label=False,
         | 
| 324 | 
            +
                                    container=False)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                            imagebox = gr.Image(type="pil")
         | 
| 327 | 
            +
                            image_process_mode = gr.Radio(
         | 
| 328 | 
            +
                                ["Crop", "Resize", "Pad", "Default"],
         | 
| 329 | 
            +
                                value="Default",
         | 
| 330 | 
            +
                                label="Preprocess for non-square image", visible=False)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                            cur_dir = os.path.dirname(os.path.abspath(__file__))
         | 
| 333 | 
            +
                            gr.Examples(examples=[
         | 
| 334 | 
            +
                                [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
         | 
| 335 | 
            +
                                [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
         | 
| 336 | 
            +
                            ], inputs=[imagebox, textbox])
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                            with gr.Accordion("Parameters", open=False) as parameter_row:
         | 
| 339 | 
            +
                                temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
         | 
| 340 | 
            +
                                top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
         | 
| 341 | 
            +
                                max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        with gr.Column(scale=8):
         | 
| 344 | 
            +
                            chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
         | 
| 345 | 
            +
                            with gr.Row():
         | 
| 346 | 
            +
                                with gr.Column(scale=8):
         | 
| 347 | 
            +
                                    textbox.render()
         | 
| 348 | 
            +
                                with gr.Column(scale=1, min_width=50):
         | 
| 349 | 
            +
                                    submit_btn = gr.Button(value="Send", variant="primary")
         | 
| 350 | 
            +
                            with gr.Row(elem_id="buttons") as button_row:
         | 
| 351 | 
            +
                                upvote_btn = gr.Button(value="👍  Upvote", interactive=False)
         | 
| 352 | 
            +
                                downvote_btn = gr.Button(value="👎  Downvote", interactive=False)
         | 
| 353 | 
            +
                                flag_btn = gr.Button(value="⚠️  Flag", interactive=False)
         | 
| 354 | 
            +
                                #stop_btn = gr.Button(value="⏹️  Stop Generation", interactive=False)
         | 
| 355 | 
            +
                                regenerate_btn = gr.Button(value="🔄  Regenerate", interactive=False)
         | 
| 356 | 
            +
                                clear_btn = gr.Button(value="🗑️  Clear", interactive=False)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    if not embed_mode:
         | 
| 359 | 
            +
                        gr.Markdown(tos_markdown)
         | 
| 360 | 
            +
                        gr.Markdown(learn_more_markdown)
         | 
| 361 | 
            +
                    url_params = gr.JSON(visible=False)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    # Register listeners
         | 
| 364 | 
            +
                    btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
         | 
| 365 | 
            +
                    upvote_btn.click(upvote_last_response,
         | 
| 366 | 
            +
                        [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
         | 
| 367 | 
            +
                    downvote_btn.click(downvote_last_response,
         | 
| 368 | 
            +
                        [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
         | 
| 369 | 
            +
                    flag_btn.click(flag_last_response,
         | 
| 370 | 
            +
                        [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
         | 
| 371 | 
            +
                    regenerate_btn.click(regenerate, [state, image_process_mode],
         | 
| 372 | 
            +
                        [state, chatbot, textbox, imagebox] + btn_list).then(
         | 
| 373 | 
            +
                        http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
         | 
| 374 | 
            +
                        [state, chatbot] + btn_list)
         | 
| 375 | 
            +
                    clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
         | 
| 378 | 
            +
                        ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
         | 
| 379 | 
            +
                               [state, chatbot] + btn_list)
         | 
| 380 | 
            +
                    submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
         | 
| 381 | 
            +
                        ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
         | 
| 382 | 
            +
                               [state, chatbot] + btn_list)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    if args.model_list_mode == "once":
         | 
| 385 | 
            +
                        demo.load(load_demo, [url_params], [state, model_selector],
         | 
| 386 | 
            +
                            _js=get_window_url_params)
         | 
| 387 | 
            +
                    elif args.model_list_mode == "reload":
         | 
| 388 | 
            +
                        demo.load(load_demo_refresh_model_list, None, [state, model_selector])
         | 
| 389 | 
            +
                    else:
         | 
| 390 | 
            +
                        raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                return demo
         | 
| 393 | 
            +
             | 
| 394 | 
            +
             | 
| 395 | 
            +
            if __name__ == "__main__":
         | 
| 396 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 397 | 
            +
                parser.add_argument("--host", type=str, default="0.0.0.0")
         | 
| 398 | 
            +
                parser.add_argument("--port", type=int)
         | 
| 399 | 
            +
                parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
         | 
| 400 | 
            +
                parser.add_argument("--concurrency-count", type=int, default=10)
         | 
| 401 | 
            +
                parser.add_argument("--model-list-mode", type=str, default="once",
         | 
| 402 | 
            +
                    choices=["once", "reload"])
         | 
| 403 | 
            +
                parser.add_argument("--share", action="store_true")
         | 
| 404 | 
            +
                parser.add_argument("--moderate", action="store_true")
         | 
| 405 | 
            +
                parser.add_argument("--embed", action="store_true")
         | 
| 406 | 
            +
                args = parser.parse_args()
         | 
| 407 | 
            +
                logger.info(f"args: {args}")
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                models = get_model_list()
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                logger.info(args)
         | 
| 412 | 
            +
                demo = build_demo(args.embed)
         | 
| 413 | 
            +
                demo.queue(
         | 
| 414 | 
            +
                    concurrency_count=args.concurrency_count,
         | 
| 415 | 
            +
                    api_open=False
         | 
| 416 | 
            +
                ).launch(
         | 
| 417 | 
            +
                    server_name=args.host,
         | 
| 418 | 
            +
                    server_port=args.port,
         | 
| 419 | 
            +
                    share=args.share
         | 
| 420 | 
            +
                )
         | 
    	
        llava/serve/model_worker.py
    ADDED
    
    | @@ -0,0 +1,285 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            A model worker executes the model.
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            import asyncio
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import threading
         | 
| 9 | 
            +
            import uuid
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from fastapi import FastAPI, Request, BackgroundTasks
         | 
| 12 | 
            +
            from fastapi.responses import StreamingResponse
         | 
| 13 | 
            +
            import requests
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import uvicorn
         | 
| 16 | 
            +
            from functools import partial
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from llava.constants import WORKER_HEART_BEAT_INTERVAL
         | 
| 19 | 
            +
            from llava.utils import (build_logger, server_error_msg,
         | 
| 20 | 
            +
                pretty_print_semaphore)
         | 
| 21 | 
            +
            from llava.model.builder import load_pretrained_model
         | 
| 22 | 
            +
            from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
         | 
| 23 | 
            +
            from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 24 | 
            +
            from transformers import TextIteratorStreamer
         | 
| 25 | 
            +
            from threading import Thread
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            GB = 1 << 30
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            worker_id = str(uuid.uuid4())[:6]
         | 
| 31 | 
            +
            logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
         | 
| 32 | 
            +
            global_counter = 0
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            model_semaphore = None
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def heart_beat_worker(controller):
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                while True:
         | 
| 40 | 
            +
                    time.sleep(WORKER_HEART_BEAT_INTERVAL)
         | 
| 41 | 
            +
                    controller.send_heart_beat()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class ModelWorker:
         | 
| 45 | 
            +
                def __init__(self, controller_addr, worker_addr,
         | 
| 46 | 
            +
                             worker_id, no_register,
         | 
| 47 | 
            +
                             model_path, model_base, model_name,
         | 
| 48 | 
            +
                             load_8bit, load_4bit, device):
         | 
| 49 | 
            +
                    self.controller_addr = controller_addr
         | 
| 50 | 
            +
                    self.worker_addr = worker_addr
         | 
| 51 | 
            +
                    self.worker_id = worker_id
         | 
| 52 | 
            +
                    if model_path.endswith("/"):
         | 
| 53 | 
            +
                        model_path = model_path[:-1]
         | 
| 54 | 
            +
                    if model_name is None:
         | 
| 55 | 
            +
                        model_paths = model_path.split("/")
         | 
| 56 | 
            +
                        if model_paths[-1].startswith('checkpoint-'):
         | 
| 57 | 
            +
                            self.model_name = model_paths[-2] + "_" + model_paths[-1]
         | 
| 58 | 
            +
                        else:
         | 
| 59 | 
            +
                            self.model_name = model_paths[-1]
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        self.model_name = model_name
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.device = device
         | 
| 64 | 
            +
                    logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
         | 
| 65 | 
            +
                    self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
         | 
| 66 | 
            +
                        model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
         | 
| 67 | 
            +
                    self.is_multimodal = 'llava' in self.model_name.lower()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    if not no_register:
         | 
| 70 | 
            +
                        self.register_to_controller()
         | 
| 71 | 
            +
                        self.heart_beat_thread = threading.Thread(
         | 
| 72 | 
            +
                            target=heart_beat_worker, args=(self,))
         | 
| 73 | 
            +
                        self.heart_beat_thread.start()
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def register_to_controller(self):
         | 
| 76 | 
            +
                    logger.info("Register to controller")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    url = self.controller_addr + "/register_worker"
         | 
| 79 | 
            +
                    data = {
         | 
| 80 | 
            +
                        "worker_name": self.worker_addr,
         | 
| 81 | 
            +
                        "check_heart_beat": True,
         | 
| 82 | 
            +
                        "worker_status": self.get_status()
         | 
| 83 | 
            +
                    }
         | 
| 84 | 
            +
                    r = requests.post(url, json=data)
         | 
| 85 | 
            +
                    assert r.status_code == 200
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def send_heart_beat(self):
         | 
| 88 | 
            +
                    logger.info(f"Send heart beat. Models: {[self.model_name]}. "
         | 
| 89 | 
            +
                                f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
         | 
| 90 | 
            +
                                f"global_counter: {global_counter}")
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    url = self.controller_addr + "/receive_heart_beat"
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    while True:
         | 
| 95 | 
            +
                        try:
         | 
| 96 | 
            +
                            ret = requests.post(url, json={
         | 
| 97 | 
            +
                                "worker_name": self.worker_addr,
         | 
| 98 | 
            +
                                "queue_length": self.get_queue_length()}, timeout=5)
         | 
| 99 | 
            +
                            exist = ret.json()["exist"]
         | 
| 100 | 
            +
                            break
         | 
| 101 | 
            +
                        except requests.exceptions.RequestException as e:
         | 
| 102 | 
            +
                            logger.error(f"heart beat error: {e}")
         | 
| 103 | 
            +
                        time.sleep(5)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if not exist:
         | 
| 106 | 
            +
                        self.register_to_controller()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def get_queue_length(self):
         | 
| 109 | 
            +
                    if model_semaphore is None:
         | 
| 110 | 
            +
                        return 0
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        return args.limit_model_concurrency - model_semaphore._value + (len(
         | 
| 113 | 
            +
                            model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def get_status(self):
         | 
| 116 | 
            +
                    return {
         | 
| 117 | 
            +
                        "model_names": [self.model_name],
         | 
| 118 | 
            +
                        "speed": 1,
         | 
| 119 | 
            +
                        "queue_length": self.get_queue_length(),
         | 
| 120 | 
            +
                    }
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                @torch.inference_mode()
         | 
| 123 | 
            +
                def generate_stream(self, params):
         | 
| 124 | 
            +
                    tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    prompt = params["prompt"]
         | 
| 127 | 
            +
                    ori_prompt = prompt
         | 
| 128 | 
            +
                    images = params.get("images", None)
         | 
| 129 | 
            +
                    num_image_tokens = 0
         | 
| 130 | 
            +
                    if images is not None and len(images) > 0 and self.is_multimodal:
         | 
| 131 | 
            +
                        if len(images) > 0:
         | 
| 132 | 
            +
                            if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
         | 
| 133 | 
            +
                                raise ValueError("Number of images does not match number of <image> tokens in prompt")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                            images = [load_image_from_base64(image) for image in images]
         | 
| 136 | 
            +
                            images = process_images(images, image_processor, model.config)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                            if type(images) is list:
         | 
| 139 | 
            +
                                images = [image.to(self.model.device, dtype=torch.float16) for image in images]
         | 
| 140 | 
            +
                            else:
         | 
| 141 | 
            +
                                images = images.to(self.model.device, dtype=torch.float16)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                            replace_token = DEFAULT_IMAGE_TOKEN
         | 
| 144 | 
            +
                            if getattr(self.model.config, 'mm_use_im_start_end', False):
         | 
| 145 | 
            +
                                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
         | 
| 146 | 
            +
                            prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                            num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
         | 
| 149 | 
            +
                        else:
         | 
| 150 | 
            +
                            images = None
         | 
| 151 | 
            +
                        image_args = {"images": images}
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        images = None
         | 
| 154 | 
            +
                        image_args = {}
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    temperature = float(params.get("temperature", 1.0))
         | 
| 157 | 
            +
                    top_p = float(params.get("top_p", 1.0))
         | 
| 158 | 
            +
                    max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
         | 
| 159 | 
            +
                    max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
         | 
| 160 | 
            +
                    stop_str = params.get("stop", None)
         | 
| 161 | 
            +
                    do_sample = True if temperature > 0.001 else False
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
         | 
| 164 | 
            +
                    keywords = [stop_str]
         | 
| 165 | 
            +
                    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
         | 
| 166 | 
            +
                    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if max_new_tokens < 1:
         | 
| 171 | 
            +
                        yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
         | 
| 172 | 
            +
                        return
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    thread = Thread(target=model.generate, kwargs=dict(
         | 
| 175 | 
            +
                        inputs=input_ids,
         | 
| 176 | 
            +
                        do_sample=do_sample,
         | 
| 177 | 
            +
                        temperature=temperature,
         | 
| 178 | 
            +
                        top_p=top_p,
         | 
| 179 | 
            +
                        max_new_tokens=max_new_tokens,
         | 
| 180 | 
            +
                        streamer=streamer,
         | 
| 181 | 
            +
                        stopping_criteria=[stopping_criteria],
         | 
| 182 | 
            +
                        use_cache=True,
         | 
| 183 | 
            +
                        **image_args
         | 
| 184 | 
            +
                    ))
         | 
| 185 | 
            +
                    thread.start()
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    generated_text = ori_prompt
         | 
| 188 | 
            +
                    for new_text in streamer:
         | 
| 189 | 
            +
                        generated_text += new_text
         | 
| 190 | 
            +
                        if generated_text.endswith(stop_str):
         | 
| 191 | 
            +
                            generated_text = generated_text[:-len(stop_str)]
         | 
| 192 | 
            +
                        yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def generate_stream_gate(self, params):
         | 
| 195 | 
            +
                    try:
         | 
| 196 | 
            +
                        for x in self.generate_stream(params):
         | 
| 197 | 
            +
                            yield x
         | 
| 198 | 
            +
                    except ValueError as e:
         | 
| 199 | 
            +
                        print("Caught ValueError:", e)
         | 
| 200 | 
            +
                        ret = {
         | 
| 201 | 
            +
                            "text": server_error_msg,
         | 
| 202 | 
            +
                            "error_code": 1,
         | 
| 203 | 
            +
                        }
         | 
| 204 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 205 | 
            +
                    except torch.cuda.CudaError as e:
         | 
| 206 | 
            +
                        print("Caught torch.cuda.CudaError:", e)
         | 
| 207 | 
            +
                        ret = {
         | 
| 208 | 
            +
                            "text": server_error_msg,
         | 
| 209 | 
            +
                            "error_code": 1,
         | 
| 210 | 
            +
                        }
         | 
| 211 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 212 | 
            +
                    except Exception as e:
         | 
| 213 | 
            +
                        print("Caught Unknown Error", e)
         | 
| 214 | 
            +
                        ret = {
         | 
| 215 | 
            +
                            "text": server_error_msg,
         | 
| 216 | 
            +
                            "error_code": 1,
         | 
| 217 | 
            +
                        }
         | 
| 218 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            app = FastAPI()
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            def release_model_semaphore(fn=None):
         | 
| 225 | 
            +
                model_semaphore.release()
         | 
| 226 | 
            +
                if fn is not None:
         | 
| 227 | 
            +
                    fn()
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            @app.post("/worker_generate_stream")
         | 
| 231 | 
            +
            async def generate_stream(request: Request):
         | 
| 232 | 
            +
                global model_semaphore, global_counter
         | 
| 233 | 
            +
                global_counter += 1
         | 
| 234 | 
            +
                params = await request.json()
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                if model_semaphore is None:
         | 
| 237 | 
            +
                    model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
         | 
| 238 | 
            +
                await model_semaphore.acquire()
         | 
| 239 | 
            +
                worker.send_heart_beat()
         | 
| 240 | 
            +
                generator = worker.generate_stream_gate(params)
         | 
| 241 | 
            +
                background_tasks = BackgroundTasks()
         | 
| 242 | 
            +
                background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
         | 
| 243 | 
            +
                return StreamingResponse(generator, background=background_tasks)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            @app.post("/worker_get_status")
         | 
| 247 | 
            +
            async def get_status(request: Request):
         | 
| 248 | 
            +
                return worker.get_status()
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            if __name__ == "__main__":
         | 
| 252 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 253 | 
            +
                parser.add_argument("--host", type=str, default="localhost")
         | 
| 254 | 
            +
                parser.add_argument("--port", type=int, default=21002)
         | 
| 255 | 
            +
                parser.add_argument("--worker-address", type=str,
         | 
| 256 | 
            +
                    default="http://localhost:21002")
         | 
| 257 | 
            +
                parser.add_argument("--controller-address", type=str,
         | 
| 258 | 
            +
                    default="http://localhost:21001")
         | 
| 259 | 
            +
                parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
         | 
| 260 | 
            +
                parser.add_argument("--model-base", type=str, default=None)
         | 
| 261 | 
            +
                parser.add_argument("--model-name", type=str)
         | 
| 262 | 
            +
                parser.add_argument("--device", type=str, default="cuda")
         | 
| 263 | 
            +
                parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
         | 
| 264 | 
            +
                parser.add_argument("--limit-model-concurrency", type=int, default=5)
         | 
| 265 | 
            +
                parser.add_argument("--stream-interval", type=int, default=1)
         | 
| 266 | 
            +
                parser.add_argument("--no-register", action="store_true")
         | 
| 267 | 
            +
                parser.add_argument("--load-8bit", action="store_true")
         | 
| 268 | 
            +
                parser.add_argument("--load-4bit", action="store_true")
         | 
| 269 | 
            +
                args = parser.parse_args()
         | 
| 270 | 
            +
                logger.info(f"args: {args}")
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                if args.multi_modal:
         | 
| 273 | 
            +
                    logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                worker = ModelWorker(args.controller_address,
         | 
| 276 | 
            +
                                     args.worker_address,
         | 
| 277 | 
            +
                                     worker_id,
         | 
| 278 | 
            +
                                     args.no_register,
         | 
| 279 | 
            +
                                     args.model_path,
         | 
| 280 | 
            +
                                     args.model_base,
         | 
| 281 | 
            +
                                     args.model_name,
         | 
| 282 | 
            +
                                     args.load_8bit,
         | 
| 283 | 
            +
                                     args.load_4bit,
         | 
| 284 | 
            +
                                     args.device)
         | 
| 285 | 
            +
                uvicorn.run(app, host=args.host, port=args.port, log_level="info")
         | 
    	
        llava/serve/register_worker.py
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Manually register workers.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Usage:
         | 
| 5 | 
            +
            python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import argparse
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import requests
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            if __name__ == "__main__":
         | 
| 13 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 14 | 
            +
                parser.add_argument("--controller-address", type=str)
         | 
| 15 | 
            +
                parser.add_argument("--worker-name", type=str)
         | 
| 16 | 
            +
                parser.add_argument("--check-heart-beat", action="store_true")
         | 
| 17 | 
            +
                args = parser.parse_args()
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                url = args.controller_address + "/register_worker"
         | 
| 20 | 
            +
                data = {
         | 
| 21 | 
            +
                    "worker_name": args.worker_name,
         | 
| 22 | 
            +
                    "check_heart_beat": args.check_heart_beat,
         | 
| 23 | 
            +
                    "worker_status": None,
         | 
| 24 | 
            +
                }
         | 
| 25 | 
            +
                r = requests.post(url, json=data)
         | 
| 26 | 
            +
                assert r.status_code == 200
         | 
    	
        llava/serve/test_message.py
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import requests
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from llava.conversation import default_conversation
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def main():
         | 
| 10 | 
            +
                if args.worker_address:
         | 
| 11 | 
            +
                    worker_addr = args.worker_address
         | 
| 12 | 
            +
                else:
         | 
| 13 | 
            +
                    controller_addr = args.controller_address
         | 
| 14 | 
            +
                    ret = requests.post(controller_addr + "/refresh_all_workers")
         | 
| 15 | 
            +
                    ret = requests.post(controller_addr + "/list_models")
         | 
| 16 | 
            +
                    models = ret.json()["models"]
         | 
| 17 | 
            +
                    models.sort()
         | 
| 18 | 
            +
                    print(f"Models: {models}")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    ret = requests.post(controller_addr + "/get_worker_address",
         | 
| 21 | 
            +
                        json={"model": args.model_name})
         | 
| 22 | 
            +
                    worker_addr = ret.json()["address"]
         | 
| 23 | 
            +
                    print(f"worker_addr: {worker_addr}")
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                if worker_addr == "":
         | 
| 26 | 
            +
                    return
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                conv = default_conversation.copy()
         | 
| 29 | 
            +
                conv.append_message(conv.roles[0], args.message)
         | 
| 30 | 
            +
                prompt = conv.get_prompt()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                headers = {"User-Agent": "LLaVA Client"}
         | 
| 33 | 
            +
                pload = {
         | 
| 34 | 
            +
                    "model": args.model_name,
         | 
| 35 | 
            +
                    "prompt": prompt,
         | 
| 36 | 
            +
                    "max_new_tokens": args.max_new_tokens,
         | 
| 37 | 
            +
                    "temperature": 0.7,
         | 
| 38 | 
            +
                    "stop": conv.sep,
         | 
| 39 | 
            +
                }
         | 
| 40 | 
            +
                response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
         | 
| 41 | 
            +
                        json=pload, stream=True)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                print(prompt.replace(conv.sep, "\n"), end="")
         | 
| 44 | 
            +
                for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
         | 
| 45 | 
            +
                    if chunk:
         | 
| 46 | 
            +
                        data = json.loads(chunk.decode("utf-8"))
         | 
| 47 | 
            +
                        output = data["text"].split(conv.sep)[-1]
         | 
| 48 | 
            +
                        print(output, end="\r")
         | 
| 49 | 
            +
                print("")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            if __name__ == "__main__":
         | 
| 53 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 54 | 
            +
                parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
         | 
| 55 | 
            +
                parser.add_argument("--worker-address", type=str)
         | 
| 56 | 
            +
                parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
         | 
| 57 | 
            +
                parser.add_argument("--max-new-tokens", type=int, default=32)
         | 
| 58 | 
            +
                parser.add_argument("--message", type=str, default=
         | 
| 59 | 
            +
                    "Tell me a story with more than 1000 words.")
         | 
| 60 | 
            +
                args = parser.parse_args()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                main()
         | 
