Spaces:
Sleeping
Sleeping
| import os | |
| import threading | |
| import random | |
| import pandas as pd | |
| from datasets import load_dataset | |
| import gradio as gr | |
| ############################################# | |
| # CSV ํ์ผ ๊ด๋ จ ํจ์ ๋ฐ ์ ์ญ ๋ฐ์ดํฐ ์ด๊ธฐํ | |
| ############################################# | |
| DATA_FILE = "Interface1.csv" | |
| data_lock = threading.Lock() | |
| def initialize_global_data(): | |
| """ | |
| CSV ํ์ผ(DATA_FILE)์ด ์์ผ๋ฉด, Hugging Face ๋ฐ์ดํฐ์ (gaeunseo/Taskmaster_sample_data)์ train split์ | |
| DataFrame์ผ๋ก ๋ณํํ ํ ํ์ํ ์ปฌ๋ผ(used, overlapping, text)์ ์ถ๊ฐํ๊ณ CSV๋ก ์ ์ฅํฉ๋๋ค. | |
| ์ด๋ฏธ ํ์ผ์ด ์์ผ๋ฉด ํ์ผ์์ ๋ฐ์ดํฐ๋ฅผ ์ฝ์ด DataFrame์ ๋ฐํํฉ๋๋ค. | |
| """ | |
| if not os.path.exists(DATA_FILE): | |
| ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train") | |
| data = ds.to_pandas() | |
| # ํ์ํ ์ปฌ๋ผ์ด ์์ผ๋ฉด ์ถ๊ฐ | |
| if "used" not in data.columns: | |
| data["used"] = False | |
| if "overlapping" not in data.columns: | |
| data["overlapping"] = "" | |
| if "text" not in data.columns: | |
| data["text"] = "" | |
| data.to_csv(DATA_FILE, index=False) | |
| return data | |
| else: | |
| with data_lock: | |
| df = pd.read_csv(DATA_FILE) | |
| return df | |
| def load_global_data(): | |
| """CSV ํ์ผ์์ global_data DataFrame์ ์ฝ์ด์ต๋๋ค.""" | |
| with data_lock: | |
| df = pd.read_csv(DATA_FILE) | |
| return df | |
| def save_global_data(df): | |
| """DataFrame์ CSV ํ์ผ์ ์ ์ฅํฉ๋๋ค.""" | |
| with data_lock: | |
| df.to_csv(DATA_FILE, index=False) | |
| # CSV ํ์ผ์ ์ ์ฅ๋ global_data ์ด๊ธฐํ | |
| global_data = initialize_global_data() | |
| ############################################# | |
| # ๋ฐ์ดํฐ์ ์์ ๋๋ค ๋ํ ํ ์ ํ ํจ์ | |
| ############################################# | |
| def get_random_row_from_dataset(): | |
| """ | |
| CSV ํ์ผ์ ์ ์ฅ๋ global_data์์, | |
| 1. conversation_id๋ณ๋ก ๊ทธ๋ฃนํํ๊ณ , | |
| 2. ๊ฐ ๊ทธ๋ฃน์์ ๋ชจ๋ ํ์ used ์ปฌ๋ผ์ด False์ด๋ฉฐ, | |
| ๊ทธ๋ฃน ๋ด์ overlapping ์ปฌ๋ผ์ด "TT"์ธ ํ์ด ์กด์ฌํ๋ ๊ทธ๋ฃน๋ง valid๋ก ๊ฐ์ฃผํฉ๋๋ค. | |
| validํ ๊ทธ๋ฃน๋ค ์ค ๋๋คํ๊ฒ ํ๋์ ๊ทธ๋ฃน์ ์ ํํ ํ, | |
| - ํด๋น ๊ทธ๋ฃน์ ๋ชจ๋ ํ์ used ๊ฐ์ True๋ก ์ ๋ฐ์ดํธํ๊ณ CSV ํ์ผ์ ์ ์ฅํฉ๋๋ค. | |
| - ์ ํ๋ ๊ทธ๋ฃน ๋ด์์ overlapping ์ปฌ๋ผ์ด "TT"์ธ ํ(์ฌ๋ฌ ๊ฐ๋ผ๋ฉด ์ฒซ ๋ฒ์งธ)์ dict๋ก ๋ฐํํฉ๋๋ค. | |
| """ | |
| global global_data | |
| global_data = load_global_data() # ์ต์ ๋ฐ์ดํฐ ๋ก๋ | |
| groups = global_data.groupby('conversation_id') | |
| valid_groups = [] | |
| for cid, group in groups: | |
| # ๋ชจ๋ ํ์ used๊ฐ False์ด๊ณ , ๊ทธ๋ฃน ๋ด์ overlapping์ด "TT"์ธ ํ์ด ์กด์ฌํ๋ ๊ทธ๋ฃน ์ ํ | |
| if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any(): | |
| valid_groups.append((cid, group)) | |
| if not valid_groups: | |
| return None | |
| chosen_cid, chosen_group = random.choice(valid_groups) | |
| # ์ ํ๋ ๊ทธ๋ฃน์ ๋ชจ๋ ํ์ used ๊ฐ์ True๋ก ์ ๋ฐ์ดํธ | |
| global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True | |
| save_global_data(global_data) | |
| # ์ ํ๋ ๊ทธ๋ฃน ๋ด์์ overlapping์ด "TT"์ธ ํ(์ฌ๋ฌ ๊ฐ์ผ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ) ์ ํ | |
| chosen_rows = chosen_group[chosen_group['overlapping'] == "TT"] | |
| if chosen_rows.empty: | |
| return None | |
| chosen_row = chosen_rows.iloc[0] | |
| return chosen_row.to_dict() | |
| ############################################# | |
| # ๋ํ HTML ์์ฑ ํจ์ | |
| ############################################# | |
| def format_conversation_html(row): | |
| """ | |
| ์ ๋ฌ๋ฐ์ row(dict)๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ํ ๋ด์ฉ์ HTML๋ก ํฌ๋งทํฉ๋๋ค. | |
| text ์ปฌ๋ผ์ "[turn]"์ ๊ธฐ์ค์ผ๋ก ๋ฐํ๊ฐ ๊ตฌ๋ถ๋์ด ์์ผ๋ฉฐ, | |
| - ์ฒซ ๋ฒ์งธ ๋ฐํ(์ธ๊ฐ)๋ ์ค๋ฅธ์ชฝ ์ ๋ ฌ๊ณผ ๋งํ์ ์ค๋ฅธ์ชฝ์ ๐ง ์์ด์ฝ์ผ๋ก ํ์, | |
| - ๋ ๋ฒ์งธ ๋ฐํ(AI)๋ ์ผ์ชฝ ์ ๋ ฌ๊ณผ ๋งํ์ ์ผ์ชฝ์ ๐ค ์์ด์ฝ์ผ๋ก ํ์ํฉ๋๋ค. | |
| """ | |
| if row is None: | |
| human_message = "No valid conversation available." | |
| ai_message = "No valid conversation available." | |
| else: | |
| raw_text = row.get('text', '') | |
| parts = raw_text.split("[turn]") | |
| if len(parts) >= 2: | |
| human_message = parts[0].strip() | |
| ai_message = parts[1].strip() | |
| else: | |
| human_message = raw_text.strip() | |
| ai_message = "" | |
| # ์ธ๊ฐ ๋งํ์ (์ค๋ฅธ์ชฝ ์ ๋ ฌ, ๐ง ์์ด์ฝ) | |
| human_html = f""" | |
| <div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;"> | |
| <div class="speech-bubble human" style="background: #d0f0d0; padding: 10px 15px; border-radius: 15px; max-width: 70%; text-align: right;"> | |
| {human_message} | |
| </div> | |
| <div class="emoji" style="font-size: 24px; line-height: 1;">๐ง</div> | |
| </div> | |
| """ | |
| # AI ๋งํ์ (์ผ์ชฝ ์ ๋ ฌ, ๐ค ์์ด์ฝ) | |
| ai_html = f""" | |
| <div class="ai-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-start; gap: 5px; width: 100%;"> | |
| <div class="emoji" style="font-size: 24px; line-height: 1;">๐ค</div> | |
| <div class="speech-bubble ai" style="background: #e0e0e0; padding: 10px 15px; border-radius: 15px; max-width: 70%; text-align: left;"> | |
| {ai_message} | |
| </div> | |
| </div> | |
| """ | |
| conversation_html = f""" | |
| <div class="chat-container" style="display: flex; flex-direction: column; gap: 10px;"> | |
| {human_html} | |
| {ai_html} | |
| </div> | |
| """ | |
| return conversation_html | |
| def load_two_conversations_html(): | |
| """ | |
| get_random_row_from_dataset() ํจ์๋ฅผ ๋ ๋ฒ ํธ์ถํ์ฌ | |
| Conversation A์ Conversation B ๊ฐ๊ฐ์ row๋ฅผ ๊ฐ์ ธ์จ ํ, | |
| format_conversation_html()๋ก HTML์ ์์ฑํ์ฌ ๋ฐํํฉ๋๋ค. | |
| """ | |
| row_A = get_random_row_from_dataset() | |
| row_B = get_random_row_from_dataset() | |
| conv_A_html = format_conversation_html(row_A) | |
| conv_B_html = format_conversation_html(row_B) | |
| return conv_A_html, conv_B_html | |
| ############################################# | |
| # ํ๊ฐ ๋ฒํผ ๊ด๋ จ ํจ์ | |
| ############################################# | |
| # ์ ์ญ ๋ณ์ statement (๋ฒํผ ํด๋ฆญ ์ ์ ์ฅํ ๊ฐ) | |
| statement = "" | |
| def update_statement(val): | |
| global statement | |
| statement = val | |
| return statement | |
| ############################################# | |
| # Gradio ์ธํฐํ์ด์ค ๊ตฌ์ฑ | |
| ############################################# | |
| with gr.Blocks() as demo: | |
| # (A) CSS ์คํ์ผ (์ฑํ ๋งํ์ ๊ด๋ จ) | |
| gr.HTML( | |
| """ | |
| <style> | |
| .chat-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 10px; | |
| width: 100%; | |
| } | |
| .speech-bubble { | |
| position: relative; | |
| padding: 10px 15px; | |
| border-radius: 15px; | |
| max-width: 70%; | |
| font-family: sans-serif; | |
| font-size: 16px; | |
| line-height: 1.4; | |
| } | |
| .human { | |
| background: #d0f0d0; | |
| } | |
| .ai { | |
| background: #e0e0e0; | |
| } | |
| .emoji { | |
| font-size: 24px; | |
| line-height: 1; | |
| } | |
| </style> | |
| """ | |
| ) | |
| gr.Markdown("## Conversation Comparison") | |
| # ์ข์ธก: Conversation A, ์ฐ์ธก: Conversation B | |
| with gr.Row(): | |
| conv_A = gr.HTML(label="Conversation A") | |
| conv_B = gr.HTML(label="Conversation B") | |
| # "Load Random Conversations" ๋ฒํผ ํด๋ฆญ ์ ๋ ๋ํ๋ฅผ ๋ถ๋ฌ์ด | |
| load_btn = gr.Button("Load Random Conversations") | |
| load_btn.click(fn=load_two_conversations_html, inputs=[], outputs=[conv_A, conv_B]) | |
| # ํ๊ฐ ๋ฒํผ ์์ญ (ํ๋จ) | |
| with gr.Row(): | |
| btn_both_good = gr.Button("Both good") # "๋ ๋ค ์ข์" โ "BG" | |
| btn_a_better = gr.Button("A is better") # "A๊ฐ ๋ ์ข์" โ "AG" | |
| btn_b_better = gr.Button("B is better") # "B๊ฐ ๋ ์ข์" โ "BG" | |
| btn_both_bad = gr.Button("Both not good") # "๋ ๋ค ๋ณ๋ก์" โ "BB" | |
| # ์ ํ๋ ํ๊ฐ๊ฐ์ ๋ณด์ฌ์ฃผ๋ ํ ์คํธ๋ฐ์ค | |
| statement_output = gr.Textbox(label="Selected Statement", interactive=False) | |
| # ๊ฐ ๋ฒํผ ํด๋ฆญ ์ ์ ์ญ ๋ณ์ statement ์ ๋ฐ์ดํธ | |
| btn_both_good.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output) | |
| btn_a_better.click(fn=lambda: update_statement("AG"), inputs=[], outputs=statement_output) | |
| btn_b_better.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output) | |
| btn_both_bad.click(fn=lambda: update_statement("BB"), inputs=[], outputs=statement_output) | |
| demo.launch() | |