Charles Chan
commited on
Commit
·
86b4310
1
Parent(s):
908d31d
coding
Browse files
app.py
CHANGED
|
@@ -7,8 +7,9 @@ import random
|
|
| 7 |
|
| 8 |
# 使用 進擊的巨人 数据集
|
| 9 |
try:
|
|
|
|
| 10 |
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
| 11 |
-
answer_list = [example["Answer"] for example in dataset["train"]]
|
| 12 |
|
| 13 |
except Exception as e:
|
| 14 |
st.error(f"读取数据集失败:{e}")
|
|
@@ -67,14 +68,18 @@ with col2:
|
|
| 67 |
temperature = st.number_input("temperature", value=1.0)
|
| 68 |
max_length = st.number_input("max_length", value=1024)
|
| 69 |
|
|
|
|
|
|
|
| 70 |
col3, col4 = st.columns(2)
|
| 71 |
with col3:
|
| 72 |
-
if st.button("
|
| 73 |
dataset_size = len(dataset["train"])
|
| 74 |
random_index = random.randint(0, dataset_size - 1)
|
| 75 |
# 读取随机问题
|
| 76 |
random_question = dataset["train"][random_index]["Question"]
|
|
|
|
| 77 |
origin_answer = dataset["train"][random_index]["Answer"]
|
|
|
|
| 78 |
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
|
| 79 |
print('origin_answer: ' + origin_answer)
|
| 80 |
|
|
@@ -90,7 +95,7 @@ with col3:
|
|
| 90 |
|
| 91 |
with col4:
|
| 92 |
question = st.text_area("请输入问题", "Gemma 有哪些特点?")
|
| 93 |
-
if st.button("
|
| 94 |
if not question:
|
| 95 |
st.warning("请输入问题!")
|
| 96 |
else:
|
|
|
|
| 7 |
|
| 8 |
# 使用 進擊的巨人 数据集
|
| 9 |
try:
|
| 10 |
+
converter = pipeline("translation_zh_tw_zh_cn")
|
| 11 |
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
| 12 |
+
answer_list = [converter(example["Answer"])[0]["translation_text"] for example in dataset["train"]]
|
| 13 |
|
| 14 |
except Exception as e:
|
| 15 |
st.error(f"读取数据集失败:{e}")
|
|
|
|
| 68 |
temperature = st.number_input("temperature", value=1.0)
|
| 69 |
max_length = st.number_input("max_length", value=1024)
|
| 70 |
|
| 71 |
+
st.divider()
|
| 72 |
+
|
| 73 |
col3, col4 = st.columns(2)
|
| 74 |
with col3:
|
| 75 |
+
if st.button("使用原数据集中的随机问题"):
|
| 76 |
dataset_size = len(dataset["train"])
|
| 77 |
random_index = random.randint(0, dataset_size - 1)
|
| 78 |
# 读取随机问题
|
| 79 |
random_question = dataset["train"][random_index]["Question"]
|
| 80 |
+
random_question = converter(random_question)[0]["translation_text"]
|
| 81 |
origin_answer = dataset["train"][random_index]["Answer"]
|
| 82 |
+
origin_answer = converter(origin_answer)[0]["translation_text"]
|
| 83 |
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
|
| 84 |
print('origin_answer: ' + origin_answer)
|
| 85 |
|
|
|
|
| 95 |
|
| 96 |
with col4:
|
| 97 |
question = st.text_area("请输入问题", "Gemma 有哪些特点?")
|
| 98 |
+
if st.button("提交输入的问题"):
|
| 99 |
if not question:
|
| 100 |
st.warning("请输入问题!")
|
| 101 |
else:
|