comment-categorization-app / streamlit_app.py
kerokero-keroppi's picture
Update streamlit_app.py
66743bd verified
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import io
import json
# --- 1. デフォルトの階層カテゴリ定義 ---
# ユーザーが編集可能なデフォルトの定義
DEFAULT_HIERARCHICAL_CATEGORIES = {
"UI/UX (使いやすさ)": {
"description": "デザインや操作性など、サービスの使いやすさに関する意見。",
"sub_categories": {
"デザインが良い・悪い": "見た目のデザイン、レイアウト、色使いなどに関する意見。",
"操作が分かりにくい": "ボタンの場所が分からない、設定方法が難しいなど、操作方法に関する意見。",
"動作が重い・遅い": "ページの読み込みが遅い、アプリがフリーズするなど、システムの応答速度に関する意見。",
"その他": "上記以外の使いやすさに関する意見。"
}
},
"機能": {
"description": "サービスが提供する個別の機能に関する意見。",
"sub_categories": {
"機能への要望": "「〇〇という機能が欲しい」といった、新しい機能の追加を求める意見。",
"機能が便利・役に立った": "特定の機能が使いやすかった、役に立ったというポジティブな意見。",
"不具合・バグ報告": "「ボタンが押せない」「エラーが出る」など、機能が正常に動作しない問題の報告。",
"その他": "上記以外の機能に関する意見。"
}
},
"サポート": {
"description": "問い合わせ対応など、カスタマーサポートに関する意見。",
"sub_categories": {
"対応が良かった・悪かった": "サポート担当者の対応が丁寧だった、あるいは不親切だったという意見。",
"返信が来ない・遅い": "問い合わせても返信がない、または返信が非常に遅いことへの不満。",
"その他": "上記以外のサポートに関する意見。"
}
},
"価格・料金": {
"description": "サービスの料金プランや価格設定に関する意見。",
"sub_categories": {
"価格が高い・安い": "価格設定が内容に見合って高い、あるいは安いと感じる意見。",
"料金体系が分かりにくい": "料金プランが複雑で理解しにくいという意見。",
"その他": "上記以外の価格に関する意見。"
}
},
"その他・要望": {
"description": "上記のいずれにも当てはまらない意見や、サービス全体への要望。",
"sub_categories": {
"感謝・応援": "サービス全体への感謝や、応援するポジティブな意見。",
"アイデア提案": "具体的な改善案や、新しいサービスのアイデアに関する提案。",
"その他": "上記のいずれにも明確に当てはまらない、その他の意見。"
}
}
}
# --- 2. AIモデルの読み込み ---
@st.cache_resource
def load_model():
"""SentenceTransformerモデルをHugging Face Hubから直接ロードする"""
try:
model = SentenceTransformer('cl-nagoya/ruri-v3-310m')
st.success("✅ AIモデルの読み込みに成功しました。")
return model
except Exception as e:
st.error(f"モデルの読み込み中にエラーが発生しました: {e}")
st.stop()
# --- 3. 分類ロジック ---
# 3-1. 単一カテゴリ分類(既存のロジック)
def classify_text(text, model, definitions_dict):
"""与えられたテキストを、定義辞書に基づいて最も類似度の高いカテゴリに分類する"""
if not text or not isinstance(text, str) or not text.strip():
return "テキストが空です"
try:
category_definitions = list(definitions_dict.values())
texts_to_encode = category_definitions + [text]
embeddings = model.encode(texts_to_encode, convert_to_tensor=True)
text_embedding = embeddings[-1]
definition_embeddings = embeddings[:-1]
similarities = F.cosine_similarity(text_embedding, definition_embeddings)
most_similar_index = torch.argmax(similarities).item()
result_category = list(definitions_dict.keys())[most_similar_index]
return result_category
except Exception as e:
st.error(f"分類中にエラーが発生しました: {e}")
return "分類エラー"
# 3-2. 階層カテゴリ分類(新規追加)
def classify_subcategory(comment, main_category, model, hierarchical_defs, threshold=0.6):
"""埋め込みモデルの類似度計算によってサブカテゴリを分類する"""
if main_category not in hierarchical_defs or "sub_categories" not in hierarchical_defs[main_category]:
return "サブカテゴリなし"
sub_category_dict = hierarchical_defs[main_category]["sub_categories"]
if not sub_category_dict:
return "サブカテゴリなし"
definitions = [f"{name}: {desc}" for name, desc in sub_category_dict.items()]
embeddings = model.encode([comment] + definitions, convert_to_tensor=True)
comment_embedding = embeddings[0]
definition_embeddings = embeddings[1:]
similarities = F.cosine_similarity(comment_embedding, definition_embeddings)
best_match_index = torch.argmax(similarities).item()
best_match_score = similarities[best_match_index].item()
if best_match_score < threshold:
return "その他"
sub_category_names = list(sub_category_dict.keys())
return sub_category_names[best_match_index]
def analyze_hierarchically(comment, model, hierarchical_defs, sentiment_labels):
"""センチメント分析、大カテゴリ分類、サブカテゴリ分類を順番に行う"""
if not comment or not isinstance(comment, str) or not comment.strip():
return {"sentiment": "エラー", "category": "コメントが空", "sub_category": ""}
try:
# 1. センチメント分析
sentiment_texts = sentiment_labels + [comment]
sentiment_embeddings = model.encode(sentiment_texts, convert_to_tensor=True)
sentiment_similarities = F.cosine_similarity(sentiment_embeddings[-1], sentiment_embeddings[:-1])
sentiment = "ポジティブ" if torch.argmax(sentiment_similarities) == 0 else "ネガティブ"
# 2. 大カテゴリ分類
category_descriptions = [v["description"] for v in hierarchical_defs.values()]
category_texts = category_descriptions + [comment]
category_embeddings = model.encode(category_texts, convert_to_tensor=True)
category_similarities = F.cosine_similarity(category_embeddings[-1], category_embeddings[:-1])
main_category = list(hierarchical_defs.keys())[torch.argmax(category_similarities)]
# 3. サブカテゴリ分類
sub_category = classify_subcategory(comment, main_category, model, hierarchical_defs, threshold=0.6)
return {"sentiment": sentiment, "category": main_category, "sub_category": sub_category}
except Exception as e:
st.error(f"階層分析中にエラーが発生しました: {e}")
return {"sentiment": "分析エラー", "category": "分析エラー", "sub_category": str(e)}
# --- 4. Streamlit アプリケーションのUIとメイン処理 ---
st.set_page_config(layout="wide")
st.title("📝 テキスト分類ツール (単一/階層)")
st.markdown("アップロードしたファイルのテキストを、定義に基づいて分類します。")
# --- モデル読み込み ---
with st.spinner('AIモデルを読み込んでいます...'):
model = load_model()
# --- 分析タイプ選択 ---
analysis_type = st.radio(
"分析の種類を選択してください",
('単一カテゴリ分析', '階層カテゴリ分析'),
horizontal=True,
help="「単一カテゴリ分析」は複数の観点で分類します。「階層カテゴリ分析」は大カテゴリ→サブカテゴリの2段階で詳細に分類します。"
)
st.header("Step 1: ファイルをアップロード")
uploaded_file = st.file_uploader("分類したいテキストデータが含まれるファイルを選択してください", type=["csv", "xlsx"])
# --- 列選択 ---
selected_column = None
if uploaded_file is not None:
try:
# ファイルをメモリに読み込んでから列名を取得
uploaded_file.seek(0)
if uploaded_file.name.endswith('.csv'):
df_peek = pd.read_csv(uploaded_file, nrows=0)
else:
df_peek = pd.read_excel(uploaded_file, nrows=0)
column_options = df_peek.columns.tolist()
selected_column = st.selectbox("分類したいテキストが含まれている列を選択してください", options=column_options)
except Exception as e:
st.error(f"ファイルの列読み込みに失敗しました: {e}")
st.header("Step 2: カテゴリを定義")
# --- 分析タイプに応じたUIの表示 ---
if analysis_type == '単一カテゴリ分析':
st.markdown("分類したい観点ごとにカテゴリ定義を追加・編集してください。(最大5つまで)")
if 'definition_sets' not in st.session_state:
st.session_state.definition_sets = [
"""ポジティブ: 肯定的、好意的、賞賛、感謝の意見
ネガティブ: 否定的、批判的、不満、改善要望の意見
質問: 何かに対する疑問や問いかけ"""
]
def add_definition_set():
if len(st.session_state.definition_sets) < 5:
st.session_state.definition_sets.append("カテゴリ名1: 説明1\nカテゴリ名2: 説明2")
def remove_definition_set(index):
if len(st.session_state.definition_sets) > 1:
st.session_state.definition_sets.pop(index)
for i, def_text in enumerate(st.session_state.definition_sets):
st.subheader(f"分類セット {i+1}")
st.session_state.definition_sets[i] = st.text_area(
f"カテゴリ定義 {i+1}", value=def_text, height=120, key=f"def_area_{i}"
)
if len(st.session_state.definition_sets) > 1:
st.button(f"分類セット {i+1} を削除", key=f"remove_btn_{i}", on_click=remove_definition_set, args=(i,))
if len(st.session_state.definition_sets) < 5:
st.button("+ カテゴリ定義を追加", on_click=add_definition_set)
else: # 階層カテゴリ分析
st.markdown("大カテゴリとサブカテゴリの関係をJSON形式で定義してください。")
default_defs_str = json.dumps(DEFAULT_HIERARCHICAL_CATEGORIES, indent=4, ensure_ascii=False)
hierarchical_defs_text = st.text_area(
"階層カテゴリ定義(JSON形式)",
value=default_defs_str,
height=500
)
st.header("Step 3: 分類を実行")
if st.button("分類を実行する", type="primary"):
if uploaded_file is None or selected_column is None:
st.warning("Step 1でファイルと列を正しく設定してください。")
else:
# --- 分類処理の実行 ---
try:
uploaded_file.seek(0)
if uploaded_file.name.endswith('.csv'):
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file)
progress_bar = st.progress(0, text="分類処理を開始します...")
total_rows = len(df)
if analysis_type == '単一カテゴリ分析':
# --- 単一カテゴリ分析の処理 ---
definition_dicts = []
for def_text in st.session_state.definition_sets:
temp_dict = {parts[0].strip(): parts[1].strip() for line in def_text.strip().split('\n') if ':' in line and (parts := line.split(':', 1))}
if temp_dict:
definition_dicts.append(temp_dict)
if not definition_dicts:
st.error("有効なカテゴリ定義がありません。")
st.stop()
for i, row in df.iterrows():
text_to_classify = str(row[selected_column]) if pd.notna(row[selected_column]) else ""
for j, def_dict in enumerate(definition_dicts):
result_col_name = f"分類結果_{j+1}"
result = classify_text(text_to_classify, model, def_dict)
df.loc[i, result_col_name] = result
progress_bar.progress((i + 1) / total_rows, text=f"{i+1}/{total_rows} 件処理完了")
else:
# --- 階層カテゴリ分析の処理 ---
hierarchical_defs = json.loads(hierarchical_defs_text)
SENTIMENT_LABELS = ["ポジティブな意見", "ネガティブな意見"]
for i, row in df.iterrows():
comment = str(row[selected_column]) if pd.notna(row[selected_column]) else ""
result = analyze_hierarchically(comment, model, hierarchical_defs, SENTIMENT_LABELS)
df.loc[i, 'センチメント'] = result['sentiment']
df.loc[i, '大カテゴリ'] = result['category']
df.loc[i, 'サブカテゴリ'] = result['sub_category']
progress_bar.progress((i + 1) / total_rows, text=f"{i+1}/{total_rows} 件処理完了")
st.success("🎉 分類が完了しました!")
st.subheader("分類結果プレビュー")
st.dataframe(df.head())
# --- ダウンロード処理 ---
output = io.BytesIO()
with pd.ExcelWriter(output, engine='openpyxl') as writer:
df.to_excel(writer, index=False, sheet_name='classified_data')
processed_data = output.getvalue()
base_filename = uploaded_file.name.rsplit('.', 1)[0]
st.download_button(
label="📁 分類済みExcelをダウンロード",
data=processed_data,
file_name=f"classified_{base_filename}.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
except json.JSONDecodeError:
st.error("JSON形式の階層カテゴリ定義が正しくありません。構文を確認してください。")
except Exception as e:
st.error(f"処理中に予期せぬエラーが発生しました: {e}")