Spaces:
Runtime error
Runtime error
| import datetime | |
| import string | |
| import nltk | |
| nltk.download('stopwords') | |
| from nltk.corpus import stopwords | |
| stop_words = stopwords.words('english') | |
| import time | |
| import arxiv | |
| import colorlog | |
| import torch | |
| fmt_string = '%(log_color)s %(asctime)s - %(levelname)s - %(message)s' | |
| log_colors = { | |
| 'DEBUG': 'white', | |
| 'INFO': 'green', | |
| 'WARNING': 'yellow', | |
| 'ERROR': 'red', | |
| 'CRITICAL': 'purple' | |
| } | |
| colorlog.basicConfig(log_colors=log_colors, format=fmt_string, level=colorlog.INFO) | |
| logger = colorlog.getLogger(__name__) | |
| logger.setLevel(colorlog.INFO) | |
| def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False): | |
| if 'Semantic Search' in source: | |
| title = rag_answer['document_metadata']['title'].replace('\n','') | |
| #score = round(rag_answer['score'], 2) | |
| date = rag_answer['document_metadata']['_time'] | |
| paper_abs = rag_answer['content'] | |
| authors = rag_answer['document_metadata']['authors'].replace('\n','') | |
| doc_id = rag_answer['document_id'] | |
| paper_link = f'''https://arxiv.org/abs/{doc_id}''' | |
| download_link = f'''https://arxiv.org/pdf/{doc_id}''' | |
| elif 'Arxiv' in source: | |
| title = rag_answer.title | |
| date = rag_answer.updated.strftime('%d %b %Y') | |
| paper_abs = rag_answer.summary.replace('\n',' ') + '\n' | |
| authors = ', '.join([author.name for author in rag_answer.authors]) | |
| paper_link = rag_answer.links[0].href | |
| download_link = rag_answer.links[1].href | |
| else: | |
| raise Exception | |
| paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n''' | |
| authors_formatted = f'*{authors}*' + ' \n\n' | |
| md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n' | |
| if return_prompt_formatting: | |
| doc = { | |
| 'title': title, | |
| 'text': paper_abs | |
| } | |
| return md_text_formatted, doc | |
| return md_text_formatted | |
| def remove_punctuation(text): | |
| punct_str = string.punctuation | |
| punct_str = punct_str.replace("'", "") | |
| return text.translate(str.maketrans("", "", punct_str)) | |
| def remove_stopwords(text): | |
| text = ' '.join(word for word in text.split(' ') if word not in stop_words) | |
| return text | |
| def search_cleaner(text): | |
| new_text = text.lower() | |
| new_text = remove_stopwords(new_text) | |
| new_text = remove_punctuation(new_text) | |
| return new_text | |
| q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)' | |
| def get_arxiv_live_search(query, client, max_results = 10): | |
| clean_text = search_cleaner(query) | |
| search = arxiv.Search( | |
| query = clean_text + " AND "+q, | |
| max_results = max_results, | |
| sort_by = arxiv.SortCriterion.Relevance | |
| ) | |
| results = client.results(search) | |
| all_results = list(results) | |
| return all_results | |
| def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None): | |
| # For doc prompt: | |
| # - {ID}: doc id (starting from 1) | |
| # - {T}: title | |
| # - {P}: text | |
| # use_shorter: None, "summary", or "extraction" | |
| text = doc['text'] | |
| if use_shorter is not None: | |
| text = doc[use_shorter] | |
| return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1)) | |
| def get_shorter_text(item, docs, ndoc, key): | |
| doc_list = [] | |
| for item_id, item in enumerate(docs): | |
| if key not in item: | |
| if len(doc_list) == 0: | |
| # If there aren't any document, at least provide one (using full text) | |
| item[key] = item['text'] | |
| doc_list.append(item) | |
| logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.") | |
| break | |
| if "irrelevant" in item[key] or "Irrelevant" in item[key]: | |
| continue | |
| doc_list.append(item) | |
| if len(doc_list) >= ndoc: | |
| break | |
| return doc_list | |
| def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False): | |
| # For demo prompt | |
| # - {INST}: the instruction | |
| # - {D}: the documents | |
| # - {Q}: the question | |
| # - {A}: the answers | |
| # ndoc: number of documents to put in context | |
| # use_shorter: None, "summary", or "extraction" | |
| prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question']) | |
| if "{D}" in prompt: | |
| if ndoc == 0: | |
| prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line | |
| else: | |
| doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc] | |
| text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)]) | |
| prompt = prompt.replace("{D}", text) | |
| if not test: | |
| answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"] | |
| prompt = prompt.replace("{A}", "").rstrip() + answer | |
| else: | |
| prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n | |
| return prompt | |
| def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| dtype = torch.bfloat16 | |
| logger.info("loading llama_guard") | |
| llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="cuda") | |
| # Get the id of the "unsafe" token, this will later be used to extract its probability | |
| UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe") | |
| return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID | |
| def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID): | |
| prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False) | |
| # Skip the generation of whitespace. | |
| # Now the next predicted token will be either "safe" or "unsafe" | |
| prompt += "\n\n" | |
| inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=50, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| output_logits=True, # get logits | |
| ) | |
| ###### | |
| # Get generated text | |
| ###### | |
| # Number of tokens that correspond to the input prompt | |
| input_length = inputs.input_ids.shape[1] | |
| # Ignore the tokens from the input to get the tokens generated by the model | |
| generated_token_ids = outputs.sequences[:, input_length:].cpu() | |
| generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True) | |
| ###### | |
| # Get Probability of "unsafe" token | |
| ###### | |
| # First generated token is either "safe" or "unsafe". | |
| # use the logits to calculate the probabilities. | |
| first_token_logits = outputs.logits[0] | |
| first_token_probs = torch.softmax(first_token_logits, dim=-1) | |
| # From the probabilities of all tokens, extract the one for the "unsafe" token. | |
| unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID] | |
| unsafe_probability = unsafe_probability.item() | |
| ###### | |
| # Result | |
| ###### | |
| return { | |
| "unsafe_score": unsafe_probability, | |
| "generated_text": generated_text | |
| } | |
| def get_max_memory(): | |
| """Get the maximum memory available for the current GPU for loading models.""" | |
| free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) | |
| max_memory = f'{free_in_GB-1}GB' | |
| n_gpus = torch.cuda.device_count() | |
| max_memory = {i: max_memory for i in range(n_gpus)} | |
| return max_memory | |
| def load_model(model_name_or_path, dtype=torch.bfloat16, int8=False): | |
| # Load a huggingface model and tokenizer | |
| # dtype: torch.float16 or torch.bfloat16 | |
| # int8: whether to use int8 quantization | |
| # reserve_memory: how much memory to reserve for the model on each gpu (in GB) | |
| # Load the FP16 model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info(f"Loading {model_name_or_path} in {dtype}...") | |
| if int8: | |
| logger.warn("Use LLM.int8") | |
| start_time = time.time() | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| device_map='auto', | |
| torch_dtype=dtype, | |
| max_memory=get_max_memory(), | |
| load_in_8bit=int8, | |
| ) | |
| logger.info("Finish loading in %.2f sec." % (time.time() - start_time)) | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer | |
| def load_vllm(model_name_or_path, dtype=torch.bfloat16): | |
| from vllm import LLM, SamplingParams | |
| logger.info(f"Loading {model_name_or_path} in {dtype}...") | |
| start_time = time.time() | |
| model = LLM( | |
| model_name_or_path, | |
| dtype=dtype, | |
| gpu_memory_utilization=0.9, | |
| max_seq_len_to_capture=2048, | |
| max_model_len=8192, | |
| ) | |
| sampling_params = SamplingParams(temperature=0.1, top_p=1.00, max_tokens=300) | |
| logger.info("Finish loading in %.2f sec." % (time.time() - start_time)) | |
| # Load the tokenizer | |
| tokenizer = model.get_tokenizer() | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer, sampling_params | |
| class LLM: | |
| def __init__(self, model_name_or_path, use_vllm=True): | |
| self.use_vllm = use_vllm | |
| if use_vllm: | |
| self.chat_llm, self.tokenizer, self.sampling_params = load_vllm(model_name_or_path) | |
| else: | |
| self.chat_llm, self.tokenizer = load_model(model_name_or_path) | |
| self.prompt_exceed_max_length = 0 | |
| self.fewer_than_50 = 0 | |
| def generate(self, prompt, max_tokens=300, stop=None): | |
| if max_tokens <= 0: | |
| self.prompt_exceed_max_length += 1 | |
| logger.warning("Prompt exceeds max length and return an empty string as answer. If this happens too many times, it is suggested to make the prompt shorter") | |
| return "" | |
| if max_tokens < 50: | |
| self.fewer_than_50 += 1 | |
| logger.warning("The model can at most generate < 50 tokens. If this happens too many times, it is suggested to make the prompt shorter") | |
| if self.use_vllm: | |
| inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) | |
| self.sampling_params.n = 1 # Number of output sequences to return for the given prompt | |
| self.sampling_params.stop_token_ids = [self.chat_llm.llm_engine.get_model_config().hf_config.eos_token_id] | |
| self.sampling_params.max_tokens = max_tokens | |
| output = self.chat_llm.generate( | |
| inputs, | |
| self.sampling_params, | |
| use_tqdm=True, | |
| ) | |
| generation = output[0].outputs[0].text.strip() | |
| else: | |
| inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.chat_llm.device) | |
| outputs = self.chat_llm.generate( | |
| **inputs, | |
| do_sample=True, temperature=0.1, top_p=1.0, | |
| max_new_tokens=max_tokens, | |
| num_return_sequences=1, | |
| eos_token_id=[self.chat_llm.config.eos_token_id] | |
| ) | |
| generation = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True).strip() | |
| return generation |