gaeunseo commited on
Commit
8570037
·
verified ·
1 Parent(s): f355cf7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import threading
4
+ import pandas as pd
5
+ import gradio as gr
6
+ from datasets import load_dataset
7
+
8
+ # 데이터 파일 및 스레드 락 설정
9
+ DATA_FILE = "global_data.csv"
10
+ data_lock = threading.Lock()
11
+
12
+ def initialize_global_data():
13
+ """
14
+ DATA_FILE이 존재하지 않으면, gaeunseo/Taskmaster_sample_data 데이터셋을 로드하여 DataFrame으로 변환한 후 CSV 파일로 저장합니다.
15
+ 이미 파일이 있으면 파일에서 데이터를 읽어 DataFrame을 반환합니다.
16
+ """
17
+ if not os.path.exists(DATA_FILE):
18
+ ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
19
+ data = ds.to_pandas()
20
+ # 필요한 컬럼이 없으면 추가합니다.
21
+ if "used" not in data.columns:
22
+ data["used"] = False
23
+ if "overlapping" not in data.columns:
24
+ data["overlapping"] = ""
25
+ if "text" not in data.columns:
26
+ data["text"] = ""
27
+ data.to_csv(DATA_FILE, index=False)
28
+ return data
29
+ else:
30
+ with data_lock:
31
+ df = pd.read_csv(DATA_FILE)
32
+ return df
33
+
34
+ def load_global_data():
35
+ """CSV 파일에서 global_data DataFrame을 읽어옵니다."""
36
+ with data_lock:
37
+ df = pd.read_csv(DATA_FILE)
38
+ return df
39
+
40
+ def save_global_data(df):
41
+ """DataFrame을 CSV 파일에 저장합니다."""
42
+ with data_lock:
43
+ df.to_csv(DATA_FILE, index=False)
44
+
45
+ # CSV 파일에 저장된 global_data 초기화
46
+ global_data = initialize_global_data()
47
+
48
+ def get_random_row_from_dataset():
49
+ """
50
+ CSV 파일에 저장된 global_data에서,
51
+ 1. conversation_id별로 그룹화하고,
52
+ 2. 각 그룹에서 모든 행의 used 컬럼이 False이며, 그룹 내에 overlapping 컬럼이 "TT"인 행이 존재하는 그룹만 valid로 간주합니다.
53
+ valid한 그룹들 중 랜덤하게 하나의 그룹을 선택한 후,
54
+ - 해당 그룹의 모든 행의 used 값을 True로 업데이트(즉, 전체 그룹을 할당)하고 CSV 파일에 저장합니다.
55
+ - 선택된 그룹 내에서 overlapping 컬럼이 "TT", "GT"가 아닌 대화들 중에서 대화 2개를 랜덤하게 선택하여,
56
+ 두 턴의 대화를 결합한 문자열을 반환합니다.
57
+ """
58
+ global global_data
59
+ global_data = load_global_data() # 최신 데이터 로드
60
+ groups = global_data.groupby('conversation_id')
61
+ valid_groups = []
62
+ for cid, group in groups:
63
+ # 모든 행의 used 값이 False이고, 그룹 내에 overlapping 값이 "TT"인 행이 존재하는 그룹 필터링
64
+ if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any():
65
+ valid_groups.append((cid, group))
66
+ if not valid_groups:
67
+ return None
68
+ chosen_cid, chosen_group = random.choice(valid_groups)
69
+ # 선택된 그룹의 모든 행의 used 값을 True로 업데이트
70
+ global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True
71
+ save_global_data(global_data)
72
+
73
+ # 선택된 그룹에서 overlapping 값이 "TT" 또는 "GT"가 아닌 행들만 필터링
74
+ valid_rows = chosen_group[~chosen_group['overlapping'].isin(["TT", "GT"])]
75
+ # 유효한 행이 2개 미만이면 None 반환
76
+ if valid_rows.shape[0] < 2:
77
+ return None
78
+ # 유효한 행들 중 2개를 랜덤하게 선택
79
+ chosen_rows = valid_rows.sample(2)
80
+
81
+ # 두 행의 text를 결합하여 하나의 대화 텍스트로 만듭니다.
82
+ combined_text = f"{chosen_rows.iloc[0]['text'].strip()} [turn] {chosen_rows.iloc[1]['text'].strip()}"
83
+
84
+ return {"text": combined_text}
85
+
86
+ def get_conversation():
87
+ """
88
+ get_random_row_from_dataset()를 호출하여 대화 문자열을 가져오고,
89
+ "[turn]" 구분자를 기준으로 인간 메시지와 AI 메시지를 분리하여 반환합니다.
90
+ """
91
+ row = get_random_row_from_dataset()
92
+ if row is None:
93
+ return "No valid conversation available.", "No valid conversation available."
94
+ else:
95
+ raw_text = row['text']
96
+ parts = raw_text.split("[turn]")
97
+ if len(parts) < 2:
98
+ return "Invalid conversation format", "Invalid conversation format"
99
+ human_message = parts[0].strip()
100
+ ai_message = parts[1].strip()
101
+ return human_message, ai_message
102
+
103
+ # Gradio 인터페이스 생성 (왼쪽: Human Message, 오른쪽: AI Message)
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("## Random Conversation Generator")
106
+ with gr.Row():
107
+ human_text = gr.Textbox(label="Human Message", lines=10, interactive=False)
108
+ ai_text = gr.Textbox(label="AI Message", lines=10, interactive=False)
109
+ generate_btn = gr.Button("Generate Conversation")
110
+ generate_btn.click(fn=get_conversation, inputs=[], outputs=[human_text, ai_text])
111
+
112
+ demo.launch()