Charles Chan
commited on
Commit
·
edfb894
1
Parent(s):
c2aa18b
coding
Browse files
app.py
CHANGED
|
@@ -15,13 +15,13 @@ if "data_list" not in st.session_state:
|
|
| 15 |
if not st.session_state.data_list:
|
| 16 |
try:
|
| 17 |
with st.spinner("正在读取数据库..."):
|
| 18 |
-
|
| 19 |
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
| 20 |
data_list = []
|
| 21 |
answer_list = []
|
| 22 |
for example in dataset["train"]:
|
| 23 |
-
converted_answer =
|
| 24 |
-
converted_question =
|
| 25 |
answer_list.append(converted_answer)
|
| 26 |
data_list.append({"Question": converted_question, "Answer": converted_answer})
|
| 27 |
st.session_state.answer_list = answer_list
|
|
@@ -63,7 +63,7 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
| 63 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
| 64 |
try:
|
| 65 |
with st.spinner("正在初始化 Gemma 模型..."):
|
| 66 |
-
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
|
| 67 |
st.success("Gemma 模型初始化完成!")
|
| 68 |
print("Gemma 模型初始化完成!")
|
| 69 |
st.session_state.repo_id = repo_id
|
|
@@ -91,7 +91,7 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
| 91 |
print("本地数据集筛选完成!")
|
| 92 |
|
| 93 |
with st.spinner("正在生成答案..."):
|
| 94 |
-
answer = llm.invoke(prompt)
|
| 95 |
# 去掉 prompt 的内容
|
| 96 |
answer = answer.replace(prompt, "").strip()
|
| 97 |
st.success("答案已经生成!")
|
|
@@ -113,6 +113,13 @@ with col2:
|
|
| 113 |
|
| 114 |
st.divider()
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
col3, col4 = st.columns(2)
|
| 117 |
with col3:
|
| 118 |
if st.button("使用原数据集中的随机问题"):
|
|
@@ -120,9 +127,7 @@ with col3:
|
|
| 120 |
random_index = random.randint(0, dataset_size - 1)
|
| 121 |
# 读取随机问题
|
| 122 |
random_question = st.session_state.data_list[random_index]["Question"]
|
| 123 |
-
random_question = st.session_state.converter.convert(random_question)
|
| 124 |
origin_answer = st.session_state.data_list[random_index]["Answer"]
|
| 125 |
-
origin_answer = st.session_state.converter.convert(origin_answer)
|
| 126 |
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
|
| 127 |
print('origin_answer: ' + origin_answer)
|
| 128 |
|
|
@@ -130,20 +135,12 @@ with col3:
|
|
| 130 |
st.write(random_question)
|
| 131 |
st.write("原始答案:")
|
| 132 |
st.write(origin_answer)
|
| 133 |
-
|
| 134 |
-
print('prompt: ' + result["prompt"])
|
| 135 |
-
print('answer: ' + result["answer"])
|
| 136 |
-
st.write("生成答案:")
|
| 137 |
-
st.write(result["answer"])
|
| 138 |
|
| 139 |
with col4:
|
| 140 |
-
question = st.text_area("请输入问题", "
|
| 141 |
if st.button("提交输入的问题"):
|
| 142 |
if not question:
|
| 143 |
st.warning("请输入问题!")
|
| 144 |
else:
|
| 145 |
-
|
| 146 |
-
print('prompt: ' + result["prompt"])
|
| 147 |
-
print('answer: ' + result["answer"])
|
| 148 |
-
st.write("生成答案:")
|
| 149 |
-
st.write(result["answer"])
|
|
|
|
| 15 |
if not st.session_state.data_list:
|
| 16 |
try:
|
| 17 |
with st.spinner("正在读取数据库..."):
|
| 18 |
+
converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
|
| 19 |
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
| 20 |
data_list = []
|
| 21 |
answer_list = []
|
| 22 |
for example in dataset["train"]:
|
| 23 |
+
converted_answer = converter.convert(example["Answer"])
|
| 24 |
+
converted_question = converter.convert(example["Question"])
|
| 25 |
answer_list.append(converted_answer)
|
| 26 |
data_list.append({"Question": converted_question, "Answer": converted_answer})
|
| 27 |
st.session_state.answer_list = answer_list
|
|
|
|
| 63 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
| 64 |
try:
|
| 65 |
with st.spinner("正在初始化 Gemma 模型..."):
|
| 66 |
+
st.session_state.llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
|
| 67 |
st.success("Gemma 模型初始化完成!")
|
| 68 |
print("Gemma 模型初始化完成!")
|
| 69 |
st.session_state.repo_id = repo_id
|
|
|
|
| 91 |
print("本地数据集筛选完成!")
|
| 92 |
|
| 93 |
with st.spinner("正在生成答案..."):
|
| 94 |
+
answer = st.session_state.llm.invoke(prompt)
|
| 95 |
# 去掉 prompt 的内容
|
| 96 |
answer = answer.replace(prompt, "").strip()
|
| 97 |
st.success("答案已经生成!")
|
|
|
|
| 113 |
|
| 114 |
st.divider()
|
| 115 |
|
| 116 |
+
def generate_answer(repo_id, temperature, max_length, question):
|
| 117 |
+
result = answer_question(repo_id, float(temperature), int(max_length), question)
|
| 118 |
+
print('prompt: ' + result["prompt"])
|
| 119 |
+
print('answer: ' + result["answer"])
|
| 120 |
+
st.write("生成答案:")
|
| 121 |
+
st.write(result["answer"])
|
| 122 |
+
|
| 123 |
col3, col4 = st.columns(2)
|
| 124 |
with col3:
|
| 125 |
if st.button("使用原数据集中的随机问题"):
|
|
|
|
| 127 |
random_index = random.randint(0, dataset_size - 1)
|
| 128 |
# 读取随机问题
|
| 129 |
random_question = st.session_state.data_list[random_index]["Question"]
|
|
|
|
| 130 |
origin_answer = st.session_state.data_list[random_index]["Answer"]
|
|
|
|
| 131 |
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
|
| 132 |
print('origin_answer: ' + origin_answer)
|
| 133 |
|
|
|
|
| 135 |
st.write(random_question)
|
| 136 |
st.write("原始答案:")
|
| 137 |
st.write(origin_answer)
|
| 138 |
+
generate_answer(gemma, float(temperature), int(max_length), random_question)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
with col4:
|
| 141 |
+
question = st.text_area("请输入问题", "《进击的巨人》中都有哪些主要角色?")
|
| 142 |
if st.button("提交输入的问题"):
|
| 143 |
if not question:
|
| 144 |
st.warning("请输入问题!")
|
| 145 |
else:
|
| 146 |
+
generate_answer(gemma, float(temperature), int(max_length), question)
|
|
|
|
|
|
|
|
|
|
|
|