Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import csv | |
| import io | |
| import tempfile | |
| import os | |
| # νκ΅μ΄ μ²λ¦¬λ₯Ό μν KoSentence-BERT λͺ¨λΈ λ‘λ | |
| model = SentenceTransformer('jhgan/ko-sbert-sts') | |
| # μ μ λ³μ | |
| global_recommendations = None | |
| global_csv_file = None | |
| youtube_columns = None | |
| # CSV νμΌ μμ± ν¨μ | |
| def create_csv_file(recommendations): | |
| with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', encoding='utf-8') as temp_file: | |
| writer = csv.writer(temp_file) | |
| writer.writerow(["Employee ID", "Employee Name", "Recommended Programs", "Recommended YouTube Content"]) | |
| for rec in recommendations: | |
| writer.writerow(rec) | |
| return temp_file.name | |
| # μ΄ λ§€μΉ ν¨μ | |
| def auto_match_columns(df, required_cols): | |
| matched_cols = {} | |
| for req_col in required_cols: | |
| matched_col = None | |
| for col in df.columns: | |
| if req_col.lower() in col.lower(): | |
| matched_col = col | |
| break | |
| matched_cols[req_col] = matched_col | |
| return matched_cols | |
| # μ΄ κ²μ¦ ν¨μ | |
| def validate_and_get_columns(employee_df, program_df): | |
| required_employee_cols = ["employee_id", "employee_name", "current_skills"] | |
| required_program_cols = ["program_name", "skills_acquired", "duration"] | |
| employee_cols = auto_match_columns(employee_df, required_employee_cols) | |
| program_cols = auto_match_columns(program_df, required_program_cols) | |
| for key, value in employee_cols.items(): | |
| if value is None: | |
| return f"μ§μ λ°μ΄ν°μμ '{key}' μ΄μ μ νν μ μμ΅λλ€. μ¬λ°λ₯Έ μ΄μ μ ννμΈμ.", None, None | |
| for key, value in program_cols.items(): | |
| if value is None: | |
| return f"νλ‘κ·Έλ¨ λ°μ΄ν°μμ '{key}' μ΄μ μ νν μ μμ΅λλ€. μ¬λ°λ₯Έ μ΄μ μ ννμΈμ.", None, None | |
| return None, employee_cols, program_cols | |
| # μ νλΈ λ°μ΄ν° μ΄ μ ν ν¨μ | |
| def select_youtube_columns(youtube_file): | |
| global youtube_columns | |
| if youtube_file is None: | |
| return [gr.Dropdown(choices=[], value="") for _ in range(4)] | |
| youtube_df = pd.read_csv(youtube_file.name) | |
| required_youtube_cols = ["title", "description", "url", "upload_date"] | |
| youtube_columns = auto_match_columns(youtube_df, required_youtube_cols) | |
| column_options = youtube_df.columns.tolist() | |
| return [ | |
| gr.Dropdown(choices=column_options, value=youtube_columns.get("title", "")), | |
| gr.Dropdown(choices=column_options, value=youtube_columns.get("description", "")), | |
| gr.Dropdown(choices=column_options, value=youtube_columns.get("url", "")), | |
| gr.Dropdown(choices=column_options, value=youtube_columns.get("upload_date", "")) | |
| ] | |
| # μ νλΈ μ½ν μΈ λ°μ΄ν° λ‘λ λ° μ²λ¦¬ ν¨μ | |
| def load_youtube_content(file_path, title_col, description_col, url_col, upload_date_col): | |
| youtube_df = pd.read_csv(file_path) | |
| selected_columns = [col for col in [title_col, description_col, url_col, upload_date_col] if col] | |
| youtube_df = youtube_df[selected_columns] | |
| column_mapping = { | |
| title_col: 'title', | |
| description_col: 'description', | |
| url_col: 'url', | |
| upload_date_col: 'upload_date' | |
| } | |
| youtube_df.rename(columns=column_mapping, inplace=True) | |
| if 'upload_date' in youtube_df.columns: | |
| youtube_df['upload_date'] = pd.to_datetime(youtube_df['upload_date'], errors='coerce') | |
| return youtube_df | |
| # μ νλΈ μ½ν μΈ μ κ΅μ‘ νλ‘κ·Έλ¨ λ§€μΉ ν¨μ | |
| def match_youtube_content(program_skills, youtube_df, model): | |
| if 'description' not in youtube_df.columns: | |
| return None | |
| youtube_embeddings = model.encode(youtube_df['description'].tolist()) | |
| program_embeddings = model.encode(program_skills) | |
| similarities = cosine_similarity(program_embeddings, youtube_embeddings) | |
| return similarities | |
| # μ§μ λ°μ΄ν°λ₯Ό λΆμνμ¬ κ΅μ‘ νλ‘κ·Έλ¨μ μΆμ²νκ³ , ν μ΄λΈμ μμ±νλ ν¨μ | |
| def hybrid_rag(employee_file, program_file, youtube_file, title_col, description_col, url_col, upload_date_col): | |
| global global_recommendations | |
| global global_csv_file | |
| # μ§μ λ° νλ‘κ·Έλ¨ λ°μ΄ν° λ‘λ | |
| employee_df = pd.read_csv(employee_file.name) | |
| program_df = pd.read_csv(program_file.name) | |
| error_msg, employee_cols, program_cols = validate_and_get_columns(employee_df, program_df) | |
| if error_msg: | |
| return error_msg, None, None | |
| employee_skills = employee_df[employee_cols["current_skills"]].tolist() | |
| program_skills = program_df[program_cols["skills_acquired"]].tolist() | |
| employee_embeddings = model.encode(employee_skills) | |
| program_embeddings = model.encode(program_skills) | |
| similarities = cosine_similarity(employee_embeddings, program_embeddings) | |
| # μ νλΈ μ½ν μΈ λ‘λ λ° μ²λ¦¬ | |
| youtube_df = load_youtube_content(youtube_file.name, title_col, description_col, url_col, upload_date_col) | |
| # μ νλΈ μ½ν μΈ μ κ΅μ‘ νλ‘κ·Έλ¨ λ§€μΉ | |
| youtube_similarities = match_youtube_content(program_df[program_cols['skills_acquired']].tolist(), youtube_df, model) | |
| recommendations = [] | |
| recommendation_rows = [] | |
| for i, employee in employee_df.iterrows(): | |
| recommended_programs = [] | |
| recommended_youtube = [] | |
| for j, program in program_df.iterrows(): | |
| if similarities[i][j] > 0.5: | |
| recommended_programs.append(f"{program[program_cols['program_name']]} ({program[program_cols['duration']]})") | |
| if youtube_similarities is not None: | |
| top_youtube_indices = youtube_similarities[j].argsort()[-3:][::-1] # μμ 3κ° | |
| for idx in top_youtube_indices: | |
| if 'title' in youtube_df.columns and 'url' in youtube_df.columns: | |
| recommended_youtube.append(f"{youtube_df.iloc[idx]['title']} (URL: {youtube_df.iloc[idx]['url']})") | |
| # μΆμ² νλ‘κ·Έλ¨ λ° μ νλΈ μ½ν μΈ κ°μ μ ν | |
| recommended_programs = recommended_programs[:5] # μ΅λ 5κ° νλ‘κ·Έλ¨λ§ μΆμ² | |
| recommended_youtube = recommended_youtube[:3] # μ΅λ 3κ° μ νλΈ μ½ν μΈ λ§ μΆμ² | |
| if recommended_programs: | |
| recommendation = f"μ§μ {employee[employee_cols['employee_name']]}μ μΆμ² νλ‘κ·Έλ¨: {', '.join(recommended_programs)}" | |
| youtube_recommendation = f"μΆμ² μ νλΈ μ½ν μΈ : {', '.join(recommended_youtube)}" if recommended_youtube else "μΆμ²ν μ νλΈ μ½ν μΈ κ° μμ΅λλ€." | |
| recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']], | |
| ", ".join(recommended_programs), ", ".join(recommended_youtube)]) | |
| else: | |
| recommendation = f"μ§μ {employee[employee_cols['employee_name']]}μκ² μ ν©ν νλ‘κ·Έλ¨μ΄ μμ΅λλ€." | |
| youtube_recommendation = "μΆμ²ν μ νλΈ μ½ν μΈ κ° μμ΅λλ€." | |
| recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']], | |
| "μ ν©ν νλ‘κ·Έλ¨ μμ", "μΆμ² μ½ν μΈ μμ"]) | |
| recommendations.append(recommendation + "\n" + youtube_recommendation) | |
| global_recommendations = recommendation_rows | |
| # CSV νμΌ μμ± | |
| global_csv_file = create_csv_file(recommendation_rows) | |
| # κ²°κ³Ό ν μ΄λΈ λ°μ΄ν°νλ μ μμ± | |
| result_df = pd.DataFrame(recommendation_rows, columns=["Employee ID", "Employee Name", "Recommended Programs", "Recommended YouTube Content"]) | |
| return result_df, gr.File(value=global_csv_file, visible=True), gr.Button(value="CSV λ€μ΄λ‘λ", visible=True) | |
| # μ±ν μλ΅ ν¨μ | |
| def chat_response(message, history): | |
| global global_recommendations | |
| if global_recommendations is None: | |
| return "λ¨Όμ 'λΆμ μμ' λ²νΌμ λλ¬ λ°μ΄ν°λ₯Ό λΆμν΄μ£ΌμΈμ." | |
| for employee in global_recommendations: | |
| if employee[1].lower() in message.lower(): | |
| return f"{employee[1]}λμκ² μΆμ²λ νλ‘κ·Έλ¨μ λ€μκ³Ό κ°μ΅λλ€: {employee[2]}\n\nμΆμ² μ νλΈ μ½ν μΈ : {employee[3]}" | |
| return "μ£μ‘ν©λλ€. ν΄λΉ μ§μμ μ 보λ₯Ό μ°Ύμ μ μμ΅λλ€. λ€λ₯Έ μ§μ μ΄λ¦μ μ λ ₯ν΄μ£ΌμΈμ." | |
| # CSV λ€μ΄λ‘λ ν¨μ | |
| def download_csv(): | |
| global global_csv_file | |
| return gr.File(value=global_csv_file, visible=True) | |
| # Gradio λΈλ‘ | |
| with gr.Blocks(css=".gradio-button {background-color: #007bff; color: white;} .gradio-textbox {border-color: #6c757d;}") as demo: | |
| gr.Markdown("<h1 style='text-align: center; color: #2c3e50;'>πΌ HybridRAG μμ€ν (μ νλΈ μ½ν μΈ ν¬ν¨)</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("<h3 style='color: #34495e;'>1. λ°μ΄ν°λ₯Ό μ λ‘λνμΈμ</h3>") | |
| employee_file = gr.File(label="μ§μ λ°μ΄ν° μ λ‘λ", interactive=True) | |
| program_file = gr.File(label="κ΅μ‘ νλ‘κ·Έλ¨ λ°μ΄ν° μ λ‘λ", interactive=True) | |
| youtube_file = gr.File(label="μ νλΈ μ½ν μΈ λ°μ΄ν° μ λ‘λ", interactive=True) | |
| gr.Markdown("<h4 style='color: #34495e;'>μ νλΈ λ°μ΄ν° μ΄ μ ν</h4>") | |
| title_col = gr.Dropdown(label="μ λͺ© μ΄") | |
| description_col = gr.Dropdown(label="μ€λͺ μ΄") | |
| url_col = gr.Dropdown(label="URL μ΄") | |
| upload_date_col = gr.Dropdown(label="μ λ‘λ λ μ§ μ΄") | |
| youtube_file.change(select_youtube_columns, inputs=[youtube_file], outputs=[title_col, description_col, url_col, upload_date_col]) | |
| analyze_button = gr.Button("λΆμ μμ", elem_classes="gradio-button") | |
| output_table = gr.DataFrame(label="λΆμ κ²°κ³Ό (ν μ΄λΈ)") | |
| csv_download = gr.File(label="μΆμ² κ²°κ³Ό λ€μ΄λ‘λ", visible=False) | |
| download_button = gr.Button("CSV λ€μ΄λ‘λ", visible=False) | |
| gr.Markdown("<h3 style='color: #34495e;'>2. μ§μλ³ μΆμ² νλ‘κ·Έλ¨ λ° μ νλΈ μ½ν μΈ νμΈ</h3>") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(label="μ§μ μ΄λ¦μ μ λ ₯νμΈμ") | |
| clear = gr.Button("λν λ΄μ μ§μ°κΈ°") | |
| # λΆμ λ²νΌ ν΄λ¦ μ ν μ΄λΈ, νμΌ λ€μ΄λ‘λλ₯Ό μ λ°μ΄νΈ | |
| analyze_button.click( | |
| hybrid_rag, | |
| inputs=[employee_file, program_file, youtube_file, title_col, description_col, url_col, upload_date_col], | |
| outputs=[output_table, csv_download, download_button] | |
| ) | |
| # CSV λ€μ΄λ‘λ λ²νΌ | |
| download_button.click(download_csv, inputs=[], outputs=[csv_download]) | |
| # μ±ν κΈ°λ₯ | |
| msg.submit(chat_response, [msg, chatbot], [chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # νλ‘κ·Έλ¨ μ’ λ£ μ μμ νμΌ μμ | |
| import atexit | |
| def cleanup(): | |
| global global_csv_file | |
| if global_csv_file and os.path.exists(global_csv_file): | |
| os.remove(global_csv_file) | |
| # Gradio μΈν°νμ΄μ€ μ€ν | |
| demo.launch() |