Charles Chan
commited on
Commit
·
10b5e55
1
Parent(s):
4b8fc6b
coding
Browse files
app.py
CHANGED
|
@@ -51,7 +51,6 @@ if not st.session_state.vector_created:
|
|
| 51 |
st.stop()
|
| 52 |
st.session_state.vector_created = True
|
| 53 |
|
| 54 |
-
# 问答函数
|
| 55 |
if "repo_id" not in st.session_state:
|
| 56 |
st.session_state.repo_id = ''
|
| 57 |
if "temperature" not in st.session_state:
|
|
@@ -64,10 +63,9 @@ def get_answer(prompt):
|
|
| 64 |
# 去掉 prompt 的内容
|
| 65 |
answer = answer.replace(prompt, "").strip()
|
| 66 |
print(answer)
|
| 67 |
-
st.success("答案已经生成!")
|
| 68 |
-
print("答案已经生成!")
|
| 69 |
return answer
|
| 70 |
|
|
|
|
| 71 |
def answer_question(repo_id, temperature, max_length, question):
|
| 72 |
# 初始化 Gemma 模型
|
| 73 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
|
@@ -85,6 +83,10 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
| 85 |
|
| 86 |
# 获取答案
|
| 87 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
with st.spinner("正在筛选本地数据集..."):
|
| 89 |
question_embedding = st.session_state.embeddings.embed_query(question)
|
| 90 |
question_embedding_str = " ".join(map(str, question_embedding))
|
|
@@ -101,8 +103,9 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
| 101 |
print("本地数据集筛选完成!")
|
| 102 |
|
| 103 |
with st.spinner("正在生成答案..."):
|
| 104 |
-
pure_answer = get_answer(question)
|
| 105 |
answer = get_answer(prompt)
|
|
|
|
|
|
|
| 106 |
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
|
| 107 |
except Exception as e:
|
| 108 |
st.error(f"问答过程出错:{e}")
|
|
|
|
| 51 |
st.stop()
|
| 52 |
st.session_state.vector_created = True
|
| 53 |
|
|
|
|
| 54 |
if "repo_id" not in st.session_state:
|
| 55 |
st.session_state.repo_id = ''
|
| 56 |
if "temperature" not in st.session_state:
|
|
|
|
| 63 |
# 去掉 prompt 的内容
|
| 64 |
answer = answer.replace(prompt, "").strip()
|
| 65 |
print(answer)
|
|
|
|
|
|
|
| 66 |
return answer
|
| 67 |
|
| 68 |
+
# 问答函数
|
| 69 |
def answer_question(repo_id, temperature, max_length, question):
|
| 70 |
# 初始化 Gemma 模型
|
| 71 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
|
|
|
| 83 |
|
| 84 |
# 获取答案
|
| 85 |
try:
|
| 86 |
+
with st.spinner("正在生成答案(基于模型自身)..."):
|
| 87 |
+
pure_answer = get_answer(question)
|
| 88 |
+
st.success("答案生成完毕(基于模型自身)!")
|
| 89 |
+
print("答案生成完毕(基于模型自身)!")
|
| 90 |
with st.spinner("正在筛选本地数据集..."):
|
| 91 |
question_embedding = st.session_state.embeddings.embed_query(question)
|
| 92 |
question_embedding_str = " ".join(map(str, question_embedding))
|
|
|
|
| 103 |
print("本地数据集筛选完成!")
|
| 104 |
|
| 105 |
with st.spinner("正在生成答案..."):
|
|
|
|
| 106 |
answer = get_answer(prompt)
|
| 107 |
+
st.success("答案生成完毕!")
|
| 108 |
+
print("答案生成完毕!")
|
| 109 |
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
|
| 110 |
except Exception as e:
|
| 111 |
st.error(f"问答过程出错:{e}")
|