gaeunseo commited on
Commit
0d9db4d
·
verified ·
1 Parent(s): 1363d3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -97
app.py CHANGED
@@ -1,105 +1,33 @@
1
- import os
2
- import random
3
- import threading
4
- import pandas as pd
5
  import gradio as gr
6
- from datasets import load_dataset
7
 
8
- # 데이터 파일 및 스레드 락 설정
9
- DATA_FILE = "global_data.csv"
10
- data_lock = threading.Lock()
11
 
12
- def initialize_global_data():
13
- """
14
- DATA_FILE이 존재하지 않으면, gaeunseo/Taskmaster_sample_data 데이터셋을 로드하여 DataFrame으로 변환한 후 CSV 파일로 저장합니다.
15
- 이미 파일이 있으면 파일에서 데이터를 읽어 DataFrame을 반환합니다.
16
- """
17
- if not os.path.exists(DATA_FILE):
18
- ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
19
- data = ds.to_pandas()
20
- # 필요한 컬럼이 없으면 추가합니다.
21
- if "used" not in data.columns:
22
- data["used"] = False
23
- if "overlapping" not in data.columns:
24
- data["overlapping"] = ""
25
- if "text" not in data.columns:
26
- data["text"] = ""
27
- data.to_csv(DATA_FILE, index=False)
28
- return data
29
- else:
30
- with data_lock:
31
- df = pd.read_csv(DATA_FILE)
32
- return df
33
 
34
- def load_global_data():
35
- """CSV 파일에서 global_data DataFrame을 읽어옵니다."""
36
- with data_lock:
37
- df = pd.read_csv(DATA_FILE)
38
- return df
39
-
40
- def save_global_data(df):
41
- """DataFrame을 CSV 파일에 저장합니다."""
42
- with data_lock:
43
- df.to_csv(DATA_FILE, index=False)
44
-
45
- # CSV 파일에 저장된 global_data 초기화
46
- global_data = initialize_global_data()
47
-
48
- def get_random_row_from_dataset():
49
- """
50
- CSV 파일에 저장된 global_data에서,
51
- 1. conversation_id별로 그룹화하고,
52
- 2. 각 그룹에서 모든 행의 used 컬럼이 False이며, 그룹 내에 overlapping 컬럼이 "TT"인 행이 존재하는 그룹만 valid로 간주합니다.
53
- valid한 그룹들 중 랜덤하게 하나의 그룹을 선택한 후,
54
- - 해당 그룹의 모든 행의 used 값을 True로 업데이트(즉, 전체 그룹을 할당)하고 CSV 파일에 저장합니다.
55
- - 선택된 그룹 내에서 overlapping 컬럼이 "TT", "GT"가 아닌 대화들 중에서 대화 2개를 랜덤하게 선택하여,
56
- 두 턴의 대화를 결합한 문자열을 반환합니다.
57
- """
58
- global global_data
59
- global_data = load_global_data() # 최신 데이터 로드
60
- groups = global_data.groupby('conversation_id')
61
- valid_groups = []
62
- for cid, group in groups:
63
- # 모든 행의 used 값이 False이고, 그룹 내에 overlapping 값이 "TT"인 행이 존재하는 그룹 필터링
64
- if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any():
65
- valid_groups.append((cid, group))
66
- if not valid_groups:
67
- return None
68
- chosen_cid, chosen_group = random.choice(valid_groups)
69
- # 선택된 그룹의 모든 행의 used 값을 True로 업데이트
70
- global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True
71
- save_global_data(global_data)
72
-
73
- # 선택된 그룹에서 overlapping 값이 "TT" 또는 "GT"가 아닌 행들만 필터링
74
- valid_rows = chosen_group[~chosen_group['overlapping'].isin(["TT", "GT"])]
75
- # 유효한 행이 2개 미만이면 None 반환
76
- if valid_rows.shape[0] < 2:
77
- return None
78
- # 유효한 행들 중 2개를 랜덤하게 선택
79
- chosen_rows = valid_rows.sample(2)
80
-
81
- return chosen_rows[0], chosen_rows[1]
82
-
83
- def get_conversation():
84
- """
85
- get_random_row_from_dataset()를 호출하여 대화 문자열을 가져오고,
86
- "[turn]" 구분자를 기준으로 인간 메시지와 AI 메시지를 분리하여 반환합니다.
87
- """
88
- row1, row2 = get_random_row_from_dataset()
89
- if row is None:
90
- return "No valid conversation available.", "No valid conversation available."
91
- else:
92
- conversation_1 = row1['text']
93
- conversation_2 = row2['text']
94
- return conversation_1, conversation_2
95
-
96
- # Gradio 인터페이스 생성 (왼쪽: Human Message, 오른쪽: AI Message)
97
  with gr.Blocks() as demo:
98
- gr.Markdown("## Random Conversation Generator")
 
 
 
 
 
99
  with gr.Row():
100
- conversation_1 = gr.Textbox(label="Conversation1", lines=10, interactive=False)
101
- conversation_2 = gr.Textbox(label="Conversation2", lines=10, interactive=False)
102
- generate_btn = gr.Button("Generate Conversation")
103
- generate_btn.click(fn=get_conversation, inputs=[], outputs=[conversation_1, conversation_2])
 
 
 
 
 
 
 
 
 
104
 
105
  demo.launch()
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
+ # 전역 변수 선언
4
+ statement = ""
 
5
 
6
+ def update_statement(value):
7
+ global statement
8
+ statement = value
9
+ return statement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with gr.Blocks() as demo:
12
+ # 개의 대화창을 좌우에 배치합니다.
13
+ with gr.Row():
14
+ conversation_A = gr.Textbox(label="Conversation A", placeholder="Enter conversation A here...", lines=10)
15
+ conversation_B = gr.Textbox(label="Conversation B", placeholder="Enter conversation B here...", lines=10)
16
+
17
+ # 4개의 버튼을 한 행에 배치합니다.
18
  with gr.Row():
19
+ btn_both_good = gr.Button("Both good") # "둘 다 좋음" → "BG"
20
+ btn_a_better = gr.Button("A is better") # "A가 더 좋음" → "AG"
21
+ btn_b_better = gr.Button("B is better") # "B가 더 좋음" → "BG"
22
+ btn_both_bad = gr.Button("Both not good") # "둘 다 별로임" → "BB"
23
+
24
+ # 선택된 statement 값을 보여주기 위한 출력 텍스트박스 (옵션)
25
+ statement_output = gr.Textbox(label="Selected Statement", value="", interactive=False)
26
+
27
+ # 각 버튼 클릭 시 update_statement 함수를 호출하여 statement 값을 업데이트합니다.
28
+ btn_both_good.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
29
+ btn_a_better.click(fn=lambda: update_statement("AG"), inputs=[], outputs=statement_output)
30
+ btn_b_better.click(fn=lambda: update_statement("BG"), inputs=[], outputs=statement_output)
31
+ btn_both_bad.click(fn=lambda: update_statement("BB"), inputs=[], outputs=statement_output)
32
 
33
  demo.launch()