Spaces:
Sleeping
Sleeping
File size: 4,775 Bytes
8570037 b63078f e3a3cf7 43769dc b63078f e3a3cf7 b63078f e3a3cf7 b63078f e3a3cf7 b63078f e3a3cf7 b63078f e3a3cf7 b63078f e3a3cf7 b63078f e3a3cf7 8570037 e3a3cf7 0d9db4d e3a3cf7 b63078f e3a3cf7 8570037 43769dc e3a3cf7 0d9db4d e3a3cf7 0d9db4d e3a3cf7 0d9db4d 8570037 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from datasets import load_dataset
import random
import re
def load_random_conversations():
"""
๋ฐ์ดํฐ์
'gaeunseo/Taskmaster_sample_data'์ train split์์
conversation_id๋ณ๋ก ๊ทธ๋ฃนํํ ํ, ๋ชจ๋ ํ์ used๊ฐ False์ธ ๊ทธ๋ฃน๋ง ๋จ๊ฒจ,
์ด ์ค ๋๋คํ๊ฒ 2๊ฐ ๊ทธ๋ฃน์ ์ ํํ์ฌ ๊ฐ ๋ํ์ ๋ชจ๋ utterance๋ฅผ ์ค๋ฐ๊ฟ์ผ๋ก ์ฐ๊ฒฐํ ๋ฌธ์์ด ๋ ๊ฐ๋ฅผ ๋ฐํ.
"""
ds = load_dataset("gaeunseo/Taskmaster_sample_data")["train"]
# conversation_id ๋ณ๋ก ๊ทธ๋ฃนํ
groups = {}
for row in ds:
cid = row["conversation_id"]
groups.setdefault(cid, []).append(row)
# ๊ทธ๋ฃน ๋ด ๋ชจ๋ ํ์ used๊ฐ False์ธ ๊ทธ๋ฃน๋ง ์ ํ
valid_groups = [grp for grp in groups.values() if all(not row["used"] for row in grp)]
if len(valid_groups) < 2:
return "Not enough unused conversations", "Not enough unused conversations"
# ์ ํจํ ๊ทธ๋ฃน ์ค ๋๋คํ๊ฒ 2๊ฐ ์ ํ
selected_groups = random.sample(valid_groups, 2)
# ๊ฐ ๊ทธ๋ฃน์ ๋ชจ๋ utterance๋ฅผ ์ด์ด๋ถ์
conv_A = "\n".join(row["utterance"] for row in selected_groups[0])
conv_B = "\n".join(row["utterance"] for row in selected_groups[1])
return conv_A, conv_B
def format_chat(conv_text):
"""
conv_text ๋ฌธ์์ด์ [turn]๊ณผ [BC] ํ ํฐ์ ๊ธฐ์ค์ผ๋ก ๋ถํ ํ ํ,
์ฒซ ๋ฒ์งธ ๋ฐํ๋ ์ฌ๋(์ฌ์ฉ์), ๋ ๋ฒ์งธ ๋ฐํ๋ AI์ ์๋ต์ผ๋ก ๊ฐ์ฃผํ์ฌ
gr.Chatbot ์ปดํฌ๋ํธ์์ ์ฌ์ฉํ๋ (user_message, ai_message) ํํ ๋ฆฌ์คํธ๋ก ๋ณํ.
- ์ฌ๋(์ฌ์ฉ์) ๋ฉ์์ง: ์ค๋ฅธ์ชฝ ๋งํ์ , ๋ฉ์์ง ๋์ "๐ง" ์ด๋ชจํฐ์ฝ ์ถ๊ฐ
- AI ๋ฉ์์ง: ์ผ์ชฝ ๋งํ์ , ๋ฉ์์ง ์์ "๐ค" ์ด๋ชจํฐ์ฝ ์ถ๊ฐ
"""
# [turn]์ [BC]๋ฅผ ๊ตฌ๋ถ์๋ก ์ฌ์ฉํ์ฌ ๋ฐํ ๋ถํ
utterances = re.split(r'\[turn\]|\[BC\]', conv_text)
# ๊ณต๋ฐฑ ์ ๊ฑฐ ๋ฐ ๋น ๋ฌธ์์ด ์ ๊ฑฐ
utterances = [utt.strip() for utt in utterances if utt.strip()]
chat = []
# ๋ฒ๊ฐ์ ๋ฑ์ฅํ๋ค๊ณ ๊ฐ์ (์ฒซ ๋ฒ์งธ: ์ฌ๋, ๋ ๋ฒ์งธ: AI, ...)
for i in range(0, len(utterances), 2):
# ์ฒซ ๋ฒ์งธ ๋ฐํ โ ์ฌ๋ (์ค๋ฅธ์ชฝ ์ ๋ ฌ: gr.Chatbot์์ ์ฌ์ฉ์ ๋ฉ์์ง๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ค๋ฅธ์ชฝ์ ํ์๋จ)
human = utterances[i] + " ๐ง"
ai = ""
if i + 1 < len(utterances):
# ๋ ๋ฒ์งธ ๋ฐํ โ AI (์ผ์ชฝ ์ ๋ ฌ: gr.Chatbot์์ ๋ด ๋ฉ์์ง๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ผ์ชฝ์ ํ์๋จ)
ai = "๐ค " + utterances[i + 1]
chat.append((human, ai))
return chat
def load_and_format_conversations():
"""
๋ฐ์ดํฐ์
์์ ๋๋คํ๊ฒ ๋ ๋ํ ๋ฌธ์์ด์ ๊ฐ์ ธ์จ ํ, ๊ฐ๊ฐ format_chat()์ ํตํด
์ฑํ
์ธํฐํ์ด์ค์ ๋ง๊ฒ ๋ณํํ์ฌ ๋ฐํ.
"""
conv_A, conv_B = load_random_conversations()
# ์๋ฌ ๋ฉ์์ง์ธ ๊ฒฝ์ฐ ๊ทธ๋๋ก ๋ฐํ
if conv_A.startswith("Not enough"):
return conv_A, conv_B
return format_chat(conv_A), format_chat(conv_B)
# ํ๊ฐ ๋ฒํผ ํด๋ฆญ ์ ์
๋ฐ์ดํธํ ์ ์ญ ๋ณ์
statement = ""
def update_statement(val):
global statement
statement = val
return statement
with gr.Blocks() as demo:
# ์๋จ: ์ข์ฐ์ ๊ฐ๊ฐ ์ฑํ
์ธํฐํ์ด์ค ๋ฐฐ์น (Conversation A์ Conversation B)
with gr.Row():
chat_A = gr.Chatbot(label="Conversation A")
chat_B = gr.Chatbot(label="Conversation B")
# "Load Random Conversations" ๋ฒํผ์ ๋๋ฌ ๋ฐ์ดํฐ์
์์ ๋ํ๋ฅผ ๋ถ๋ฌ์ค๊ณ ์ฑํ
์ธํฐํ์ด์ค์ ํ์
load_btn = gr.Button("Load Random Conversations")
load_btn.click(fn=load_and_format_conversations, inputs=[], outputs=[chat_A, chat_B])
# ํ๋จ: ํ๊ฐ ๋ฒํผ 4๊ฐ ๋ฐฐ์น
with gr.Row():
btn_both_good = gr.Button("Both good") # "๋ ๋ค ์ข์" โ "BG"
btn_a_better = gr.Button("A is better") # "A๊ฐ ๋ ์ข์" โ "AG"
btn_b_better = gr.Button("B is better") # "B๊ฐ ๋ ์ข์" โ "BG"
btn_both_bad = gr.Button("Both not good") # "๋ ๋ค ๋ณ๋ก์" โ "BB"
# ์ ํ๋ ํ๊ฐ๊ฐ(statement)์ ๋ณด์ฌ์ฃผ๋ ํ
์คํธ๋ฐ์ค
statement_output = gr.Textbox(label="Selected Statement", value="", interactive=False)
# ๊ฐ ํ๊ฐ ๋ฒํผ ํด๋ฆญ ์ ํด๋น ์ํ๊ฐ ์
๋ฐ์ดํธ
btn_both_good.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
btn_a_better.click(fn=lambda: update_statement("AG"), inputs=[], outputs=statement_output)
btn_b_better.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
btn_both_bad.click(fn=lambda: update_statement("BB"), inputs=[], outputs=statement_output)
demo.launch()
|