Interface2 / app.py
gaeunseo's picture
Update app.py
e3a3cf7 verified
raw
history blame
4.78 kB
import gradio as gr
from datasets import load_dataset
import random
import re
def load_random_conversations():
"""
๋ฐ์ดํ„ฐ์…‹ 'gaeunseo/Taskmaster_sample_data'์˜ train split์—์„œ
conversation_id๋ณ„๋กœ ๊ทธ๋ฃนํ™”ํ•œ ํ›„, ๋ชจ๋“  ํ–‰์˜ used๊ฐ€ False์ธ ๊ทธ๋ฃน๋งŒ ๋‚จ๊ฒจ,
์ด ์ค‘ ๋žœ๋คํ•˜๊ฒŒ 2๊ฐœ ๊ทธ๋ฃน์„ ์„ ํƒํ•˜์—ฌ ๊ฐ ๋Œ€ํ™”์˜ ๋ชจ๋“  utterance๋ฅผ ์ค„๋ฐ”๊ฟˆ์œผ๋กœ ์—ฐ๊ฒฐํ•œ ๋ฌธ์ž์—ด ๋‘ ๊ฐœ๋ฅผ ๋ฐ˜ํ™˜.
"""
ds = load_dataset("gaeunseo/Taskmaster_sample_data")["train"]
# conversation_id ๋ณ„๋กœ ๊ทธ๋ฃนํ™”
groups = {}
for row in ds:
cid = row["conversation_id"]
groups.setdefault(cid, []).append(row)
# ๊ทธ๋ฃน ๋‚ด ๋ชจ๋“  ํ–‰์˜ used๊ฐ€ False์ธ ๊ทธ๋ฃน๋งŒ ์„ ํƒ
valid_groups = [grp for grp in groups.values() if all(not row["used"] for row in grp)]
if len(valid_groups) < 2:
return "Not enough unused conversations", "Not enough unused conversations"
# ์œ ํšจํ•œ ๊ทธ๋ฃน ์ค‘ ๋žœ๋คํ•˜๊ฒŒ 2๊ฐœ ์„ ํƒ
selected_groups = random.sample(valid_groups, 2)
# ๊ฐ ๊ทธ๋ฃน์˜ ๋ชจ๋“  utterance๋ฅผ ์ด์–ด๋ถ™์ž„
conv_A = "\n".join(row["utterance"] for row in selected_groups[0])
conv_B = "\n".join(row["utterance"] for row in selected_groups[1])
return conv_A, conv_B
def format_chat(conv_text):
"""
conv_text ๋ฌธ์ž์—ด์„ [turn]๊ณผ [BC] ํ† ํฐ์„ ๊ธฐ์ค€์œผ๋กœ ๋ถ„ํ• ํ•œ ํ›„,
์ฒซ ๋ฒˆ์งธ ๋ฐœํ™”๋Š” ์‚ฌ๋žŒ(์‚ฌ์šฉ์ž), ๋‘ ๋ฒˆ์งธ ๋ฐœํ™”๋Š” AI์˜ ์‘๋‹ต์œผ๋กœ ๊ฐ„์ฃผํ•˜์—ฌ
gr.Chatbot ์ปดํฌ๋„ŒํŠธ์—์„œ ์‚ฌ์šฉํ•˜๋Š” (user_message, ai_message) ํŠœํ”Œ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜.
- ์‚ฌ๋žŒ(์‚ฌ์šฉ์ž) ๋ฉ”์‹œ์ง€: ์˜ค๋ฅธ์ชฝ ๋งํ’์„ , ๋ฉ”์‹œ์ง€ ๋์— "๐Ÿง‘" ์ด๋ชจํ‹ฐ์ฝ˜ ์ถ”๊ฐ€
- AI ๋ฉ”์‹œ์ง€: ์™ผ์ชฝ ๋งํ’์„ , ๋ฉ”์‹œ์ง€ ์•ž์— "๐Ÿค–" ์ด๋ชจํ‹ฐ์ฝ˜ ์ถ”๊ฐ€
"""
# [turn]์™€ [BC]๋ฅผ ๊ตฌ๋ถ„์ž๋กœ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐœํ™” ๋ถ„ํ• 
utterances = re.split(r'\[turn\]|\[BC\]', conv_text)
# ๊ณต๋ฐฑ ์ œ๊ฑฐ ๋ฐ ๋นˆ ๋ฌธ์ž์—ด ์ œ๊ฑฐ
utterances = [utt.strip() for utt in utterances if utt.strip()]
chat = []
# ๋ฒˆ๊ฐˆ์•„ ๋“ฑ์žฅํ•œ๋‹ค๊ณ  ๊ฐ€์ • (์ฒซ ๋ฒˆ์งธ: ์‚ฌ๋žŒ, ๋‘ ๋ฒˆ์งธ: AI, ...)
for i in range(0, len(utterances), 2):
# ์ฒซ ๋ฒˆ์งธ ๋ฐœํ™” โ†’ ์‚ฌ๋žŒ (์˜ค๋ฅธ์ชฝ ์ •๋ ฌ: gr.Chatbot์—์„œ ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ์˜ค๋ฅธ์ชฝ์— ํ‘œ์‹œ๋จ)
human = utterances[i] + " ๐Ÿง‘"
ai = ""
if i + 1 < len(utterances):
# ๋‘ ๋ฒˆ์งธ ๋ฐœํ™” โ†’ AI (์™ผ์ชฝ ์ •๋ ฌ: gr.Chatbot์—์„œ ๋ด‡ ๋ฉ”์‹œ์ง€๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ์™ผ์ชฝ์— ํ‘œ์‹œ๋จ)
ai = "๐Ÿค– " + utterances[i + 1]
chat.append((human, ai))
return chat
def load_and_format_conversations():
"""
๋ฐ์ดํ„ฐ์…‹์—์„œ ๋žœ๋คํ•˜๊ฒŒ ๋‘ ๋Œ€ํ™” ๋ฌธ์ž์—ด์„ ๊ฐ€์ ธ์˜จ ํ›„, ๊ฐ๊ฐ format_chat()์„ ํ†ตํ•ด
์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค์— ๋งž๊ฒŒ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ˜ํ™˜.
"""
conv_A, conv_B = load_random_conversations()
# ์—๋Ÿฌ ๋ฉ”์‹œ์ง€์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜
if conv_A.startswith("Not enough"):
return conv_A, conv_B
return format_chat(conv_A), format_chat(conv_B)
# ํ‰๊ฐ€ ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์—…๋ฐ์ดํŠธํ•  ์ „์—ญ ๋ณ€์ˆ˜
statement = ""
def update_statement(val):
global statement
statement = val
return statement
with gr.Blocks() as demo:
# ์ƒ๋‹จ: ์ขŒ์šฐ์— ๊ฐ๊ฐ ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค ๋ฐฐ์น˜ (Conversation A์™€ Conversation B)
with gr.Row():
chat_A = gr.Chatbot(label="Conversation A")
chat_B = gr.Chatbot(label="Conversation B")
# "Load Random Conversations" ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋Œ€ํ™”๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค์— ํ‘œ์‹œ
load_btn = gr.Button("Load Random Conversations")
load_btn.click(fn=load_and_format_conversations, inputs=[], outputs=[chat_A, chat_B])
# ํ•˜๋‹จ: ํ‰๊ฐ€ ๋ฒ„ํŠผ 4๊ฐœ ๋ฐฐ์น˜
with gr.Row():
btn_both_good = gr.Button("Both good") # "๋‘˜ ๋‹ค ์ข‹์Œ" โ†’ "BG"
btn_a_better = gr.Button("A is better") # "A๊ฐ€ ๋” ์ข‹์Œ" โ†’ "AG"
btn_b_better = gr.Button("B is better") # "B๊ฐ€ ๋” ์ข‹์Œ" โ†’ "BG"
btn_both_bad = gr.Button("Both not good") # "๋‘˜ ๋‹ค ๋ณ„๋กœ์ž„" โ†’ "BB"
# ์„ ํƒ๋œ ํ‰๊ฐ€๊ฐ’(statement)์„ ๋ณด์—ฌ์ฃผ๋Š” ํ…์ŠคํŠธ๋ฐ•์Šค
statement_output = gr.Textbox(label="Selected Statement", value="", interactive=False)
# ๊ฐ ํ‰๊ฐ€ ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ํ•ด๋‹น ์ƒํƒœ๊ฐ’ ์—…๋ฐ์ดํŠธ
btn_both_good.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
btn_a_better.click(fn=lambda: update_statement("AG"), inputs=[], outputs=statement_output)
btn_b_better.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
btn_both_bad.click(fn=lambda: update_statement("BB"), inputs=[], outputs=statement_output)
demo.launch()