Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| import os | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import numpy as np | |
| from pathlib import Path | |
| import glob | |
| import requests | |
| from io import StringIO | |
| import zipfile | |
| import tempfile | |
| import shutil | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Attention Analysis Results Explorer", | |
| page_icon="🔍", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| font-weight: bold; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .section-header { | |
| font-size: 1.5rem; | |
| font-weight: bold; | |
| color: #ff7f0e; | |
| margin-top: 2rem; | |
| margin-bottom: 1rem; | |
| } | |
| .metric-container { | |
| background-color: #f0f2f6; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| margin: 0.5rem 0; | |
| } | |
| .stSelectbox > div > div { | |
| background-color: white; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| class AttentionResultsExplorer: | |
| def __init__(self, github_repo="ACMCMC/attention", use_cache=True): | |
| self.github_repo = github_repo | |
| self.use_cache = use_cache | |
| self.cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" | |
| self.base_path = self.cache_dir | |
| # Initialize cache directory | |
| if not self.cache_dir.exists(): | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| # Download and cache data if needed | |
| if not self._cache_exists() or not use_cache: | |
| self._download_repository() | |
| self.languages = self._get_available_languages() | |
| self.relation_types = None | |
| def _cache_exists(self): | |
| """Check if cached data exists""" | |
| return (self.cache_dir / "results_en").exists() | |
| def _download_repository(self): | |
| """Download repository data from GitHub""" | |
| st.info("🔄 Downloading results data from GitHub... This may take a moment.") | |
| # GitHub API to get the repository contents | |
| api_url = f"https://api.github.com/repos/{self.github_repo}/contents" | |
| try: | |
| # Get list of result directories | |
| response = requests.get(api_url) | |
| response.raise_for_status() | |
| contents = response.json() | |
| result_dirs = [item['name'] for item in contents | |
| if item['type'] == 'dir' and item['name'].startswith('results_')] | |
| st.write(f"Found {len(result_dirs)} result directories: {', '.join(result_dirs)}") | |
| # Download each result directory | |
| progress_bar = st.progress(0) | |
| for i, result_dir in enumerate(result_dirs): | |
| st.write(f"Downloading {result_dir}...") | |
| self._download_directory(result_dir) | |
| progress_bar.progress((i + 1) / len(result_dirs)) | |
| st.success("✅ Download completed!") | |
| except Exception as e: | |
| st.error(f"❌ Error downloading repository: {str(e)}") | |
| st.error("Please check the repository URL and your internet connection.") | |
| raise | |
| def _download_directory(self, dir_name, path=""): | |
| """Recursively download a directory from GitHub""" | |
| url = f"https://api.github.com/repos/{self.github_repo}/contents/{path}{dir_name}" | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| contents = response.json() | |
| local_dir = self.cache_dir / path / dir_name | |
| local_dir.mkdir(parents=True, exist_ok=True) | |
| for item in contents: | |
| if item['type'] == 'file': | |
| self._download_file(item, local_dir) | |
| elif item['type'] == 'dir': | |
| self._download_directory(item['name'], f"{path}{dir_name}/") | |
| except Exception as e: | |
| st.warning(f"Could not download {dir_name}: {str(e)}") | |
| def _download_file(self, file_info, local_dir): | |
| """Download a single file from GitHub""" | |
| try: | |
| # Download file content | |
| response = requests.get(file_info['download_url']) | |
| response.raise_for_status() | |
| # Save to local cache | |
| local_file = local_dir / file_info['name'] | |
| # Handle different file types | |
| if file_info['name'].endswith(('.csv', '.json')): | |
| with open(local_file, 'w', encoding='utf-8') as f: | |
| f.write(response.text) | |
| else: # Binary files like PDFs | |
| with open(local_file, 'wb') as f: | |
| f.write(response.content) | |
| except Exception as e: | |
| st.warning(f"Could not download file {file_info['name']}: {str(e)}") | |
| def _get_available_languages(self): | |
| """Get all available language directories""" | |
| if not self.base_path.exists(): | |
| return [] | |
| result_dirs = [d.name for d in self.base_path.iterdir() | |
| if d.is_dir() and d.name.startswith("results_")] | |
| languages = [d.replace("results_", "") for d in result_dirs] | |
| return sorted(languages) | |
| def _get_experimental_configs(self, language): | |
| """Get all experimental configurations for a language""" | |
| lang_dir = self.base_path / f"results_{language}" | |
| if not lang_dir.exists(): | |
| return [] | |
| configs = [d.name for d in lang_dir.iterdir() if d.is_dir()] | |
| return sorted(configs) | |
| def _get_models(self, language, config): | |
| """Get all models for a language and configuration""" | |
| config_dir = self.base_path / f"results_{language}" / config | |
| if not config_dir.exists(): | |
| return [] | |
| models = [d.name for d in config_dir.iterdir() if d.is_dir()] | |
| return sorted(models) | |
| def _parse_config_name(self, config_name): | |
| """Parse configuration name into readable format""" | |
| parts = config_name.split('+') | |
| config_dict = {} | |
| for part in parts: | |
| if '_' in part: | |
| key, value = part.split('_', 1) | |
| config_dict[key.replace('_', ' ').title()] = value | |
| return config_dict | |
| def _load_metadata(self, language, config, model): | |
| """Load metadata for a specific combination""" | |
| metadata_path = self.base_path / f"results_{language}" / config / model / "metadata" / "metadata.json" | |
| if metadata_path.exists(): | |
| with open(metadata_path, 'r') as f: | |
| return json.load(f) | |
| return None | |
| def _load_uas_scores(self, language, config, model): | |
| """Load UAS scores data""" | |
| uas_dir = self.base_path / f"results_{language}" / config / model / "uas_scores" | |
| if not uas_dir.exists(): | |
| return {} | |
| uas_data = {} | |
| csv_files = list(uas_dir.glob("uas_*.csv")) | |
| if csv_files: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| for i, csv_file in enumerate(csv_files): | |
| relation = csv_file.stem.replace("uas_", "") | |
| status_text.text(f"Loading UAS data: {relation}") | |
| try: | |
| df = pd.read_csv(csv_file, index_col=0) | |
| uas_data[relation] = df | |
| except Exception as e: | |
| st.warning(f"Could not load {csv_file.name}: {e}") | |
| progress_bar.progress((i + 1) / len(csv_files)) | |
| progress_bar.empty() | |
| status_text.empty() | |
| return uas_data | |
| def _load_head_matching(self, language, config, model): | |
| """Load head matching data""" | |
| heads_dir = self.base_path / f"results_{language}" / config / model / "number_of_heads_matching" | |
| if not heads_dir.exists(): | |
| return {} | |
| heads_data = {} | |
| csv_files = list(heads_dir.glob("heads_matching_*.csv")) | |
| if csv_files: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| for i, csv_file in enumerate(csv_files): | |
| relation = csv_file.stem.replace("heads_matching_", "").replace(f"_{model}", "") | |
| status_text.text(f"Loading head matching data: {relation}") | |
| try: | |
| df = pd.read_csv(csv_file, index_col=0) | |
| heads_data[relation] = df | |
| except Exception as e: | |
| st.warning(f"Could not load {csv_file.name}: {e}") | |
| progress_bar.progress((i + 1) / len(csv_files)) | |
| progress_bar.empty() | |
| status_text.empty() | |
| return heads_data | |
| def _load_variability(self, language, config, model): | |
| """Load variability data""" | |
| var_path = self.base_path / f"results_{language}" / config / model / "variability" / "variability_list.csv" | |
| if var_path.exists(): | |
| try: | |
| return pd.read_csv(var_path, index_col=0) | |
| except Exception as e: | |
| st.warning(f"Could not load variability data: {e}") | |
| return None | |
| def _get_available_figures(self, language, config, model): | |
| """Get all available figure files""" | |
| figures_dir = self.base_path / f"results_{language}" / config / model / "figures" | |
| if not figures_dir.exists(): | |
| return [] | |
| return list(figures_dir.glob("*.pdf")) | |
| def main(): | |
| # Title | |
| st.markdown('<div class="main-header">🔍 Attention Analysis Results Explorer</div>', unsafe_allow_html=True) | |
| # Sidebar for navigation | |
| st.sidebar.title("🔧 Configuration") | |
| # Cache management section | |
| st.sidebar.markdown("### 📁 Data Management") | |
| # Initialize explorer | |
| use_cache = st.sidebar.checkbox("Use cached data", value=True, | |
| help="Use previously downloaded data if available") | |
| if st.sidebar.button("🔄 Refresh Data", help="Download fresh data from GitHub"): | |
| # Clear cache and re-download | |
| cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" | |
| if cache_dir.exists(): | |
| shutil.rmtree(cache_dir) | |
| st.rerun() | |
| # Show cache status | |
| cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" | |
| if cache_dir.exists(): | |
| st.sidebar.success("✅ Data cached locally") | |
| else: | |
| st.sidebar.info("📥 Will download data from GitHub") | |
| st.sidebar.markdown("---") | |
| # Initialize explorer with error handling | |
| try: | |
| explorer = AttentionResultsExplorer(use_cache=use_cache) | |
| except Exception as e: | |
| st.error(f"❌ Failed to initialize data explorer: {str(e)}") | |
| st.error("Please check your internet connection and try again.") | |
| return | |
| # Check if any languages are available | |
| if not explorer.languages: | |
| st.error("❌ No result data found. Please check the GitHub repository.") | |
| return | |
| # Language selection | |
| selected_language = st.sidebar.selectbox( | |
| "Select Language", | |
| options=explorer.languages, | |
| help="Choose the language dataset to explore" | |
| ) | |
| # Get configurations for selected language | |
| configs = explorer._get_experimental_configs(selected_language) | |
| if not configs: | |
| st.error(f"No configurations found for language: {selected_language}") | |
| return | |
| # Configuration selection | |
| selected_config = st.sidebar.selectbox( | |
| "Select Experimental Configuration", | |
| options=configs, | |
| help="Choose the experimental configuration" | |
| ) | |
| # Parse and display configuration details | |
| config_details = explorer._parse_config_name(selected_config) | |
| st.sidebar.markdown("**Configuration Details:**") | |
| for key, value in config_details.items(): | |
| st.sidebar.markdown(f"- **{key}**: {value}") | |
| # Get models for selected language and config | |
| models = explorer._get_models(selected_language, selected_config) | |
| if not models: | |
| st.error(f"No models found for {selected_language}/{selected_config}") | |
| return | |
| # Model selection | |
| selected_model = st.sidebar.selectbox( | |
| "Select Model", | |
| options=models, | |
| help="Choose the model to analyze" | |
| ) | |
| # Main content area | |
| tab1, tab2, tab3, tab4, tab5 = st.tabs([ | |
| "📊 Overview", | |
| "🎯 UAS Scores", | |
| "🧠 Head Matching", | |
| "📈 Variability", | |
| "🖼️ Figures" | |
| ]) | |
| # Tab 1: Overview | |
| with tab1: | |
| st.markdown('<div class="section-header">Experiment Overview</div>', unsafe_allow_html=True) | |
| # Load metadata | |
| metadata = explorer._load_metadata(selected_language, selected_config, selected_model) | |
| if metadata: | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Total Samples", metadata.get('total_number', 'N/A')) | |
| with col2: | |
| st.metric("Processed Correctly", metadata.get('number_processed_correctly', 'N/A')) | |
| with col3: | |
| st.metric("Errors", metadata.get('number_errored', 'N/A')) | |
| with col4: | |
| success_rate = (metadata.get('number_processed_correctly', 0) / | |
| metadata.get('total_number', 1)) * 100 if metadata.get('total_number') else 0 | |
| st.metric("Success Rate", f"{success_rate:.1f}%") | |
| st.markdown("**Random Seed:**", metadata.get('random_seed', 'N/A')) | |
| if metadata.get('errored_phrases'): | |
| st.markdown("**Errored Phrase IDs:**") | |
| st.write(metadata['errored_phrases']) | |
| else: | |
| st.warning("No metadata available for this configuration.") | |
| # Quick stats about available data | |
| st.markdown('<div class="section-header">Available Data</div>', unsafe_allow_html=True) | |
| uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model) | |
| heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model) | |
| variability_data = explorer._load_variability(selected_language, selected_config, selected_model) | |
| figures = explorer._get_available_figures(selected_language, selected_config, selected_model) | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("UAS Relations", len(uas_data)) | |
| with col2: | |
| st.metric("Head Matching Relations", len(heads_data)) | |
| with col3: | |
| st.metric("Variability Data", "✓" if variability_data is not None else "✗") | |
| with col4: | |
| st.metric("Figure Files", len(figures)) | |
| # Tab 2: UAS Scores | |
| with tab2: | |
| st.markdown('<div class="section-header">UAS (Unlabeled Attachment Score) Analysis</div>', unsafe_allow_html=True) | |
| uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model) | |
| if uas_data: | |
| # Relation selection | |
| selected_relation = st.selectbox( | |
| "Select Dependency Relation", | |
| options=list(uas_data.keys()), | |
| help="Choose a dependency relation to visualize UAS scores" | |
| ) | |
| if selected_relation and selected_relation in uas_data: | |
| df = uas_data[selected_relation] | |
| # Display the data table | |
| st.markdown("**UAS Scores Matrix (Layer × Head)**") | |
| st.dataframe(df, use_container_width=True) | |
| # Create heatmap | |
| fig = px.imshow( | |
| df.values, | |
| x=[f"Head {i}" for i in df.columns], | |
| y=[f"Layer {i}" for i in df.index], | |
| color_continuous_scale="Viridis", | |
| title=f"UAS Scores Heatmap - {selected_relation}", | |
| labels=dict(color="UAS Score") | |
| ) | |
| fig.update_layout(height=600) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Statistics | |
| st.markdown("**Statistics**") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Max Score", f"{df.values.max():.4f}") | |
| with col2: | |
| st.metric("Min Score", f"{df.values.min():.4f}") | |
| with col3: | |
| st.metric("Mean Score", f"{df.values.mean():.4f}") | |
| with col4: | |
| st.metric("Std Dev", f"{df.values.std():.4f}") | |
| else: | |
| st.warning("No UAS score data available for this configuration.") | |
| # Tab 3: Head Matching | |
| with tab3: | |
| st.markdown('<div class="section-header">Attention Head Matching Analysis</div>', unsafe_allow_html=True) | |
| heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model) | |
| if heads_data: | |
| # Relation selection | |
| selected_relation = st.selectbox( | |
| "Select Dependency Relation", | |
| options=list(heads_data.keys()), | |
| help="Choose a dependency relation to visualize head matching patterns", | |
| key="heads_relation" | |
| ) | |
| if selected_relation and selected_relation in heads_data: | |
| df = heads_data[selected_relation] | |
| # Display the data table | |
| st.markdown("**Head Matching Counts Matrix (Layer × Head)**") | |
| st.dataframe(df, use_container_width=True) | |
| # Create heatmap | |
| fig = px.imshow( | |
| df.values, | |
| x=[f"Head {i}" for i in df.columns], | |
| y=[f"Layer {i}" for i in df.index], | |
| color_continuous_scale="Blues", | |
| title=f"Head Matching Counts - {selected_relation}", | |
| labels=dict(color="Match Count") | |
| ) | |
| fig.update_layout(height=600) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Create bar chart of total matches per layer | |
| layer_totals = df.sum(axis=1) | |
| fig_bar = px.bar( | |
| x=layer_totals.index, | |
| y=layer_totals.values, | |
| title=f"Total Matches per Layer - {selected_relation}", | |
| labels={"x": "Layer", "y": "Total Matches"} | |
| ) | |
| fig_bar.update_layout(height=400) | |
| st.plotly_chart(fig_bar, use_container_width=True) | |
| # Statistics | |
| st.markdown("**Statistics**") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Total Matches", int(df.values.sum())) | |
| with col2: | |
| st.metric("Max per Cell", int(df.values.max())) | |
| with col3: | |
| best_layer = layer_totals.idxmax() | |
| st.metric("Best Layer", f"Layer {best_layer}") | |
| with col4: | |
| best_head_idx = np.unravel_index(df.values.argmax(), df.values.shape) | |
| st.metric("Best Head", f"L{best_head_idx[0]}-H{best_head_idx[1]}") | |
| else: | |
| st.warning("No head matching data available for this configuration.") | |
| # Tab 4: Variability | |
| with tab4: | |
| st.markdown('<div class="section-header">Attention Variability Analysis</div>', unsafe_allow_html=True) | |
| variability_data = explorer._load_variability(selected_language, selected_config, selected_model) | |
| if variability_data is not None: | |
| # Display the data table | |
| st.markdown("**Variability Matrix (Layer × Head)**") | |
| st.dataframe(variability_data, use_container_width=True) | |
| # Create heatmap | |
| fig = px.imshow( | |
| variability_data.values, | |
| x=[f"Head {i}" for i in variability_data.columns], | |
| y=[f"Layer {i}" for i in variability_data.index], | |
| color_continuous_scale="Reds", | |
| title="Attention Variability Heatmap", | |
| labels=dict(color="Variability Score") | |
| ) | |
| fig.update_layout(height=600) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Create line plot for variability trends | |
| fig_line = go.Figure() | |
| for col in variability_data.columns: | |
| fig_line.add_trace(go.Scatter( | |
| x=variability_data.index, | |
| y=variability_data[col], | |
| mode='lines+markers', | |
| name=f'Head {col}', | |
| line=dict(width=2) | |
| )) | |
| fig_line.update_layout( | |
| title="Variability Trends Across Layers", | |
| xaxis_title="Layer", | |
| yaxis_title="Variability Score", | |
| height=500 | |
| ) | |
| st.plotly_chart(fig_line, use_container_width=True) | |
| # Statistics | |
| st.markdown("**Statistics**") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Max Variability", f"{variability_data.values.max():.4f}") | |
| with col2: | |
| st.metric("Min Variability", f"{variability_data.values.min():.4f}") | |
| with col3: | |
| st.metric("Mean Variability", f"{variability_data.values.mean():.4f}") | |
| with col4: | |
| most_variable_idx = np.unravel_index(variability_data.values.argmax(), variability_data.values.shape) | |
| st.metric("Most Variable", f"L{most_variable_idx[0]}-H{most_variable_idx[1]}") | |
| else: | |
| st.warning("No variability data available for this configuration.") | |
| # Tab 5: Figures | |
| with tab5: | |
| st.markdown('<div class="section-header">Generated Figures</div>', unsafe_allow_html=True) | |
| figures = explorer._get_available_figures(selected_language, selected_config, selected_model) | |
| if figures: | |
| st.markdown(f"**Available Figures: {len(figures)}**") | |
| # Group figures by relation type | |
| figure_groups = {} | |
| for fig_path in figures: | |
| # Extract relation from filename | |
| filename = fig_path.stem | |
| relation = filename.replace("heads_matching_", "").replace(f"_{selected_model}", "") | |
| if relation not in figure_groups: | |
| figure_groups[relation] = [] | |
| figure_groups[relation].append(fig_path) | |
| # Select relation to view | |
| selected_fig_relation = st.selectbox( | |
| "Select Relation for Figure View", | |
| options=list(figure_groups.keys()), | |
| help="Choose a dependency relation to view its figure" | |
| ) | |
| if selected_fig_relation and selected_fig_relation in figure_groups: | |
| fig_path = figure_groups[selected_fig_relation][0] | |
| st.markdown(f"**Figure: {fig_path.name}**") | |
| st.markdown(f"**Path:** `{fig_path}`") | |
| # Note about PDF viewing | |
| st.info( | |
| "📄 PDF figures are available in the results directory. " | |
| "Due to Streamlit limitations, PDF files cannot be displayed directly in the browser. " | |
| "You can download or view them locally." | |
| ) | |
| # Provide download link | |
| try: | |
| with open(fig_path, "rb") as file: | |
| st.download_button( | |
| label=f"📥 Download {fig_path.name}", | |
| data=file.read(), | |
| file_name=fig_path.name, | |
| mime="application/pdf" | |
| ) | |
| except Exception as e: | |
| st.error(f"Could not load figure: {e}") | |
| # List all available figures | |
| st.markdown("**All Available Figures:**") | |
| for relation, paths in figure_groups.items(): | |
| with st.expander(f"📊 {relation} ({len(paths)} files)"): | |
| for path in paths: | |
| st.markdown(f"- `{path.name}`") | |
| else: | |
| st.warning("No figures available for this configuration.") | |
| # Footer | |
| st.markdown("---") | |
| # Data source information | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown( | |
| "🔬 **Attention Analysis Results Explorer** | " | |
| f"Currently viewing: {selected_language.upper()} - {selected_model} | " | |
| "Built with Streamlit" | |
| ) | |
| with col2: | |
| st.markdown( | |
| f"📊 **Data Source**: [GitHub Repository](https://github.com/{explorer.github_repo})" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |