Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import io | |
| import base64 | |
| import ast # For safely evaluating Python literals | |
| # Function to process and visualize log probs | |
| def visualize_logprobs(json_input): | |
| try: | |
| # Try to parse as JSON first | |
| try: | |
| data = json.loads(json_input) | |
| except json.JSONDecodeError: | |
| # If JSON fails, try to parse as Python literal (e.g., with single quotes) | |
| data = ast.literal_eval(json_input) | |
| # Extract tokens and log probs, skipping None values | |
| tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None] | |
| logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None] | |
| # Prepare data for the table | |
| table_data = [] | |
| for entry in data['content']: | |
| if entry['logprob'] is not None: | |
| token = entry['token'] | |
| logprob = entry['logprob'] | |
| top_logprobs = entry['top_logprobs'] | |
| # Extract top 3 alternatives, sorted by log prob (most probable first) | |
| top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] | |
| row = [token, f"{logprob:.4f}"] | |
| for alt_token, alt_logprob in top_3: | |
| row.append(f"{alt_token}: {alt_logprob:.4f}") | |
| # Pad with empty strings if fewer than 3 alternatives | |
| while len(row) < 5: | |
| row.append("") | |
| table_data.append(row) | |
| # Create the plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b') | |
| plt.title("Log Probabilities of Generated Tokens") | |
| plt.xlabel("Token Position") | |
| plt.ylabel("Log Probability") | |
| plt.grid(True) | |
| plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right') | |
| plt.tight_layout() | |
| # Save plot to a bytes buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| # Convert buffer to base64 for Gradio | |
| img_bytes = buf.getvalue() | |
| img_base64 = base64.b64encode(img_bytes).decode('utf-8') | |
| img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">' | |
| # Create a DataFrame for the table | |
| df = pd.DataFrame( | |
| table_data, | |
| columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"] | |
| ) | |
| return img_html, df | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Gradio interface | |
| with gr.Blocks(title="Log Probability Visualizer") as app: | |
| gr.Markdown("# Log Probability Visualizer") | |
| gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.") | |
| # Input | |
| json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...") | |
| # Outputs | |
| plot_output = gr.HTML(label="Log Probability Plot") | |
| table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") | |
| # Button to trigger visualization | |
| btn = gr.Button("Visualize") | |
| btn.click( | |
| fn=visualize_logprobs, | |
| inputs=json_input, | |
| outputs=[plot_output, table_output] | |
| ) | |
| # Launch the app | |
| app.launch() |