Spaces:
Runtime error
Runtime error
| import os | |
| from types import SimpleNamespace | |
| import warnings | |
| import torch | |
| os.environ["RWKV_JIT_ON"] = "1" | |
| os.environ["RWKV_CUDA_ON"] = "1" | |
| from rwkv.model import RWKV | |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
| class RwkvModel: | |
| def __init__(self, model_path): | |
| warnings.warn( | |
| "Experimental support. Please use ChatRWKV if you want to chat with RWKV" | |
| ) | |
| self.config = SimpleNamespace(is_encoder_decoder=False) | |
| self.model = RWKV(model=model_path, strategy="cuda fp16") | |
| # two GPUs | |
| # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") | |
| self.tokenizer = None | |
| self.model_path = model_path | |
| def to(self, target): | |
| assert target == "cuda" | |
| def __call__(self, input_ids, use_cache, past_key_values=None): | |
| assert use_cache == True | |
| input_ids = input_ids[0].detach().cpu().numpy() | |
| # print(input_ids) | |
| logits, state = self.model.forward(input_ids, past_key_values) | |
| # print(logits) | |
| logits = logits.unsqueeze(0).unsqueeze(0) | |
| out = SimpleNamespace(logits=logits, past_key_values=state) | |
| return out | |
| def generate( | |
| self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 | |
| ): | |
| # This function is used by fastchat.llm_judge. | |
| # Because RWKV does not support huggingface generation API, | |
| # we reuse fastchat.serve.inference.generate_stream as a workaround. | |
| from transformers import AutoTokenizer | |
| from fastchat.serve.inference import generate_stream | |
| from fastchat.conversation import get_conv_template | |
| if self.tokenizer is None: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "EleutherAI/pythia-160m", use_fast=True | |
| ) | |
| prompt = self.tokenizer.decode(input_ids[0].tolist()) | |
| conv = get_conv_template("rwkv") | |
| gen_params = { | |
| "model": self.model_path, | |
| "prompt": prompt, | |
| "temperature": temperature, | |
| "repetition_penalty": repetition_penalty, | |
| "max_new_tokens": max_new_tokens, | |
| "stop": conv.stop_str, | |
| "stop_token_ids": conv.stop_token_ids, | |
| "echo": False, | |
| } | |
| res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") | |
| for res in res_iter: | |
| pass | |
| output = res["text"] | |
| output_ids = self.tokenizer.encode(output) | |
| return [input_ids[0].tolist() + output_ids] | |