Charles Chan
commited on
Commit
·
1e21aa9
1
Parent(s):
a054c10
coding
Browse files
app.py
CHANGED
|
@@ -7,34 +7,55 @@ from datasets import load_dataset
|
|
| 7 |
from opencc import OpenCC
|
| 8 |
|
| 9 |
# 使用 進擊的巨人 数据集
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# 构建向量数据库 (如果需要,仅构建一次)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
st.
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# 问答函数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def answer_question(repo_id, temperature, max_length, question):
|
| 30 |
# 初始化 Gemma 模型
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# 获取答案
|
| 40 |
try:
|
|
|
|
| 7 |
from opencc import OpenCC
|
| 8 |
|
| 9 |
# 使用 進擊的巨人 数据集
|
| 10 |
+
# 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
|
| 11 |
+
if "dataset_loaded" not in st.session_state:
|
| 12 |
+
st.session_state.dataset_loaded = False
|
| 13 |
+
if not st.session_state.dataset_loaded:
|
| 14 |
+
try:
|
| 15 |
+
with st.spinner("正在读取数据库..."):
|
| 16 |
+
converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
|
| 17 |
+
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
| 18 |
+
answer_list = [converter.convert(example["Answer"]) for example in dataset["train"]]
|
| 19 |
+
st.success("数据库读取完成!")
|
| 20 |
+
except Exception as e:
|
| 21 |
+
st.error(f"读取数据集失败:{e}")
|
| 22 |
+
st.stop()
|
| 23 |
+
st.session_state.dataset_loaded = True
|
| 24 |
|
| 25 |
# 构建向量数据库 (如果需要,仅构建一次)
|
| 26 |
+
if "vector_created" not in st.session_state:
|
| 27 |
+
st.session_state.vector_created = False
|
| 28 |
+
if not st.session_state.vector_created:
|
| 29 |
+
try:
|
| 30 |
+
with st.spinner("正在构建向量数据库..."):
|
| 31 |
+
embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
|
| 32 |
+
db = FAISS.from_texts(answer_list, embeddings)
|
| 33 |
+
st.success("向量数据库构建完成!")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
st.error(f"向量数据库构建失败:{e}")
|
| 36 |
+
st.stop()
|
| 37 |
+
st.session_state.vector_created = True
|
| 38 |
|
| 39 |
# 问答函数
|
| 40 |
+
if "repo_id" not in st.session_state:
|
| 41 |
+
st.session_state.repo_id = ''
|
| 42 |
+
if "temperature" not in st.session_state:
|
| 43 |
+
st.session_state.temperature = ''
|
| 44 |
+
if "max_length" not in st.session_state:
|
| 45 |
+
st.session_state.max_length = ''
|
| 46 |
def answer_question(repo_id, temperature, max_length, question):
|
| 47 |
# 初始化 Gemma 模型
|
| 48 |
+
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
| 49 |
+
try:
|
| 50 |
+
with st.spinner("正在初始化 Gemma 模型..."):
|
| 51 |
+
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
|
| 52 |
+
st.success("Gemma 模型初始化完成!")
|
| 53 |
+
st.session_state.repo_id = repo_id
|
| 54 |
+
st.session_state.temperature = temperature
|
| 55 |
+
st.session_state.max_length = max_length
|
| 56 |
+
except Exception as e:
|
| 57 |
+
st.error(f"Gemma 模型加载失败:{e}")
|
| 58 |
+
st.stop()
|
| 59 |
|
| 60 |
# 获取答案
|
| 61 |
try:
|