Interface2 / app.py
gaeunseo's picture
Update app.py
7b078cc verified
import os
import threading
import random
import pandas as pd
from datasets import load_dataset
import gradio as gr
#############################################
# CSV ํŒŒ์ผ ๊ด€๋ จ ํ•จ์ˆ˜ ๋ฐ ์ „์—ญ ๋ฐ์ดํ„ฐ ์ดˆ๊ธฐํ™”
#############################################
DATA_FILE = "Interface1.csv"
data_lock = threading.Lock()
def initialize_global_data():
"""
CSV ํŒŒ์ผ(DATA_FILE)์ด ์—†์œผ๋ฉด, Hugging Face ๋ฐ์ดํ„ฐ์…‹(gaeunseo/Taskmaster_sample_data)์˜ train split์„
DataFrame์œผ๋กœ ๋ณ€ํ™˜ํ•œ ํ›„ ํ•„์š”ํ•œ ์ปฌ๋Ÿผ(used, overlapping, text)์„ ์ถ”๊ฐ€ํ•˜๊ณ  CSV๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฏธ ํŒŒ์ผ์ด ์žˆ์œผ๋ฉด ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์–ด DataFrame์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
if not os.path.exists(DATA_FILE):
ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
data = ds.to_pandas()
# ํ•„์š”ํ•œ ์ปฌ๋Ÿผ์ด ์—†์œผ๋ฉด ์ถ”๊ฐ€
if "used" not in data.columns:
data["used"] = False
if "overlapping" not in data.columns:
data["overlapping"] = ""
if "text" not in data.columns:
data["text"] = ""
data.to_csv(DATA_FILE, index=False)
return data
else:
with data_lock:
df = pd.read_csv(DATA_FILE)
return df
def load_global_data():
"""CSV ํŒŒ์ผ์—์„œ global_data DataFrame์„ ์ฝ์–ด์˜ต๋‹ˆ๋‹ค."""
with data_lock:
df = pd.read_csv(DATA_FILE)
return df
def save_global_data(df):
"""DataFrame์„ CSV ํŒŒ์ผ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค."""
with data_lock:
df.to_csv(DATA_FILE, index=False)
# CSV ํŒŒ์ผ์— ์ €์žฅ๋œ global_data ์ดˆ๊ธฐํ™”
global_data = initialize_global_data()
#############################################
# ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋žœ๋ค ๋Œ€ํ™” ํ–‰ ์„ ํƒ ํ•จ์ˆ˜
#############################################
def get_random_row_from_dataset():
"""
CSV ํŒŒ์ผ์— ์ €์žฅ๋œ global_data์—์„œ,
1. conversation_id๋ณ„๋กœ ๊ทธ๋ฃนํ™”ํ•˜๊ณ ,
2. ๊ฐ ๊ทธ๋ฃน์—์„œ ๋ชจ๋“  ํ–‰์˜ used ์ปฌ๋Ÿผ์ด False์ด๋ฉฐ,
๊ทธ๋ฃน ๋‚ด์— overlapping ์ปฌ๋Ÿผ์ด "TT"์ธ ํ–‰์ด ์กด์žฌํ•˜๋Š” ๊ทธ๋ฃน๋งŒ valid๋กœ ๊ฐ„์ฃผํ•ฉ๋‹ˆ๋‹ค.
validํ•œ ๊ทธ๋ฃน๋“ค ์ค‘ ๋žœ๋คํ•˜๊ฒŒ ํ•˜๋‚˜์˜ ๊ทธ๋ฃน์„ ์„ ํƒํ•œ ํ›„,
- ํ•ด๋‹น ๊ทธ๋ฃน์˜ ๋ชจ๋“  ํ–‰์˜ used ๊ฐ’์„ True๋กœ ์—…๋ฐ์ดํŠธํ•˜๊ณ  CSV ํŒŒ์ผ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
- ์„ ํƒ๋œ ๊ทธ๋ฃน ๋‚ด์—์„œ overlapping ์ปฌ๋Ÿผ์ด "TT"์ธ ํ–‰(์—ฌ๋Ÿฌ ๊ฐœ๋ผ๋ฉด ์ฒซ ๋ฒˆ์งธ)์„ dict๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
global global_data
global_data = load_global_data() # ์ตœ์‹  ๋ฐ์ดํ„ฐ ๋กœ๋“œ
groups = global_data.groupby('conversation_id')
valid_groups = []
for cid, group in groups:
# ๋ชจ๋“  ํ–‰์˜ used๊ฐ€ False์ด๊ณ , ๊ทธ๋ฃน ๋‚ด์— overlapping์ด "TT"์ธ ํ–‰์ด ์กด์žฌํ•˜๋Š” ๊ทธ๋ฃน ์„ ํƒ
if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any():
valid_groups.append((cid, group))
if not valid_groups:
return None
chosen_cid, chosen_group = random.choice(valid_groups)
# ์„ ํƒ๋œ ๊ทธ๋ฃน์˜ ๋ชจ๋“  ํ–‰์˜ used ๊ฐ’์„ True๋กœ ์—…๋ฐ์ดํŠธ
global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True
save_global_data(global_data)
# ์„ ํƒ๋œ ๊ทธ๋ฃน ๋‚ด์—์„œ overlapping์ด "TT"์ธ ํ–‰(์—ฌ๋Ÿฌ ๊ฐœ์ผ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ) ์„ ํƒ
chosen_rows = chosen_group[chosen_group['overlapping'] == "TT"]
if chosen_rows.empty:
return None
chosen_row = chosen_rows.iloc[0]
return chosen_row.to_dict()
#############################################
# ๋Œ€ํ™” HTML ์ƒ์„ฑ ํ•จ์ˆ˜
#############################################
def format_conversation_html(row):
"""
์ „๋‹ฌ๋ฐ›์€ row(dict)๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋Œ€ํ™” ๋‚ด์šฉ์„ HTML๋กœ ํฌ๋งทํ•ฉ๋‹ˆ๋‹ค.
text ์ปฌ๋Ÿผ์€ "[turn]"์„ ๊ธฐ์ค€์œผ๋กœ ๋ฐœํ™”๊ฐ€ ๊ตฌ๋ถ„๋˜์–ด ์žˆ์œผ๋ฉฐ,
- ์ฒซ ๋ฒˆ์งธ ๋ฐœํ™”(์ธ๊ฐ„)๋Š” ์˜ค๋ฅธ์ชฝ ์ •๋ ฌ๊ณผ ๋งํ’์„  ์˜ค๋ฅธ์ชฝ์˜ ๐Ÿง‘ ์•„์ด์ฝ˜์œผ๋กœ ํ‘œ์‹œ,
- ๋‘ ๋ฒˆ์งธ ๋ฐœํ™”(AI)๋Š” ์™ผ์ชฝ ์ •๋ ฌ๊ณผ ๋งํ’์„  ์™ผ์ชฝ์˜ ๐Ÿค– ์•„์ด์ฝ˜์œผ๋กœ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
"""
if row is None:
human_message = "No valid conversation available."
ai_message = "No valid conversation available."
else:
raw_text = row.get('text', '')
parts = raw_text.split("[turn]")
if len(parts) >= 2:
human_message = parts[0].strip()
ai_message = parts[1].strip()
else:
human_message = raw_text.strip()
ai_message = ""
# ์ธ๊ฐ„ ๋งํ’์„  (์˜ค๋ฅธ์ชฝ ์ •๋ ฌ, ๐Ÿง‘ ์•„์ด์ฝ˜)
human_html = f"""
<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">
<div class="speech-bubble human" style="background: #d0f0d0; padding: 10px 15px; border-radius: 15px; max-width: 70%; text-align: right;">
{human_message}
</div>
<div class="emoji" style="font-size: 24px; line-height: 1;">๐Ÿง‘</div>
</div>
"""
# AI ๋งํ’์„  (์™ผ์ชฝ ์ •๋ ฌ, ๐Ÿค– ์•„์ด์ฝ˜)
ai_html = f"""
<div class="ai-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-start; gap: 5px; width: 100%;">
<div class="emoji" style="font-size: 24px; line-height: 1;">๐Ÿค–</div>
<div class="speech-bubble ai" style="background: #e0e0e0; padding: 10px 15px; border-radius: 15px; max-width: 70%; text-align: left;">
{ai_message}
</div>
</div>
"""
conversation_html = f"""
<div class="chat-container" style="display: flex; flex-direction: column; gap: 10px;">
{human_html}
{ai_html}
</div>
"""
return conversation_html
def load_two_conversations_html():
"""
get_random_row_from_dataset() ํ•จ์ˆ˜๋ฅผ ๋‘ ๋ฒˆ ํ˜ธ์ถœํ•˜์—ฌ
Conversation A์™€ Conversation B ๊ฐ๊ฐ์˜ row๋ฅผ ๊ฐ€์ ธ์˜จ ํ›„,
format_conversation_html()๋กœ HTML์„ ์ƒ์„ฑํ•˜์—ฌ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
row_A = get_random_row_from_dataset()
row_B = get_random_row_from_dataset()
conv_A_html = format_conversation_html(row_A)
conv_B_html = format_conversation_html(row_B)
return conv_A_html, conv_B_html
#############################################
# ํ‰๊ฐ€ ๋ฒ„ํŠผ ๊ด€๋ จ ํ•จ์ˆ˜
#############################################
# ์ „์—ญ ๋ณ€์ˆ˜ statement (๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ €์žฅํ•  ๊ฐ’)
statement = ""
def update_statement(val):
global statement
statement = val
return statement
#############################################
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
#############################################
with gr.Blocks() as demo:
# (A) CSS ์Šคํƒ€์ผ (์ฑ„ํŒ… ๋งํ’์„  ๊ด€๋ จ)
gr.HTML(
"""
<style>
.chat-container {
display: flex;
flex-direction: column;
gap: 10px;
width: 100%;
}
.speech-bubble {
position: relative;
padding: 10px 15px;
border-radius: 15px;
max-width: 70%;
font-family: sans-serif;
font-size: 16px;
line-height: 1.4;
}
.human {
background: #d0f0d0;
}
.ai {
background: #e0e0e0;
}
.emoji {
font-size: 24px;
line-height: 1;
}
</style>
"""
)
gr.Markdown("## Conversation Comparison")
# ์ขŒ์ธก: Conversation A, ์šฐ์ธก: Conversation B
with gr.Row():
conv_A = gr.HTML(label="Conversation A")
conv_B = gr.HTML(label="Conversation B")
# "Load Random Conversations" ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ๋‘ ๋Œ€ํ™”๋ฅผ ๋ถˆ๋Ÿฌ์˜ด
load_btn = gr.Button("Load Random Conversations")
load_btn.click(fn=load_two_conversations_html, inputs=[], outputs=[conv_A, conv_B])
# ํ‰๊ฐ€ ๋ฒ„ํŠผ ์˜์—ญ (ํ•˜๋‹จ)
with gr.Row():
btn_both_good = gr.Button("Both good") # "๋‘˜ ๋‹ค ์ข‹์Œ" โ†’ "BG"
btn_a_better = gr.Button("A is better") # "A๊ฐ€ ๋” ์ข‹์Œ" โ†’ "AG"
btn_b_better = gr.Button("B is better") # "B๊ฐ€ ๋” ์ข‹์Œ" โ†’ "BG"
btn_both_bad = gr.Button("Both not good") # "๋‘˜ ๋‹ค ๋ณ„๋กœ์ž„" โ†’ "BB"
# ์„ ํƒ๋œ ํ‰๊ฐ€๊ฐ’์„ ๋ณด์—ฌ์ฃผ๋Š” ํ…์ŠคํŠธ๋ฐ•์Šค
statement_output = gr.Textbox(label="Selected Statement", interactive=False)
# ๊ฐ ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ „์—ญ ๋ณ€์ˆ˜ statement ์—…๋ฐ์ดํŠธ
btn_both_good.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
btn_a_better.click(fn=lambda: update_statement("AG"), inputs=[], outputs=statement_output)
btn_b_better.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
btn_both_bad.click(fn=lambda: update_statement("BB"), inputs=[], outputs=statement_output)
demo.launch()