Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import io | |
| import base64 | |
| import math | |
| import ast | |
| import logging | |
| import numpy as np | |
| from sklearn.cluster import KMeans | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| from scipy import stats | |
| from scipy.stats import entropy | |
| from scipy.signal import correlate | |
| import networkx as nx | |
| from matplotlib.widgets import Cursor | |
| # Set up logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # Function to safely parse JSON or Python dictionary input | |
| def parse_input(json_input): | |
| logger.debug("Attempting to parse input: %s", json_input) | |
| try: | |
| # Try to parse as JSON first | |
| data = json.loads(json_input) | |
| logger.debug("Successfully parsed as JSON") | |
| return data | |
| except json.JSONDecodeError as e: | |
| logger.error("JSON parsing failed: %s", str(e)) | |
| try: | |
| # If JSON fails, try to parse as Python literal (e.g., with single quotes) | |
| data = ast.literal_eval(json_input) | |
| logger.debug("Successfully parsed as Python literal") | |
| # Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes) | |
| def dict_to_json(obj): | |
| if isinstance(obj, dict): | |
| return {str(k): dict_to_json(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [dict_to_json(item) for item in obj] | |
| else: | |
| return obj | |
| converted_data = dict_to_json(data) | |
| logger.debug("Converted to JSON-compatible format") | |
| return converted_data | |
| except (SyntaxError, ValueError) as e: | |
| logger.error("Python literal parsing failed: %s", str(e)) | |
| raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") or correct Python dictionary format.") | |
| # Function to ensure a value is a float, converting from string if necessary | |
| def ensure_float(value): | |
| if value is None: | |
| return None | |
| if isinstance(value, str): | |
| try: | |
| return float(value) | |
| except ValueError: | |
| logger.error("Failed to convert string '%s' to float", value) | |
| return None | |
| if isinstance(value, (int, float)): | |
| return float(value) | |
| return None | |
| # Function to process and visualize log probs with multiple analyses | |
| def visualize_logprobs(json_input, prob_filter=-float('inf')): | |
| try: | |
| # Parse the input (handles both JSON and Python dictionaries) | |
| data = parse_input(json_input) | |
| # Ensure data is a list or dictionary with 'content' | |
| if isinstance(data, dict) and "content" in data: | |
| content = data["content"] | |
| elif isinstance(data, list): | |
| content = data | |
| else: | |
| raise ValueError("Input must be a list or dictionary with 'content' key") | |
| # Extract tokens, log probs, and top alternatives, skipping None or non-finite values | |
| tokens = [] | |
| logprobs = [] | |
| top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives) | |
| token_types = [] # Simplified token type categorization | |
| for entry in content: | |
| logprob = ensure_float(entry.get("logprob", None)) | |
| if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter: | |
| tokens.append(entry["token"]) | |
| logprobs.append(logprob) | |
| # Categorize token type (simple heuristic) | |
| token = entry["token"].lower().strip() | |
| if token in ["the", "a", "an"]: token_types.append("article") | |
| elif token in ["is", "are", "was", "were"]: token_types.append("verb") | |
| elif token in ["top", "so", "need", "figure"]: token_types.append("noun") | |
| else: token_types.append("other") | |
| # Get top_logprobs, default to empty dict if None | |
| top_probs = entry.get("top_logprobs", {}) | |
| # Ensure all values in top_logprobs are floats | |
| finite_top_probs = {} | |
| for key, value in top_probs.items(): | |
| float_value = ensure_float(value) | |
| if float_value is not None and math.isfinite(float_value): | |
| finite_top_probs[key] = float_value | |
| # Get the top 3 log probs (including the selected token) | |
| all_probs = {entry["token"]: logprob} # Add the selected token's logprob | |
| all_probs.update(finite_top_probs) # Add alternatives | |
| sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) | |
| top_3 = sorted_probs[:3] # Top 3 log probs (highest to lowest) | |
| top_alternatives.append(top_3) | |
| else: | |
| logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None))) | |
| # If no valid data after filtering, return error messages | |
| if not logprobs: | |
| return "No finite log probabilities to visualize after filtering.", None, None, None, None, None, None, None, None, None, None | |
| # 1. Main Log Probability Plot (with click for tokens) | |
| fig_main, ax_main = plt.subplots(figsize=(10, 5)) | |
| scatter = ax_main.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0] | |
| ax_main.set_title("Log Probabilities of Generated Tokens") | |
| ax_main.set_xlabel("Token Position") | |
| ax_main.set_ylabel("Log Probability") | |
| ax_main.grid(True) | |
| ax_main.set_xticks([]) # Hide X-axis labels by default | |
| # Add click functionality to show token | |
| token_annotations = [] | |
| for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)): | |
| annotation = ax_main.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False) | |
| token_annotations.append(annotation) | |
| def on_click(event): | |
| if event.inaxes == ax_main: | |
| for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)): | |
| contains, _ = scatter.contains(event) | |
| if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5: | |
| token_annotations[i].set_text(tokens[i]) | |
| token_annotations[i].set_visible(True) | |
| fig_main.canvas.draw_idle() | |
| else: | |
| token_annotations[i].set_visible(False) | |
| fig_main.canvas.draw_idle() | |
| fig_main.canvas.mpl_connect('button_press_event', on_click) | |
| # Save main plot | |
| buf_main = io.BytesIO() | |
| plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100) | |
| buf_main.seek(0) | |
| plt.close(fig_main) | |
| img_main_bytes = buf_main.getvalue() | |
| img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8") | |
| img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">' | |
| # 2. K-Means Clustering of Log Probabilities | |
| kmeans = KMeans(n_clusters=3, random_state=42) | |
| cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1)) | |
| fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5)) | |
| scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis') | |
| ax_cluster.set_title("K-Means Clustering of Log Probabilities") | |
| ax_cluster.set_xlabel("Token Position") | |
| ax_cluster.set_ylabel("Log Probability") | |
| ax_cluster.grid(True) | |
| plt.colorbar(scatter, ax=ax_cluster, label="Cluster") | |
| buf_cluster = io.BytesIO() | |
| plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100) | |
| buf_cluster.seek(0) | |
| plt.close(fig_cluster) | |
| img_cluster_bytes = buf_cluster.getvalue() | |
| img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8") | |
| img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">' | |
| # 3. Probability Drop Analysis | |
| drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))] | |
| fig_drops, ax_drops = plt.subplots(figsize=(10, 5)) | |
| ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5) | |
| ax_drops.set_title("Significant Probability Drops") | |
| ax_drops.set_xlabel("Token Position") | |
| ax_drops.set_ylabel("Log Probability Drop") | |
| ax_drops.grid(True) | |
| buf_drops = io.BytesIO() | |
| plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100) | |
| buf_drops.seek(0) | |
| plt.close(fig_drops) | |
| img_drops_bytes = buf_drops.getvalue() | |
| img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8") | |
| img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">' | |
| # 4. N-Gram Analysis (Bigrams for simplicity) | |
| bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)] | |
| bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)] | |
| fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5)) | |
| ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green') | |
| ax_ngram.set_title("N-Gram (Bigrams) Probability Sum") | |
| ax_ngram.set_xlabel("Bigram Position") | |
| ax_ngram.set_ylabel("Sum of Log Probabilities") | |
| ax_ngram.set_xticks(range(len(bigrams))) | |
| ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right") | |
| ax_ngram.grid(True) | |
| buf_ngram = io.BytesIO() | |
| plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100) | |
| buf_ngram.seek(0) | |
| plt.close(fig_ngram) | |
| img_ngram_bytes = buf_ngram.getvalue() | |
| img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8") | |
| img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">' | |
| # 5. Markov Chain Modeling (Simple Graph) | |
| G = nx.DiGraph() | |
| for i in range(len(tokens)-1): | |
| G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i]) | |
| fig_markov, ax_markov = plt.subplots(figsize=(10, 5)) | |
| pos = nx.spring_layout(G) | |
| nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_markov) | |
| ax_markov.set_title("Markov Chain of Token Transitions") | |
| buf_markov = io.BytesIO() | |
| plt.savefig(buf_markov, format="png", bbox_inches="tight", dpi=100) | |
| buf_markov.seek(0) | |
| plt.close(fig_markov) | |
| img_markov_bytes = buf_markov.getvalue() | |
| img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8") | |
| img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">' | |
| # 6. Anomaly Detection (Outlier Detection with Z-Score) | |
| z_scores = np.abs(stats.zscore(logprobs)) | |
| outliers = z_scores > 2 # Threshold for outliers | |
| fig_anomaly, ax_anomaly = plt.subplots(figsize=(10, 5)) | |
| ax_anomaly.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b") | |
| ax_anomaly.plot(np.where(outliers)[0], [logprobs[i] for i in np.where(outliers)[0]], "ro", label="Outliers") | |
| ax_anomaly.set_title("Log Probabilities with Outliers") | |
| ax_anomaly.set_xlabel("Token Position") | |
| ax_anomaly.set_ylabel("Log Probability") | |
| ax_anomaly.grid(True) | |
| ax_anomaly.legend() | |
| ax_anomaly.set_xticks([]) # Hide X-axis labels | |
| buf_anomaly = io.BytesIO() | |
| plt.savefig(buf_anomaly, format="png", bbox_inches="tight", dpi=100) | |
| buf_anomaly.seek(0) | |
| plt.close(fig_anomaly) | |
| img_anomaly_bytes = buf_anomaly.getvalue() | |
| img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8") | |
| img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">' | |
| # 7. Autocorrelation | |
| autocorr = correlate(logprobs, logprobs, mode='full') | |
| autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize | |
| fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5)) | |
| ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple') | |
| ax_autocorr.set_title("Autocorrelation of Log Probabilities") | |
| ax_autocorr.set_xlabel("Lag") | |
| ax_autocorr.set_ylabel("Autocorrelation") | |
| ax_autocorr.grid(True) | |
| buf_autocorr = io.BytesIO() | |
| plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100) | |
| buf_autocorr.seek(0) | |
| plt.close(fig_autocorr) | |
| img_autocorr_bytes = buf_autocorr.getvalue() | |
| img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8") | |
| img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">' | |
| # 8. Smoothing (Moving Average) | |
| window_size = 3 | |
| moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid') | |
| fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5)) | |
| ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original") | |
| ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average") | |
| ax_smoothing.set_title("Log Probabilities with Moving Average") | |
| ax_smoothing.set_xlabel("Token Position") | |
| ax_smoothing.set_ylabel("Log Probability") | |
| ax_smoothing.grid(True) | |
| ax_smoothing.legend() | |
| ax_smoothing.set_xticks([]) # Hide X-axis labels | |
| buf_smoothing = io.BytesIO() | |
| plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100) | |
| buf_smoothing.seek(0) | |
| plt.close(fig_smoothing) | |
| img_smoothing_bytes = buf_smoothing.getvalue() | |
| img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8") | |
| img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">' | |
| # 9. Uncertainty Propagation (Variance of Top Logprobs) | |
| variances = [] | |
| for probs in top_alternatives: | |
| if len(probs) > 1: | |
| values = [p[1] for p in probs] | |
| variances.append(np.var(values)) | |
| else: | |
| variances.append(0) | |
| fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5)) | |
| ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob") | |
| ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)], | |
| [lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty") | |
| ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation") | |
| ax_uncertainty.set_xlabel("Token Position") | |
| ax_uncertainty.set_ylabel("Log Probability") | |
| ax_uncertainty.grid(True) | |
| ax_uncertainty.legend() | |
| ax_uncertainty.set_xticks([]) # Hide X-axis labels | |
| buf_uncertainty = io.BytesIO() | |
| plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100) | |
| buf_uncertainty.seek(0) | |
| plt.close(fig_uncertainty) | |
| img_uncertainty_bytes = buf_uncertainty.getvalue() | |
| img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8") | |
| img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">' | |
| # 10. Correlation Heatmap | |
| corr_matrix = np.corrcoef(logprobs, rowvar=False) | |
| fig_corr, ax_corr = plt.subplots(figsize=(10, 5)) | |
| im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest') | |
| ax_corr.set_title("Correlation of Log Probabilities Across Positions") | |
| ax_corr.set_xlabel("Token Position") | |
| ax_corr.set_ylabel("Token Position") | |
| plt.colorbar(im, ax=ax_corr, label="Correlation") | |
| buf_corr = io.BytesIO() | |
| plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100) | |
| buf_corr.seek(0) | |
| plt.close(fig_corr) | |
| img_corr_bytes = buf_corr.getvalue() | |
| img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8") | |
| img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">' | |
| # 11. Token Type Correlation | |
| type_probs = {t: [] for t in set(token_types)} | |
| for t, p in zip(token_types, logprobs): | |
| type_probs[t].append(p) | |
| fig_type, ax_type = plt.subplots(figsize=(10, 5)) | |
| for t in type_probs: | |
| ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t) | |
| ax_type.set_title("Average Log Probability by Token Type") | |
| ax_type.set_xlabel("Token Type") | |
| ax_type.set_ylabel("Average Log Probability") | |
| ax_type.grid(True) | |
| ax_type.legend() | |
| buf_type = io.BytesIO() | |
| plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100) | |
| buf_type.seek(0) | |
| plt.close(fig_type) | |
| img_type_bytes = buf_type.getvalue() | |
| img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8") | |
| img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">' | |
| # 12. Token Embedding Similarity vs. Probability (Simulated) | |
| # Simulate embedding distances (e.g., cosine similarity) as random values for demonstration | |
| simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings | |
| fig_embed, ax_embed = plt.subplots(figsize=(10, 5)) | |
| ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis') | |
| ax_embed.set_title("Token Embedding Similarity vs. Log Probability") | |
| ax_embed.set_xlabel("Embedding Dimension 1") | |
| ax_embed.set_ylabel("Embedding Dimension 2") | |
| plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability") | |
| buf_embed = io.BytesIO() | |
| plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100) | |
| buf_embed.seek(0) | |
| plt.close(fig_embed) | |
| img_embed_bytes = buf_embed.getvalue() | |
| img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8") | |
| img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">' | |
| # 13. Bayesian Inference (Simplified as Inferred Probabilities) | |
| # Simulate inferred probabilities based on top_logprobs entropy | |
| entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1] | |
| fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5)) | |
| ax_bayesian.bar(range(len(entropies)), entropies, color='orange') | |
| ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)") | |
| ax_bayesian.set_xlabel("Token Position") | |
| ax_bayesian.set_ylabel("Entropy") | |
| ax_bayesian.grid(True) | |
| buf_bayesian = io.BytesIO() | |
| plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100) | |
| buf_bayesian.seek(0) | |
| plt.close(fig_bayesian) | |
| img_bayesian_bytes = buf_bayesian.getvalue() | |
| img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8") | |
| img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">' | |
| # 14. Graph-Based Analysis | |
| G = nx.DiGraph() | |
| for i in range(len(tokens)-1): | |
| G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i]) | |
| fig_graph, ax_graph = plt.subplots(figsize=(10, 5)) | |
| pos = nx.spring_layout(G) | |
| nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph) | |
| ax_graph.set_title("Graph of Token Transitions") | |
| buf_graph = io.BytesIO() | |
| plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100) | |
| buf_graph.seek(0) | |
| plt.close(fig_graph) | |
| img_graph_bytes = buf_graph.getvalue() | |
| img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8") | |
| img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">' | |
| # 15. Dimensionality Reduction (t-SNE) | |
| features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)]) | |
| tsne = TSNE(n_components=2, random_state=42) | |
| tsne_result = tsne.fit_transform(features.T) | |
| fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5)) | |
| scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis') | |
| ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives") | |
| ax_tsne.set_xlabel("t-SNE Dimension 1") | |
| ax_tsne.set_ylabel("t-SNE Dimension 2") | |
| plt.colorbar(scatter, ax=ax_tsne, label="Log Probability") | |
| buf_tsne = io.BytesIO() | |
| plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100) | |
| buf_tsne.seek(0) | |
| plt.close(fig_tsne) | |
| img_tsne_bytes = buf_tsne.getvalue() | |
| img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8") | |
| img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">' | |
| # 16. Interactive Heatmap | |
| fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5)) | |
| im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto') | |
| ax_heatmap.set_title("Interactive Heatmap of Log Probabilities") | |
| ax_heatmap.set_xlabel("Token Position") | |
| ax_heatmap.set_ylabel("Probability Level") | |
| plt.colorbar(im, ax=ax_heatmap, label="Log Probability") | |
| buf_heatmap = io.BytesIO() | |
| plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100) | |
| buf_heatmap.seek(0) | |
| plt.close(fig_heatmap) | |
| img_heatmap_bytes = buf_heatmap.getvalue() | |
| img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8") | |
| img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">' | |
| # 17. Probability Distribution Plots (Box Plots for Top Logprobs) | |
| all_top_probs = [p[1] for alts in top_alternatives for p in alts] | |
| fig_dist, ax_dist = plt.subplots(figsize=(10, 5)) | |
| ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"]) | |
| ax_dist.set_title("Probability Distribution of Top Tokens") | |
| ax_dist.set_xlabel("Token Type") | |
| ax_dist.set_ylabel("Log Probability") | |
| ax_dist.grid(True) | |
| buf_dist = io.BytesIO() | |
| plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100) | |
| buf_dist.seek(0) | |
| plt.close(fig_dist) | |
| img_dist_bytes = buf_dist.getvalue() | |
| img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8") | |
| img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">' | |
| # Create DataFrame for the table | |
| table_data = [] | |
| for i, entry in enumerate(content): | |
| logprob = ensure_float(entry.get("logprob", None)) | |
| if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None: | |
| token = entry["token"] | |
| top_logprobs = entry["top_logprobs"] | |
| # Ensure all values in top_logprobs are floats | |
| finite_top_logprobs = {} | |
| for key, value in top_logprobs.items(): | |
| float_value = ensure_float(value) | |
| if float_value is not None and math.isfinite(float_value): | |
| finite_top_logprobs[key] = float_value | |
| # Extract top 3 alternatives from top_logprobs | |
| top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] | |
| row = [token, f"{logprob:.4f}"] | |
| for alt_token, alt_logprob in top_3: | |
| row.append(f"{alt_token}: {alt_logprob:.4f}") | |
| while len(row) < 5: | |
| row.append("") | |
| table_data.append(row) | |
| df = ( | |
| pd.DataFrame( | |
| table_data, | |
| columns=[ | |
| "Token", | |
| "Log Prob", | |
| "Top 1 Alternative", | |
| "Top 2 Alternative", | |
| "Top 3 Alternative", | |
| ], | |
| ) | |
| if table_data | |
| else None | |
| ) | |
| # Generate colored text | |
| if logprobs: | |
| min_logprob = min(logprobs) | |
| max_logprob = max(logprobs) | |
| if max_logprob == min_logprob: | |
| normalized_probs = [0.5] * len(logprobs) | |
| else: | |
| normalized_probs = [ | |
| (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs | |
| ] | |
| colored_text = "" | |
| for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)): | |
| r = int(255 * (1 - norm_prob)) # Red for low confidence | |
| g = int(255 * norm_prob) # Green for high confidence | |
| b = 0 | |
| color = f"rgb({r}, {g}, {b})" | |
| colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>' | |
| if i < len(tokens) - 1: | |
| colored_text += " " | |
| colored_text_html = f"<p>{colored_text}</p>" | |
| else: | |
| colored_text_html = "No finite log probabilities to display." | |
| # Top 3 Token Log Probabilities | |
| alt_viz_html = "" | |
| if logprobs and top_alternatives: | |
| alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>" | |
| for i, (token, probs) in enumerate(zip(tokens, top_alternatives)): | |
| alt_viz_html += f"<li>Position {i} (Token: {token}):<br>" | |
| for tok, prob in probs: | |
| alt_viz_html += f"{tok}: {prob:.4f}<br>" | |
| alt_viz_html += "</li>" | |
| alt_viz_html += "</ul>" | |
| return (img_main_html, df, colored_text_html, alt_viz_html, img_cluster_html, img_drops_html, | |
| img_ngram_html, img_markov_html, img_anomaly_html, img_autocorr_html, img_smoothing_html, | |
| img_uncertainty_html, img_corr_html, img_type_html, img_embed_html, img_bayesian_html, | |
| img_graph_html, img_tsne_html, img_heatmap_html, img_dist_html) | |
| except Exception as e: | |
| logger.error("Visualization failed: %s", str(e)) | |
| return (f"Error: {str(e)}", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) | |
| # Gradio interface with dynamic filtering | |
| with gr.Blocks(title="Log Probability Visualizer") as app: | |
| gr.Markdown("# Log Probability Visualizer") | |
| gr.Markdown( | |
| "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter to focus on specific log probability ranges." | |
| ) | |
| with gr.Row(): | |
| json_input = gr.Textbox( | |
| label="JSON Input", | |
| lines=10, | |
| placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...", | |
| ) | |
| prob_filter = gr.Slider(minimum=-float('inf'), maximum=0, value=-float('inf'), label="Log Probability Filter (≥)") | |
| with gr.Row(): | |
| plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)") | |
| cluster_output = gr.HTML(label="K-Means Clustering") | |
| drops_output = gr.HTML(label="Probability Drops") | |
| with gr.Row(): | |
| ngram_output = gr.HTML(label="N-Gram Analysis") | |
| markov_output = gr.HTML(label="Markov Chain") | |
| with gr.Row(): | |
| anomaly_output = gr.HTML(label="Anomaly Detection") | |
| autocorr_output = gr.HTML(label="Autocorrelation") | |
| with gr.Row(): | |
| smoothing_output = gr.HTML(label="Smoothing (Moving Average)") | |
| uncertainty_output = gr.HTML(label="Uncertainty Propagation") | |
| with gr.Row(): | |
| corr_output = gr.HTML(label="Correlation Heatmap") | |
| type_output = gr.HTML(label="Token Type Correlation") | |
| with gr.Row(): | |
| embed_output = gr.HTML(label="Embedding Similarity vs. Probability") | |
| bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)") | |
| with gr.Row(): | |
| graph_output = gr.HTML(label="Graph of Token Transitions") | |
| tsne_output = gr.HTML(label="t-SNE of Log Probabilities") | |
| with gr.Row(): | |
| heatmap_output = gr.HTML(label="Interactive Heatmap") | |
| dist_output = gr.HTML(label="Probability Distribution") | |
| table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") | |
| text_output = gr.HTML(label="Colored Text (Confidence Visualization)") | |
| alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities") | |
| btn = gr.Button("Visualize") | |
| btn.click( | |
| fn=visualize_logprobs, | |
| inputs=[json_input, prob_filter], | |
| outputs=[ | |
| plot_output, table_output, text_output, alt_viz_output, | |
| cluster_output, drops_output, ngram_output, markov_output, | |
| anomaly_output, autocorr_output, smoothing_output, uncertainty_output, | |
| corr_output, type_output, embed_output, bayesian_output, | |
| graph_output, tsne_output, heatmap_output, dist_output | |
| ], | |
| ) | |
| app.launch() |