Interface_1 / app.py
gaeunseo's picture
Update app.py
938e4cf verified
raw
history blame
12.1 kB
import time
import gradio as gr
import random
import os
import threading
import pandas as pd
from datasets import load_dataset
# CSV ํŒŒ์ผ ๊ฒฝ๋กœ์™€ ๋™์‹œ ์ ‘๊ทผ์„ ์œ„ํ•œ Lock ์„ ์–ธ
DATA_FILE = "global_data.csv"
data_lock = threading.Lock()
def initialize_global_data():
"""
DATA_FILE์ด ์กด์žฌํ•˜์ง€ ์•Š์œผ๋ฉด, gaeunseo/Taskmaster_sample_data ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜์—ฌ DataFrame์œผ๋กœ ๋ณ€ํ™˜ํ•œ ํ›„ CSV ํŒŒ์ผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฏธ ํŒŒ์ผ์ด ์žˆ์œผ๋ฉด ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์–ด DataFrame์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
if not os.path.exists(DATA_FILE):
ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
data = ds.to_pandas()
# ํ•„์š”ํ•œ ์ปฌ๋Ÿผ์ด ์—†์œผ๋ฉด ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
if "used" not in data.columns:
data["used"] = False
if "overlapping" not in data.columns:
data["overlapping"] = ""
if "text" not in data.columns:
data["text"] = ""
data.to_csv(DATA_FILE, index=False)
return data
else:
with data_lock:
df = pd.read_csv(DATA_FILE)
return df
def load_global_data():
"""CSV ํŒŒ์ผ์—์„œ global_data DataFrame์„ ์ฝ์–ด์˜ต๋‹ˆ๋‹ค."""
with data_lock:
df = pd.read_csv(DATA_FILE)
return df
def save_global_data(df):
"""DataFrame์„ CSV ํŒŒ์ผ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค."""
with data_lock:
df.to_csv(DATA_FILE, index=False)
# CSV ํŒŒ์ผ์— ์ €์žฅ๋œ global_data ์ดˆ๊ธฐํ™”
global_data = initialize_global_data()
def get_random_row_from_dataset():
"""
CSV ํŒŒ์ผ์— ์ €์žฅ๋œ global_data์—์„œ,
1. conversation_id๋ณ„๋กœ ๊ทธ๋ฃนํ™”ํ•˜๊ณ ,
2. ๊ฐ ๊ทธ๋ฃน์—์„œ ๋ชจ๋“  ํ–‰์˜ used ์ปฌ๋Ÿผ์ด False์ด๋ฉฐ, ๊ทธ๋ฃน ๋‚ด์— overlapping ์ปฌ๋Ÿผ์ด "TT"์ธ ํ–‰์ด ์กด์žฌํ•˜๋Š” ๊ทธ๋ฃน๋งŒ valid๋กœ ๊ฐ„์ฃผํ•ฉ๋‹ˆ๋‹ค.
validํ•œ ๊ทธ๋ฃน๋“ค ์ค‘ ๋žœ๋คํ•˜๊ฒŒ ํ•˜๋‚˜์˜ ๊ทธ๋ฃน์„ ์„ ํƒํ•œ ํ›„,
- ํ•ด๋‹น ๊ทธ๋ฃน์˜ ๋ชจ๋“  ํ–‰์˜ used ๊ฐ’์„ True๋กœ ์—…๋ฐ์ดํŠธ(์ฆ‰, ์ „์ฒด ๊ทธ๋ฃน์„ ํ• ๋‹น)ํ•˜๊ณ  CSV ํŒŒ์ผ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
- ์„ ํƒ๋œ ๊ทธ๋ฃน ๋‚ด์—์„œ overlapping ์ปฌ๋Ÿผ์ด "TT"์ธ ํ–‰(์—ฌ๋Ÿฌ ๊ฐœ๋ผ๋ฉด ์ฒซ ๋ฒˆ์งธ)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
global global_data
global_data = load_global_data() # ์ตœ์‹  ๋ฐ์ดํ„ฐ ๋กœ๋“œ
groups = global_data.groupby('conversation_id')
valid_groups = []
for cid, group in groups:
# ๋ชจ๋“  ํ–‰์˜ used ๊ฐ’์ด False์ด๊ณ , ๊ทธ๋ฃน ๋‚ด์— overlapping ๊ฐ’์ด "TT"์ธ ํ–‰์ด ์žˆ๋Š” ๊ทธ๋ฃน ํ•„ํ„ฐ๋ง
if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any():
valid_groups.append((cid, group))
if not valid_groups:
return None
chosen_cid, chosen_group = random.choice(valid_groups)
# ์ „์ฒด ๊ทธ๋ฃน์˜ used ๊ฐ’์„ True๋กœ ์—…๋ฐ์ดํŠธ
global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True
save_global_data(global_data)
# ์„ ํƒ๋œ ๊ทธ๋ฃน ๋‚ด์—์„œ overlapping ์ปฌ๋Ÿผ์ด "TT"์ธ ํ–‰์„ ๋ฐ˜ํ™˜ (์—ฌ๋Ÿฌ ๊ฐœ๋ผ๋ฉด ์ฒซ ๋ฒˆ์งธ)
chosen_rows = chosen_group[chosen_group['overlapping'] == "TT"]
if chosen_rows.empty:
return None
chosen_row = chosen_rows.iloc[0]
return chosen_row.to_dict()
# --- ์ดˆ๊ธฐ ๋Œ€ํ™” ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
# ๋ฐ์ดํ„ฐ์…‹์˜ text ์ปฌ๋Ÿผ์€ "[turn]"์„ ๊ธฐ์ค€์œผ๋กœ ๋Œ€ํ™”๊ฐ€ ๊ตฌ๋ถ„๋˜์–ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
row = get_random_row_from_dataset()
if row is None:
human_message = "No valid conversation available."
ai_message = "No valid conversation available."
else:
raw_text = row['text']
human_message = raw_text.split("[turn]")[0].strip()
ai_message = raw_text.split("[turn]")[1].strip()
#############################################
# ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค ๊ด€๋ จ ํ•จ์ˆ˜ (๋งํ’์„ , ํƒ€์ดํ•‘ ํšจ๊ณผ, ํŽธ์ง‘ ๊ธฐ๋Šฅ)
#############################################
def get_initial_human_html():
"""
ํŽ˜์ด์ง€ ๋กœ๋“œ ์‹œ, ๋นˆ Human ๋งํ’์„ ๊ณผ ์˜ค๋ฅธ์ชฝ ๐Ÿง‘ ์ด๋ชจํ‹ฐ์ฝ˜์„ ํฌํ•จํ•œ ์ดˆ๊ธฐ HTML ๋ฐ˜ํ™˜
"""
wrapper_start = (
"""<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">"""
)
bubble_start = """<div id="human_message" class="speech-bubble human">"""
bubble_end = "</div>"
emoji_html = "<div class='emoji'>๐Ÿง‘</div>"
wrapper_end = "</div>"
return wrapper_start + bubble_start + bubble_end + emoji_html + wrapper_end
def stream_human_message():
"""
Start Typing ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ, ์ „์—ญ ๋ณ€์ˆ˜ human_message์˜ ๋‚ด์šฉ์„ ํ•œ ๊ธ€์ž์”ฉ ํƒ€์ดํ•‘ ํšจ๊ณผ๋กœ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
์ด์ „ ์ƒํƒœ(โœ‚๏ธ ์•„์ด์ฝ˜, ํšŒ์ƒ‰ ์ฒ˜๋ฆฌ ๋“ฑ)๋Š” ๋ฆฌ์…‹๋ฉ๋‹ˆ๋‹ค.
"""
bubble_content = ""
wrapper_start = (
"""<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">"""
)
bubble_start = """<div id="human_message" class="speech-bubble human">"""
bubble_end = "</div>"
emoji_html = "<div class='emoji'>๐Ÿง‘</div>"
wrapper_end = "</div>"
# ์ดˆ๊ธฐ ์ƒํƒœ: ๋นˆ ๋งํ’์„ ๊ณผ ์ด๋ชจํ‹ฐ์ฝ˜
yield wrapper_start + bubble_start + bubble_end + emoji_html + wrapper_end
# human_message๋ฅผ ํ•œ ๊ธ€์ž์”ฉ ์ถ”๊ฐ€ (ํƒ€์ดํ•‘ ํšจ๊ณผ)
for i, ch in enumerate(human_message):
bubble_content += f"<span data-index='{i}'>{ch}</span>"
current_html = wrapper_start + bubble_start + bubble_content + bubble_end + emoji_html + wrapper_end
yield current_html
time.sleep(0.05)
def submit_edit(edited_text):
"""
Submit ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ํ˜ธ์ถœ๋˜๋Š” ํ•จ์ˆ˜.
1. ํŽธ์ง‘๋œ human ๋ฉ”์‹œ์ง€(โœ‚๏ธ ์•ž๋ถ€๋ถ„)๋ฅผ ์ƒˆ ํ–‰์œผ๋กœ global_data์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
2. get_random_row_from_dataset()์„ ํ†ตํ•ด ์ƒˆ๋กœ์šด ๋Œ€ํ™”๋ฅผ ๊ฐ€์ ธ์˜ค๊ณ , ์ „์—ญ ๋ณ€์ˆ˜ human_message์™€ ai_message๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.
3. ์ดˆ๊ธฐ ์ƒํƒœ์˜ human ๋งํ’์„ ๊ณผ ai ๋งํ’์„  HTML์„ ๋ฐ˜ํ™˜ํ•˜์—ฌ ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ๋ฆฌ์…‹ํ•ฉ๋‹ˆ๋‹ค.
"""
global human_message, ai_message
data = load_global_data()
new_row = {
"conversation_id": "edited_" + str(random.randint(1000, 9999)),
"used": False,
"overlapping": "",
"text": edited_text,
"human_message": edited_text,
"ai_message": ""
}
new_df = pd.DataFrame([new_row])
data = pd.concat([data, new_df], ignore_index=True)
save_global_data(data)
new_row_data = get_random_row_from_dataset()
if new_row_data is None:
human_message = "No valid conversation available."
ai_message = "No valid conversation available."
else:
raw_text = new_row_data['text']
human_message = raw_text.split("[turn]")[0].strip()
ai_message = raw_text.split("[turn]")[1].strip()
new_human_html = get_initial_human_html()
new_ai_html = f"""
<div class="ai-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-start; gap: 5px; width: 100%;">
<div class="emoji">๐Ÿค–</div>
<div id="ai_message" class="speech-bubble ai">{ai_message}</div>
</div>
"""
return new_human_html, new_ai_html
#############################################
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
#############################################
with gr.Blocks() as demo:
# (A) ํŽ˜์ด์ง€ ์ƒ๋‹จ ์Šคํฌ๋ฆฝํŠธ: Human ๋งํ’์„  ๋‚ด์˜ ๊ฐ <span data-index="...">๋ฅผ ํด๋ฆญํ•˜๋ฉด,
# ํ•ด๋‹น ์œ„์น˜์— โœ‚๏ธ ์•„์ด์ฝ˜์ด ์‚ฝ์ž…๋˜๊ณ , ๊ทธ ์ดํ›„ ํ…์ŠคํŠธ๊ฐ€ ํšŒ์ƒ‰์œผ๋กœ ๋ณ€๊ฒฝ๋ฉ๋‹ˆ๋‹ค.
gr.HTML(
"""
<script>
document.addEventListener("click", function(event) {
if (event.target && event.target.matches("div.speech-bubble.human span[data-index]")) {
var span = event.target;
var container = span.closest("div.speech-bubble.human");
var oldScissors = container.querySelectorAll("span.scissor");
oldScissors.forEach(function(s) { s.remove(); });
var spans = container.querySelectorAll("span[data-index]");
spans.forEach(function(s) { s.style.color = ''; });
var scissor = document.createElement('span');
scissor.textContent = 'โœ‚๏ธ';
scissor.classList.add("scissor");
container.insertBefore(scissor, span.nextSibling);
var cutIndex = parseInt(span.getAttribute("data-index"));
spans.forEach(function(s) {
var idx = parseInt(s.getAttribute("data-index"));
if (idx > cutIndex) {
s.style.color = "grey";
}
});
}
});
</script>
"""
)
# (B) ์ถ”๊ฐ€ ์Šคํฌ๋ฆฝํŠธ: Submit ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ, human_message div์˜ innerText์—์„œ "โœ‚๏ธ"๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํŽธ์ง‘๋œ ํ…์ŠคํŠธ(์•ž๋ถ€๋ถ„)๋ฅผ ์ˆจ๊น€ ํ…์ŠคํŠธ๋ฐ•์Šค์— ์—…๋ฐ์ดํŠธ
gr.HTML(
"""
<script>
document.addEventListener("DOMContentLoaded", function() {
var submitBtn = document.getElementById("submit_btn");
if(submitBtn){
submitBtn.addEventListener("click", function(){
var humanDiv = document.getElementById("human_message");
if(humanDiv){
var edited_text = humanDiv.innerText.split("โœ‚๏ธ")[0];
document.getElementById("edited_text_input").value = edited_text;
}
});
}
});
</script>
"""
)
# (C) CSS ์Šคํƒ€์ผ
gr.HTML(
"""
<style>
.chat-container {
display: flex;
flex-direction: column;
gap: 10px;
width: 100%;
}
.speech-bubble {
position: relative;
padding: 10px 15px;
border-radius: 15px;
max-width: 70%;
font-family: sans-serif;
font-size: 16px;
line-height: 1.4;
}
.human {
background: #d0f0d0;
margin-right: 10px;
}
.human::after {
content: "";
position: absolute;
right: -10px;
top: 10px;
border-width: 10px 0 10px 10px;
border-style: solid;
border-color: transparent transparent transparent #d0f0d0;
}
.ai {
background: #e0e0e0;
margin-left: 10px;
}
.ai::after {
content: "";
position: absolute;
left: -10px;
top: 10px;
border-width: 10px 10px 10px 0;
border-style: solid;
border-color: transparent #e0e0e0 transparent transparent;
}
.emoji {
font-size: 24px;
line-height: 1;
}
</style>
"""
)
gr.Markdown("## Chat Interface")
with gr.Column(elem_classes="chat-container"):
# Human ๋งํ’์„  (์ดˆ๊ธฐ: ๋นˆ ๋ฉ”์‹œ์ง€ + ๐Ÿง‘ ์ด๋ชจํ‹ฐ์ฝ˜)
human_bubble = gr.HTML(get_initial_human_html())
# AI ๋งํ’์„  (์™ผ์ชฝ: ๐Ÿค– ์ด๋ชจํ‹ฐ์ฝ˜ + ๋ฉ”์‹œ์ง€)
ai_html = f"""
<div class="ai-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-start; gap: 5px; width: 100%;">
<div class="emoji">๐Ÿค–</div>
<div id="ai_message" class="speech-bubble ai">{ai_message}</div>
</div>
"""
ai_bubble = gr.HTML(ai_html)
# ์ˆจ๊น€ ํ…์ŠคํŠธ๋ฐ•์Šค (ํŽธ์ง‘๋œ ํ…์ŠคํŠธ ์ €์žฅ์šฉ)
edited_text_input = gr.Textbox(visible=False, elem_id="edited_text_input")
# ๋ฒ„ํŠผ ์˜์—ญ: Start Typing๊ณผ Submit ๋ฒ„ํŠผ์„ ๊ฐ™์€ ํ–‰์— ๋ฐฐ์น˜
with gr.Row():
start_button = gr.Button("Start Typing")
submit_button = gr.Button("Submit", elem_id="submit_btn")
# ๋ฒ„ํŠผ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
start_button.click(fn=stream_human_message, outputs=human_bubble)
submit_button.click(fn=submit_edit, inputs=edited_text_input, outputs=[human_bubble, ai_bubble])
demo.launch()