Spaces:
Build error
Build error
| import torch | |
| import nltk | |
| from scipy.io.wavfile import write | |
| import librosa | |
| import hashlib | |
| from typing import List | |
| def embed_questions( | |
| question_model, question_tokenizer, questions, max_length=128, device="cpu" | |
| ): | |
| query = question_tokenizer( | |
| questions, | |
| max_length=max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| q_reps = question_model( | |
| query["input_ids"].to(device), query["attention_mask"].to(device) | |
| ).pooler_output | |
| return q_reps.cpu().numpy() | |
| def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"): | |
| p = ctx_tokenizer( | |
| passages["text"], | |
| max_length=max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| a_reps = ctx_model( | |
| p["input_ids"].to(device), p["attention_mask"].to(device) | |
| ).pooler_output | |
| return {"embeddings": a_reps.cpu().numpy()} | |
| class Document: | |
| def __init__(self, meta={}, content: str = "", id_: str = ""): | |
| self.meta = meta | |
| self.content = content | |
| self.id = id_ | |
| def _alter_docs_for_haystack(passages): | |
| return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)] | |
| def embed_passages_haystack( | |
| dpr_model, | |
| passages, | |
| ): | |
| passages = _alter_docs_for_haystack(passages["text"]) | |
| embeddings = dpr_model.embed_documents(passages) | |
| return {"embeddings": embeddings} | |
| def correct_casing(input_sentence): | |
| """This function is for correcting the casing of the generated transcribed text""" | |
| sentences = nltk.sent_tokenize(input_sentence) | |
| return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]) | |
| def clean_transcript(text): | |
| text = text.replace("[pad]".upper(), "") | |
| return text | |
| def add_question_symbols(text): | |
| if text[0] != "¿": | |
| text = "¿" + text | |
| if text[-1] != "?": | |
| text = text + "?" | |
| return text | |
| def remove_chars_to_tts(text): | |
| text = text.replace(",", " ") | |
| return text | |
| def transcript(input_file, audio_array, processor, model): | |
| if audio_array: | |
| rate, sample = audio_array | |
| write("temp.wav", rate, sample) | |
| input_file = "temp.wav" | |
| transcript = "" | |
| # Ensure that the sample rate is 16k | |
| sample_rate = librosa.get_samplerate(input_file) | |
| # Stream over 10 seconds chunks rather than load the full file | |
| stream = librosa.stream( | |
| input_file, | |
| block_length=20, # number of seconds to split the batch | |
| frame_length=sample_rate, # 16000, | |
| hop_length=sample_rate, # 16000 | |
| ) | |
| for speech in stream: | |
| if len(speech.shape) > 1: | |
| speech = speech[:, 0] + speech[:, 1] | |
| if sample_rate != 16000: | |
| speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000) | |
| input_values = processor(speech, return_tensors="pt").input_values | |
| logits = model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.decode( | |
| predicted_ids[0], | |
| clean_up_tokenization_spaces=True, | |
| skip_special_tokens=True, | |
| ) | |
| transcription = clean_transcript(transcription) | |
| # transcript += transcription.lower() | |
| transcript += correct_casing(transcription.lower()) + ". " | |
| # transcript += " " | |
| whole_text = transcript[:3800] | |
| whole_text = add_question_symbols(whole_text) | |
| return whole_text | |
| def parse_final_answer(answer_text: str, contexts: List): | |
| """Parse the final answer into correct format""" | |
| answer = f"<p><b>{answer_text}</b></p> \n\n\n" | |
| docs = ( | |
| "\n".join( | |
| [ | |
| ("""<p style="text-align: justify;">""" + context)[:250] | |
| + "[...]</p>" | |
| for context in contexts[:5] | |
| ] | |
| ) | |
| ) | |
| return answer, docs | |