Spaces:
Runtime error
Runtime error
| import datasets | |
| import faiss | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from datasets import Dataset | |
| from transformers import FeatureExtractionPipeline, pipeline | |
| def load_encoder_pipeline(encoder_path: str) -> FeatureExtractionPipeline: | |
| """訓練済みの教師なしSimCSEのエンコーダを読み込む""" | |
| encoder_pipeline = pipeline("feature-extraction", model=encoder_path) | |
| return encoder_pipeline | |
| def load_dataset(dataset_dir: str) -> Dataset: | |
| """文埋め込み適用済みのデータセットを読み込み、Faissのインデックスを構築""" | |
| # ディスクに保存されたデータセットを読み込む | |
| dataset = datasets.load_from_disk(dataset_dir) | |
| # データセットの"embeddings"フィールドの値からFaissのインデックスを構築する | |
| emb_dim = len(dataset[0]["embeddings"]) | |
| index = faiss.IndexFlatIP(emb_dim) | |
| dataset.add_faiss_index("embeddings", custom_index=index) | |
| return dataset | |
| def embed_text( | |
| text: str, encoder_pipeline: FeatureExtractionPipeline | |
| ) -> np.ndarray: | |
| """教師なしSimCSEのエンコーダを用いてテキストの埋め込みを計算""" | |
| with torch.inference_mode(): | |
| # encoder_pipelineが返すTensorのsizeは(1, トークン数, 埋め込みの次元数) | |
| encoded_text = encoder_pipeline(text, return_tensors="pt")[0][0] | |
| # ベクトルをNumPyのarrayに変換 | |
| emb = encoded_text.cpu().numpy().astype(np.float32) | |
| # ベクトルのノルムが1になるように正規化 | |
| emb = emb / np.linalg.norm(emb) | |
| return emb | |
| def search_similar_texts( | |
| query_text: str, | |
| dataset: Dataset, | |
| encoder_pipeline: FeatureExtractionPipeline, | |
| k: int = 5, | |
| ) -> list[dict[str, float | str]]: | |
| """モデルとデータセットを用いてクエリの類似文検索を実行""" | |
| # クエリに対して類似テキストをk件取得する | |
| scores, retrieved_examples = dataset.get_nearest_examples( | |
| "embeddings", embed_text(query_text, encoder_pipeline), k=k | |
| ) | |
| titles = retrieved_examples["title"] | |
| texts = retrieved_examples["text"] | |
| # 検索された類似テキストをdictのlistにして返す | |
| results = [ | |
| {"score": score, "title": title, "text": text} | |
| for score, title, text in zip(scores, titles, texts) | |
| ] | |
| return results | |
| # 訓練済みの教師なしSimCSEのモデルを読み込む | |
| encoder_pipeline = load_encoder_pipeline("outputs_unsup_simcse/encoder") | |
| # 文埋め込み適用済みのデータセットを読み込む | |
| dataset = load_dataset("outputs_unsup_simcse/embedded_paragraphs") | |
| # デモページのタイトルを表示する | |
| st.title(":mag: Wikipedia Paragraph Search") | |
| # デモページのフォームを表示する | |
| with st.form("input_form"): | |
| # クエリの入力欄を表示し、入力された値を受け取る | |
| query_text = st.text_input( | |
| "クエリを入力:", value="日本語は、主に日本で話されている言語である。", max_chars=150 | |
| ) | |
| # 検索する段落数のスライダーを表示し、設定された値を受け取る | |
| k = st.slider("検索する段落数:", min_value=1, max_value=100, value=10) | |
| # 検索を実行するボタンを表示し、押下されたらTrueを受け取る | |
| is_submitted = st.form_submit_button("Search") | |
| # 検索結果を表示する | |
| if is_submitted and len(query_text) > 0: | |
| # クエリに対して類似文検索を実行し、検索結果を受け取る | |
| serach_results = search_similar_texts( | |
| query_text, dataset, encoder_pipeline, k=k | |
| ) | |
| # 検索結果を表示する | |
| st.subheader("検索結果") | |
| st.dataframe(serach_results, use_container_width=True) | |
| st.caption("セルのダブルクリックで全体が表示されます") | |