Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import threading
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
|
| 8 |
+
# 데이터 파일 및 스레드 락 설정
|
| 9 |
+
DATA_FILE = "global_data.csv"
|
| 10 |
+
data_lock = threading.Lock()
|
| 11 |
+
|
| 12 |
+
def initialize_global_data():
|
| 13 |
+
"""
|
| 14 |
+
DATA_FILE이 존재하지 않으면, gaeunseo/Taskmaster_sample_data 데이터셋을 로드하여 DataFrame으로 변환한 후 CSV 파일로 저장합니다.
|
| 15 |
+
이미 파일이 있으면 파일에서 데이터를 읽어 DataFrame을 반환합니다.
|
| 16 |
+
"""
|
| 17 |
+
if not os.path.exists(DATA_FILE):
|
| 18 |
+
ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
|
| 19 |
+
data = ds.to_pandas()
|
| 20 |
+
# 필요한 컬럼이 없으면 추가합니다.
|
| 21 |
+
if "used" not in data.columns:
|
| 22 |
+
data["used"] = False
|
| 23 |
+
if "overlapping" not in data.columns:
|
| 24 |
+
data["overlapping"] = ""
|
| 25 |
+
if "text" not in data.columns:
|
| 26 |
+
data["text"] = ""
|
| 27 |
+
data.to_csv(DATA_FILE, index=False)
|
| 28 |
+
return data
|
| 29 |
+
else:
|
| 30 |
+
with data_lock:
|
| 31 |
+
df = pd.read_csv(DATA_FILE)
|
| 32 |
+
return df
|
| 33 |
+
|
| 34 |
+
def load_global_data():
|
| 35 |
+
"""CSV 파일에서 global_data DataFrame을 읽어옵니다."""
|
| 36 |
+
with data_lock:
|
| 37 |
+
df = pd.read_csv(DATA_FILE)
|
| 38 |
+
return df
|
| 39 |
+
|
| 40 |
+
def save_global_data(df):
|
| 41 |
+
"""DataFrame을 CSV 파일에 저장합니다."""
|
| 42 |
+
with data_lock:
|
| 43 |
+
df.to_csv(DATA_FILE, index=False)
|
| 44 |
+
|
| 45 |
+
# CSV 파일에 저장된 global_data 초기화
|
| 46 |
+
global_data = initialize_global_data()
|
| 47 |
+
|
| 48 |
+
def get_random_row_from_dataset():
|
| 49 |
+
"""
|
| 50 |
+
CSV 파일에 저장된 global_data에서,
|
| 51 |
+
1. conversation_id별로 그룹화하고,
|
| 52 |
+
2. 각 그룹에서 모든 행의 used 컬럼이 False이며, 그룹 내에 overlapping 컬럼이 "TT"인 행이 존재하는 그룹만 valid로 간주합니다.
|
| 53 |
+
valid한 그룹들 중 랜덤하게 하나의 그룹을 선택한 후,
|
| 54 |
+
- 해당 그룹의 모든 행의 used 값을 True로 업데이트(즉, 전체 그룹을 할당)하고 CSV 파일에 저장합니다.
|
| 55 |
+
- 선택된 그룹 내에서 overlapping 컬럼이 "TT", "GT"가 아닌 대화들 중에서 대화 2개를 랜덤하게 선택하여,
|
| 56 |
+
두 턴의 대화를 결합한 문자열을 반환합니다.
|
| 57 |
+
"""
|
| 58 |
+
global global_data
|
| 59 |
+
global_data = load_global_data() # 최신 데이터 로드
|
| 60 |
+
groups = global_data.groupby('conversation_id')
|
| 61 |
+
valid_groups = []
|
| 62 |
+
for cid, group in groups:
|
| 63 |
+
# 모든 행의 used 값이 False이고, 그룹 내에 overlapping 값이 "TT"인 행이 존재하는 그룹 필터링
|
| 64 |
+
if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any():
|
| 65 |
+
valid_groups.append((cid, group))
|
| 66 |
+
if not valid_groups:
|
| 67 |
+
return None
|
| 68 |
+
chosen_cid, chosen_group = random.choice(valid_groups)
|
| 69 |
+
# 선택된 그룹의 모든 행의 used 값을 True로 업데이트
|
| 70 |
+
global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True
|
| 71 |
+
save_global_data(global_data)
|
| 72 |
+
|
| 73 |
+
# 선택된 그룹에서 overlapping 값이 "TT" 또는 "GT"가 아닌 행들만 필터링
|
| 74 |
+
valid_rows = chosen_group[~chosen_group['overlapping'].isin(["TT", "GT"])]
|
| 75 |
+
# 유효한 행이 2개 미만이면 None 반환
|
| 76 |
+
if valid_rows.shape[0] < 2:
|
| 77 |
+
return None
|
| 78 |
+
# 유효한 행들 중 2개를 랜덤하게 선택
|
| 79 |
+
chosen_rows = valid_rows.sample(2)
|
| 80 |
+
|
| 81 |
+
# 두 행의 text를 결합하여 하나의 대화 텍스트로 만듭니다.
|
| 82 |
+
combined_text = f"{chosen_rows.iloc[0]['text'].strip()} [turn] {chosen_rows.iloc[1]['text'].strip()}"
|
| 83 |
+
|
| 84 |
+
return {"text": combined_text}
|
| 85 |
+
|
| 86 |
+
def get_conversation():
|
| 87 |
+
"""
|
| 88 |
+
get_random_row_from_dataset()를 호출하여 대화 문자열을 가져오고,
|
| 89 |
+
"[turn]" 구분자를 기준으로 인간 메시지와 AI 메시지를 분리하여 반환합니다.
|
| 90 |
+
"""
|
| 91 |
+
row = get_random_row_from_dataset()
|
| 92 |
+
if row is None:
|
| 93 |
+
return "No valid conversation available.", "No valid conversation available."
|
| 94 |
+
else:
|
| 95 |
+
raw_text = row['text']
|
| 96 |
+
parts = raw_text.split("[turn]")
|
| 97 |
+
if len(parts) < 2:
|
| 98 |
+
return "Invalid conversation format", "Invalid conversation format"
|
| 99 |
+
human_message = parts[0].strip()
|
| 100 |
+
ai_message = parts[1].strip()
|
| 101 |
+
return human_message, ai_message
|
| 102 |
+
|
| 103 |
+
# Gradio 인터페이스 생성 (왼쪽: Human Message, 오른쪽: AI Message)
|
| 104 |
+
with gr.Blocks() as demo:
|
| 105 |
+
gr.Markdown("## Random Conversation Generator")
|
| 106 |
+
with gr.Row():
|
| 107 |
+
human_text = gr.Textbox(label="Human Message", lines=10, interactive=False)
|
| 108 |
+
ai_text = gr.Textbox(label="AI Message", lines=10, interactive=False)
|
| 109 |
+
generate_btn = gr.Button("Generate Conversation")
|
| 110 |
+
generate_btn.click(fn=get_conversation, inputs=[], outputs=[human_text, ai_text])
|
| 111 |
+
|
| 112 |
+
demo.launch()
|