Spaces:
Runtime error
Runtime error
| import gc | |
| import time | |
| import uuid | |
| from threading import Thread | |
| from types import MethodType | |
| from typing import Iterable, Dict, Any | |
| import torch | |
| from transformers import ( | |
| TextIteratorStreamer, | |
| PreTrainedModel, | |
| PreTrainedTokenizer, | |
| ) | |
| from api.generation.qwen import check_is_qwen | |
| from api.generation.utils import ( | |
| prepare_logits_processor, | |
| is_partial_stop, | |
| apply_stopping_strings, | |
| ) | |
| def generate_stream( | |
| model: PreTrainedModel, | |
| tokenizer: PreTrainedTokenizer, | |
| params: Dict[str, Any], | |
| ): | |
| # Read parameters | |
| input_ids = params.get("inputs") | |
| prompt = params.get("prompt") | |
| model_name = params.get("model", "llm") | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| top_k = int(params.get("top_k", -1)) # -1 means disable | |
| max_new_tokens = int(params.get("max_tokens", 256)) | |
| logprobs = params.get("logprobs") | |
| echo = bool(params.get("echo", True)) | |
| stop_str = params.get("stop") | |
| stop_token_ids = params.get("stop_token_ids") or [] | |
| if tokenizer.eos_token_id not in stop_token_ids: | |
| stop_token_ids.append(tokenizer.eos_token_id) | |
| logits_processor = prepare_logits_processor( | |
| temperature, repetition_penalty, top_p, top_k | |
| ) | |
| output_ids = list(input_ids) | |
| input_echo_len = len(input_ids) | |
| device = model.device | |
| if model.config.is_encoder_decoder: | |
| encoder_output = model.encoder( | |
| input_ids=torch.as_tensor([input_ids], device=device) | |
| )[0] | |
| start_ids = torch.as_tensor( | |
| [[model.generation_config.decoder_start_token_id]], | |
| dtype=torch.int64, | |
| device=device, | |
| ) | |
| else: | |
| start_ids = torch.as_tensor([input_ids], device=device) | |
| past_key_values, sent_interrupt = None, False | |
| token_logprobs = [None] # The first token has no logprobs. | |
| completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
| created: int = int(time.time()) | |
| previous_text = "" | |
| for i in range(max_new_tokens): | |
| if i == 0: # prefill | |
| if model.config.is_encoder_decoder: | |
| out = model.decoder( | |
| input_ids=start_ids, | |
| encoder_hidden_states=encoder_output, | |
| use_cache=True, | |
| ) | |
| logits = model.lm_head(out[0]) | |
| else: | |
| out = model(torch.as_tensor([input_ids], device=device), use_cache=True) | |
| logits = out.logits | |
| past_key_values = out.past_key_values | |
| if logprobs is not None: | |
| # Prefull logprobs for the prompt. | |
| shift_input_ids = start_ids[..., 1:].contiguous() | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() | |
| for label_id, logit in zip( | |
| shift_input_ids[0].tolist(), shift_logits[0] | |
| ): | |
| token_logprobs.append(logit[label_id]) | |
| else: # decoding | |
| if model.config.is_encoder_decoder: | |
| out = model.decoder( | |
| input_ids=torch.as_tensor( | |
| [output_ids if sent_interrupt else [token]], device=device | |
| ), | |
| encoder_hidden_states=encoder_output, | |
| use_cache=True, | |
| past_key_values=None if sent_interrupt else past_key_values, | |
| ) | |
| sent_interrupt = False | |
| logits = model.lm_head(out[0]) | |
| else: | |
| out = model( | |
| input_ids=torch.as_tensor( | |
| [output_ids if sent_interrupt else [token]], device=device | |
| ), | |
| use_cache=True, | |
| past_key_values=None if sent_interrupt else past_key_values, | |
| ) | |
| sent_interrupt = False | |
| logits = out.logits | |
| past_key_values = out.past_key_values | |
| if logits_processor: | |
| if repetition_penalty > 1.0: | |
| tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) | |
| else: | |
| tmp_output_ids = None | |
| last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] | |
| else: | |
| last_token_logits = logits[0, -1, :] | |
| if device == "mps": | |
| # Switch to CPU by avoiding some bugs in mps backend. | |
| last_token_logits = last_token_logits.float().to("cpu") | |
| if temperature < 1e-5 or top_p < 1e-8: # greedy | |
| _, indices = torch.topk(last_token_logits, 2) | |
| tokens = [int(index) for index in indices.tolist()] | |
| else: | |
| probs = torch.softmax(last_token_logits, dim=-1) | |
| indices = torch.multinomial(probs, num_samples=2) | |
| tokens = [int(token) for token in indices.tolist()] | |
| token = tokens[0] | |
| output_ids.append(token) | |
| if logprobs is not None: | |
| # Cannot use last_token_logits because logprobs is based on raw logits. | |
| token_logprobs.append( | |
| torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() | |
| ) | |
| if token in stop_token_ids: | |
| stopped = True | |
| else: | |
| stopped = False | |
| # Yield the output tokens | |
| if i % 2 == 0 or i == max_new_tokens - 1 or stopped: | |
| if echo: | |
| tmp_output_ids = output_ids | |
| rfind_start = len(prompt) | |
| else: | |
| tmp_output_ids = output_ids[input_echo_len:] | |
| rfind_start = 0 | |
| output = tokenizer.decode( | |
| tmp_output_ids, | |
| skip_special_tokens=False if check_is_qwen(model) else True, # fix for qwen react | |
| spaces_between_special_tokens=False, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| ret_logprobs = None | |
| if logprobs is not None: | |
| ret_logprobs = { | |
| "text_offset": [], | |
| "tokens": [ | |
| tokenizer.decode(token) | |
| for token in ( | |
| output_ids if echo else output_ids[input_echo_len:] | |
| ) | |
| ], | |
| "token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:], | |
| "top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]), | |
| } | |
| # Compute text_offset | |
| curr_pos = 0 | |
| for text in ret_logprobs["tokens"]: | |
| ret_logprobs["text_offset"].append(curr_pos) | |
| curr_pos += len(text) | |
| partially_stopped, finish_reason = False, None | |
| if stop_str: | |
| if isinstance(stop_str, str): | |
| pos = output.rfind(stop_str, rfind_start) | |
| if pos != -1: | |
| output = output[:pos] | |
| stopped = True | |
| else: | |
| partially_stopped = is_partial_stop(output, stop_str) | |
| elif isinstance(stop_str, Iterable): | |
| for each_stop in stop_str: | |
| pos = output.rfind(each_stop, rfind_start) | |
| if pos != -1: | |
| output = output[:pos] | |
| stopped = True | |
| if each_stop == "Observation:": | |
| finish_reason = "function_call" | |
| break | |
| else: | |
| partially_stopped = is_partial_stop(output, each_stop) | |
| if partially_stopped: | |
| break | |
| else: | |
| raise ValueError("Invalid stop field type.") | |
| # Prevent yielding partial stop sequence | |
| if (not partially_stopped) and output and output[-1] != "�": | |
| delta_text = output[len(previous_text):] | |
| previous_text = output | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": delta_text, | |
| "text": output, | |
| "logprobs": ret_logprobs, | |
| "finish_reason": finish_reason, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| } | |
| if stopped: | |
| break | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": "", | |
| "text": output, | |
| "logprobs": ret_logprobs, | |
| "finish_reason": "stop", | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| } | |
| # Clean | |
| del past_key_values, out | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def generate_stream_v2( | |
| model: PreTrainedModel, | |
| tokenizer: PreTrainedTokenizer, | |
| params: Dict[str, Any], | |
| ): | |
| input_ids = params.get("inputs") | |
| functions = params.get("functions") | |
| model_name = params.get("model", "llm") | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| top_k = int(params.get("top_k", 40)) | |
| max_new_tokens = int(params.get("max_tokens", 256)) | |
| stop_token_ids = params.get("stop_token_ids") or [] | |
| if tokenizer.eos_token_id not in stop_token_ids: | |
| stop_token_ids.append(tokenizer.eos_token_id) | |
| stop_strings = params.get("stop", []) | |
| input_echo_len = len(input_ids) | |
| device = model.device | |
| generation_kwargs = dict( | |
| input_ids=torch.tensor([input_ids], device=device), | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| if temperature <= 1e-5: | |
| generation_kwargs["do_sample"] = False | |
| generation_kwargs.pop("top_k") | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generation_kwargs["streamer"] = streamer | |
| if "GenerationMixin" not in str(model.generate.__func__): | |
| model.generate = MethodType(PreTrainedModel.generate, model) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text, func_call_found = "", False | |
| completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
| created: int = int(time.time()) | |
| previous_text = "" | |
| for i, new_text in enumerate(streamer): | |
| generated_text += new_text | |
| if functions: | |
| _, func_call_found = apply_stopping_strings(generated_text, ["Observation:"]) | |
| generated_text, stop_found = apply_stopping_strings(generated_text, stop_strings) | |
| if generated_text and generated_text[-1] != "�": | |
| delta_text = generated_text[len(previous_text):] | |
| previous_text = generated_text | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": delta_text, | |
| "text": generated_text, | |
| "logprobs": None, | |
| "finish_reason": "function_call" if func_call_found else None, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| } | |
| if stop_found: | |
| break | |
| yield { | |
| "id": completion_id, | |
| "object": "text_completion", | |
| "created": created, | |
| "model": model_name, | |
| "delta": "", | |
| "text": generated_text, | |
| "logprobs": None, | |
| "finish_reason": "stop", | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| } | |