Spaces:
Build error
Build error
| import os | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import DirectoryLoader,UnstructuredFileLoader | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from langchain_community.retrievers import BM25Retriever | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.exceptions import ResponseHandlingException | |
| from glob import glob | |
| from llama_index.vector_stores.qdrant import QdrantVectorStore | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import models, SentenceTransformer | |
| from langchain.embeddings.base import Embeddings | |
| from qdrant_client.models import VectorParams | |
| import torch | |
| # from llama_index import SimpleDirectoryReader, StorageContext | |
| class ClinicalBertEmbeddings(Embeddings): | |
| def __init__(self, model_name: str = "medicalai/ClinicalBERT"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.model.eval() | |
| def embed(self, text: str): | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| embeddings = self.mean_pooling(outputs, inputs['attention_mask']) | |
| return embeddings.squeeze().numpy() | |
| def mean_pooling(self, model_output, attention_mask): | |
| token_embeddings = model_output[0] # First element of model_output contains all token embeddings | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def embed_documents(self, texts): | |
| return [self.embed(text) for text in texts] | |
| def embed_query(self, text): | |
| return self.embed(text) | |
| # Initialize our custom embeddings class | |
| embeddings = ClinicalBertEmbeddings() | |
| # # Load ClinicalBERT | |
| # tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT") | |
| # model = AutoModel.from_pretrained("medicalai/ClinicalBERT") | |
| # # Use mean pooling | |
| # pooling_model = models.Pooling(model.config.hidden_size, pooling_mode_mean_tokens=True) | |
| # # Create custom SentenceTransformer model | |
| # # sentence_transformer_model = SentenceTransformer(modules=[model, pooling_model]) | |
| # # Use mean pooling | |
| # sentence_transformer_model = SentenceTransformer(modules=[model]) | |
| # embeddings = SentenceTransformerEmbeddings(model=sentence_transformer_model) | |
| print(embeddings) | |
| # Tokenize sentences | |
| # loading from directory | |
| loader = DirectoryLoader("Data/", glob="**/*.pdf", show_progress=True, loader_cls=UnstructuredFileLoader) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=70) | |
| texts = text_splitter.split_documents(documents) | |
| # Get elements | |
| print("Create Vector Database") | |
| url = "http://localhost:6333" | |
| client = QdrantClient(url=url, prefer_grpc=False) | |
| collection_name = "vector_db" | |
| try: | |
| # Check if the collection already exists | |
| collection_info = client.get_collection(collection_name=collection_name) | |
| print(f"Collection '{collection_name}' already exists.") | |
| except Exception as e: | |
| print(f"Collection '{collection_name}' does not exist. Creating a new one.") | |
| print(f"Error: {e}") | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config= VectorParams( | |
| size=768, | |
| distance="Cosine" | |
| ) | |
| ) | |
| qdrant = Qdrant.from_documents( | |
| texts, | |
| embeddings, | |
| url=url, | |
| prefer_grpc=False, | |
| collection_name=collection_name, | |
| # optimizer_config=optimizer_config | |
| ) | |
| keyword_retriever = BM25Retriever.from_documents(texts) | |
| keyword_retriever.k = 3 | |
| print("vector database created.................") | |