import os import threading import random import pandas as pd from datasets import load_dataset import gradio as gr ############################################# # CSV 파일 관련 함수 및 전역 데이터 초기화 ############################################# DATA_FILE = "Interface1.csv" data_lock = threading.Lock() def initialize_global_data(): """ CSV 파일(DATA_FILE)이 없으면, Hugging Face 데이터셋(gaeunseo/Taskmaster_sample_data)의 train split을 DataFrame으로 변환한 후 필요한 컬럼(used, overlapping, text)을 추가하고 CSV로 저장합니다. 이미 파일이 있으면 파일에서 데이터를 읽어 DataFrame을 반환합니다. """ if not os.path.exists(DATA_FILE): ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train") data = ds.to_pandas() # 필요한 컬럼이 없으면 추가 if "used" not in data.columns: data["used"] = False if "overlapping" not in data.columns: data["overlapping"] = "" if "text" not in data.columns: data["text"] = "" data.to_csv(DATA_FILE, index=False) return data else: with data_lock: df = pd.read_csv(DATA_FILE) return df def load_global_data(): """CSV 파일에서 global_data DataFrame을 읽어옵니다.""" with data_lock: df = pd.read_csv(DATA_FILE) return df def save_global_data(df): """DataFrame을 CSV 파일에 저장합니다.""" with data_lock: df.to_csv(DATA_FILE, index=False) # CSV 파일에 저장된 global_data 초기화 global_data = initialize_global_data() ############################################# # 데이터셋에서 랜덤 대화 행 선택 함수 ############################################# def get_random_row_from_dataset(): """ CSV 파일에 저장된 global_data에서, 1. conversation_id별로 그룹화하고, 2. 각 그룹에서 모든 행의 used 컬럼이 False이며, 그룹 내에 overlapping 컬럼이 "TT"인 행이 존재하는 그룹만 valid로 간주합니다. valid한 그룹들 중 랜덤하게 하나의 그룹을 선택한 후, - 해당 그룹의 모든 행의 used 값을 True로 업데이트하고 CSV 파일에 저장합니다. - 선택된 그룹 내에서 overlapping 컬럼이 "TT"인 행(여러 개라면 첫 번째)을 dict로 반환합니다. """ global global_data global_data = load_global_data() # 최신 데이터 로드 groups = global_data.groupby('conversation_id') valid_groups = [] for cid, group in groups: # 모든 행의 used가 False이고, 그룹 내에 overlapping이 "TT"인 행이 존재하는 그룹 선택 if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any(): valid_groups.append((cid, group)) if not valid_groups: return None chosen_cid, chosen_group = random.choice(valid_groups) # 선택된 그룹의 모든 행의 used 값을 True로 업데이트 global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True save_global_data(global_data) # 선택된 그룹 내에서 overlapping이 "TT"인 행(여러 개일 경우 첫 번째) 선택 chosen_rows = chosen_group[chosen_group['overlapping'] == "TT"] if chosen_rows.empty: return None chosen_row = chosen_rows.iloc[0] return chosen_row.to_dict() ############################################# # 대화 HTML 생성 함수 ############################################# def format_conversation_html(row): """ 전달받은 row(dict)를 기반으로 대화 내용을 HTML로 포맷합니다. text 컬럼은 "[turn]"을 기준으로 발화가 구분되어 있으며, - 첫 번째 발화(인간)는 오른쪽 정렬과 말풍선 오른쪽의 🧑 아이콘으로 표시, - 두 번째 발화(AI)는 왼쪽 정렬과 말풍선 왼쪽의 🤖 아이콘으로 표시합니다. """ if row is None: human_message = "No valid conversation available." ai_message = "No valid conversation available." else: raw_text = row.get('text', '') parts = raw_text.split("[turn]") if len(parts) >= 2: human_message = parts[0].strip() ai_message = parts[1].strip() else: human_message = raw_text.strip() ai_message = "" # 인간 말풍선 (오른쪽 정렬, 🧑 아이콘) human_html = f"""