Spaces:
Runtime error
Runtime error
| """Send a test message.""" | |
| import argparse | |
| import json | |
| import requests | |
| from fastchat.model.model_adapter import get_conversation_template | |
| def main(): | |
| model_name = args.model_name | |
| if args.worker_address: | |
| worker_addr = args.worker_address | |
| else: | |
| controller_addr = args.controller_address | |
| ret = requests.post(controller_addr + "/refresh_all_workers") | |
| ret = requests.post(controller_addr + "/list_models") | |
| models = ret.json()["models"] | |
| models.sort() | |
| print(f"Models: {models}") | |
| ret = requests.post( | |
| controller_addr + "/get_worker_address", json={"model": model_name} | |
| ) | |
| worker_addr = ret.json()["address"] | |
| print(f"worker_addr: {worker_addr}") | |
| if worker_addr == "": | |
| print(f"No available workers for {model_name}") | |
| return | |
| conv = get_conversation_template(model_name) | |
| conv.append_message(conv.roles[0], args.message) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| headers = {"User-Agent": "FastChat Client"} | |
| gen_params = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "temperature": args.temperature, | |
| "max_new_tokens": args.max_new_tokens, | |
| "stop": conv.stop_str, | |
| "stop_token_ids": conv.stop_token_ids, | |
| "echo": False, | |
| } | |
| response = requests.post( | |
| worker_addr + "/worker_generate_stream", | |
| headers=headers, | |
| json=gen_params, | |
| stream=True, | |
| ) | |
| print(f"{conv.roles[0]}: {args.message}") | |
| print(f"{conv.roles[1]}: ", end="") | |
| prev = 0 | |
| for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
| if chunk: | |
| data = json.loads(chunk.decode()) | |
| output = data["text"].strip() | |
| print(output[prev:], end="", flush=True) | |
| prev = len(output) | |
| print("") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--controller-address", type=str, default="http://localhost:21001" | |
| ) | |
| parser.add_argument("--worker-address", type=str) | |
| parser.add_argument("--model-name", type=str, required=True) | |
| parser.add_argument("--temperature", type=float, default=0.0) | |
| parser.add_argument("--max-new-tokens", type=int, default=32) | |
| parser.add_argument( | |
| "--message", type=str, default="Tell me a story with more than 1000 words." | |
| ) | |
| args = parser.parse_args() | |
| main() | |