Spaces:
Running
Running
Update app.py (#1)
Browse files- Update app.py (b7159a07f2b7b4eb2ba22113df0d4c04de90a6bc)
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import base64
|
|
| 7 |
import math
|
| 8 |
import ast
|
| 9 |
import logging
|
|
|
|
| 10 |
|
| 11 |
# Set up logging
|
| 12 |
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -55,7 +56,7 @@ def ensure_float(value):
|
|
| 55 |
return float(value)
|
| 56 |
return None
|
| 57 |
|
| 58 |
-
# Function to process and visualize log probs
|
| 59 |
def visualize_logprobs(json_input):
|
| 60 |
try:
|
| 61 |
# Parse the input (handles both JSON and Python dictionaries)
|
|
@@ -69,30 +70,82 @@ def visualize_logprobs(json_input):
|
|
| 69 |
else:
|
| 70 |
raise ValueError("Input must be a list or dictionary with 'content' key")
|
| 71 |
|
| 72 |
-
# Extract tokens
|
| 73 |
tokens = []
|
| 74 |
logprobs = []
|
|
|
|
| 75 |
for entry in content:
|
| 76 |
logprob = ensure_float(entry.get("logprob", None))
|
| 77 |
if logprob is not None and math.isfinite(logprob):
|
| 78 |
tokens.append(entry["token"])
|
| 79 |
logprobs.append(logprob)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
else:
|
| 81 |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
|
| 82 |
|
| 83 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
table_data = []
|
| 85 |
-
for entry in content:
|
| 86 |
logprob = ensure_float(entry.get("logprob", None))
|
| 87 |
-
|
| 88 |
-
if (
|
| 89 |
-
logprob is not None
|
| 90 |
-
and math.isfinite(logprob)
|
| 91 |
-
and "top_logprobs" in entry
|
| 92 |
-
and entry["top_logprobs"] is not None
|
| 93 |
-
):
|
| 94 |
token = entry["token"]
|
| 95 |
-
logger.debug("Processing token: %s, logprob: %s (type: %s)", token, logprob, type(logprob))
|
| 96 |
top_logprobs = entry["top_logprobs"]
|
| 97 |
# Ensure all values in top_logprobs are floats
|
| 98 |
finite_top_logprobs = {}
|
|
@@ -100,44 +153,15 @@ def visualize_logprobs(json_input):
|
|
| 100 |
float_value = ensure_float(value)
|
| 101 |
if float_value is not None and math.isfinite(float_value):
|
| 102 |
finite_top_logprobs[key] = float_value
|
| 103 |
-
|
| 104 |
# Extract top 3 alternatives from top_logprobs
|
| 105 |
-
top_3 = sorted(
|
| 106 |
-
finite_top_logprobs.items(), key=lambda x: x[1], reverse=True
|
| 107 |
-
)[:3]
|
| 108 |
row = [token, f"{logprob:.4f}"]
|
| 109 |
for alt_token, alt_logprob in top_3:
|
| 110 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
| 111 |
-
# Pad with empty strings if fewer than 3 alternatives
|
| 112 |
while len(row) < 5:
|
| 113 |
row.append("")
|
| 114 |
table_data.append(row)
|
| 115 |
|
| 116 |
-
# Create the plot
|
| 117 |
-
if logprobs:
|
| 118 |
-
plt.figure(figsize=(10, 5))
|
| 119 |
-
plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
|
| 120 |
-
plt.title("Log Probabilities of Generated Tokens")
|
| 121 |
-
plt.xlabel("Token Position")
|
| 122 |
-
plt.ylabel("Log Probability")
|
| 123 |
-
plt.grid(True)
|
| 124 |
-
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right")
|
| 125 |
-
plt.tight_layout()
|
| 126 |
-
|
| 127 |
-
# Save plot to a bytes buffer
|
| 128 |
-
buf = io.BytesIO()
|
| 129 |
-
plt.savefig(buf, format="png", bbox_inches="tight")
|
| 130 |
-
buf.seek(0)
|
| 131 |
-
plt.close()
|
| 132 |
-
|
| 133 |
-
# Convert to base64 for Gradio
|
| 134 |
-
img_bytes = buf.getvalue()
|
| 135 |
-
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 136 |
-
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
| 137 |
-
else:
|
| 138 |
-
img_html = "No finite log probabilities to plot."
|
| 139 |
-
|
| 140 |
-
# Create DataFrame for the table
|
| 141 |
df = (
|
| 142 |
pd.DataFrame(
|
| 143 |
table_data,
|
|
@@ -177,11 +201,22 @@ def visualize_logprobs(json_input):
|
|
| 177 |
else:
|
| 178 |
colored_text_html = "No finite log probabilities to display."
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
except Exception as e:
|
| 183 |
logger.error("Visualization failed: %s", str(e))
|
| 184 |
-
return f"Error: {str(e)}", None, None
|
| 185 |
|
| 186 |
# Gradio interface
|
| 187 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
@@ -196,15 +231,16 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
| 196 |
placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
|
| 197 |
)
|
| 198 |
|
| 199 |
-
plot_output = gr.HTML(label="Log Probability Plot")
|
| 200 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
| 201 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
|
|
|
| 202 |
|
| 203 |
btn = gr.Button("Visualize")
|
| 204 |
btn.click(
|
| 205 |
fn=visualize_logprobs,
|
| 206 |
inputs=json_input,
|
| 207 |
-
outputs=[plot_output, table_output, text_output],
|
| 208 |
)
|
| 209 |
|
| 210 |
app.launch()
|
|
|
|
| 7 |
import math
|
| 8 |
import ast
|
| 9 |
import logging
|
| 10 |
+
from matplotlib.widgets import Cursor
|
| 11 |
|
| 12 |
# Set up logging
|
| 13 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
| 56 |
return float(value)
|
| 57 |
return None
|
| 58 |
|
| 59 |
+
# Function to process and visualize log probs with hover and alternatives
|
| 60 |
def visualize_logprobs(json_input):
|
| 61 |
try:
|
| 62 |
# Parse the input (handles both JSON and Python dictionaries)
|
|
|
|
| 70 |
else:
|
| 71 |
raise ValueError("Input must be a list or dictionary with 'content' key")
|
| 72 |
|
| 73 |
+
# Extract tokens, log probs, and top alternatives, skipping None or non-finite values
|
| 74 |
tokens = []
|
| 75 |
logprobs = []
|
| 76 |
+
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
| 77 |
for entry in content:
|
| 78 |
logprob = ensure_float(entry.get("logprob", None))
|
| 79 |
if logprob is not None and math.isfinite(logprob):
|
| 80 |
tokens.append(entry["token"])
|
| 81 |
logprobs.append(logprob)
|
| 82 |
+
# Get top_logprobs, default to empty dict if None
|
| 83 |
+
top_probs = entry.get("top_logprobs", {})
|
| 84 |
+
# Ensure all values in top_logprobs are floats
|
| 85 |
+
finite_top_probs = {}
|
| 86 |
+
for key, value in top_probs.items():
|
| 87 |
+
float_value = ensure_float(value)
|
| 88 |
+
if float_value is not None and math.isfinite(float_value):
|
| 89 |
+
finite_top_probs[key] = float_value
|
| 90 |
+
# Get the top 3 log probs (including the selected token)
|
| 91 |
+
all_probs = {entry["token"]: logprob} # Add the selected token's logprob
|
| 92 |
+
all_probs.update(finite_top_probs) # Add alternatives
|
| 93 |
+
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
|
| 94 |
+
top_3 = sorted_probs[:3] # Top 3 log probs (highest to lowest)
|
| 95 |
+
top_alternatives.append(top_3)
|
| 96 |
else:
|
| 97 |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
|
| 98 |
|
| 99 |
+
# Create the plot with hover functionality
|
| 100 |
+
if logprobs:
|
| 101 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 102 |
+
scatter = ax.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
|
| 103 |
+
ax.set_title("Log Probabilities of Generated Tokens")
|
| 104 |
+
ax.set_xlabel("Token Position")
|
| 105 |
+
ax.set_ylabel("Log Probability")
|
| 106 |
+
ax.grid(True)
|
| 107 |
+
ax.set_xticks([]) # Hide X-axis labels by default
|
| 108 |
+
|
| 109 |
+
# Add hover functionality using Matplotlib's Cursor for tooltips
|
| 110 |
+
cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
|
| 111 |
+
token_annotations = []
|
| 112 |
+
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
|
| 113 |
+
annotation = ax.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
|
| 114 |
+
token_annotations.append(annotation)
|
| 115 |
+
|
| 116 |
+
def on_hover(event):
|
| 117 |
+
if event.inaxes == ax:
|
| 118 |
+
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
|
| 119 |
+
contains, _ = scatter.contains(event)
|
| 120 |
+
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
|
| 121 |
+
token_annotations[i].set_text(tokens[i])
|
| 122 |
+
token_annotations[i].set_visible(True)
|
| 123 |
+
fig.canvas.draw_idle()
|
| 124 |
+
else:
|
| 125 |
+
token_annotations[i].set_visible(False)
|
| 126 |
+
fig.canvas.draw_idle()
|
| 127 |
+
|
| 128 |
+
fig.canvas.mpl_connect('motion_notify_event', on_hover)
|
| 129 |
+
|
| 130 |
+
# Save plot to a bytes buffer
|
| 131 |
+
buf = io.BytesIO()
|
| 132 |
+
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
| 133 |
+
buf.seek(0)
|
| 134 |
+
plt.close()
|
| 135 |
+
|
| 136 |
+
# Convert to base64 for Gradio
|
| 137 |
+
img_bytes = buf.getvalue()
|
| 138 |
+
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 139 |
+
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
| 140 |
+
else:
|
| 141 |
+
img_html = "No finite log probabilities to plot."
|
| 142 |
+
|
| 143 |
+
# Create DataFrame for the table
|
| 144 |
table_data = []
|
| 145 |
+
for i, entry in enumerate(content):
|
| 146 |
logprob = ensure_float(entry.get("logprob", None))
|
| 147 |
+
if logprob is not None and math.isfinite(logprob) and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
token = entry["token"]
|
|
|
|
| 149 |
top_logprobs = entry["top_logprobs"]
|
| 150 |
# Ensure all values in top_logprobs are floats
|
| 151 |
finite_top_logprobs = {}
|
|
|
|
| 153 |
float_value = ensure_float(value)
|
| 154 |
if float_value is not None and math.isfinite(float_value):
|
| 155 |
finite_top_logprobs[key] = float_value
|
|
|
|
| 156 |
# Extract top 3 alternatives from top_logprobs
|
| 157 |
+
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
|
|
|
|
|
|
|
| 158 |
row = [token, f"{logprob:.4f}"]
|
| 159 |
for alt_token, alt_logprob in top_3:
|
| 160 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
|
|
|
| 161 |
while len(row) < 5:
|
| 162 |
row.append("")
|
| 163 |
table_data.append(row)
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
df = (
|
| 166 |
pd.DataFrame(
|
| 167 |
table_data,
|
|
|
|
| 201 |
else:
|
| 202 |
colored_text_html = "No finite log probabilities to display."
|
| 203 |
|
| 204 |
+
# Create an alternative visualization for top 3 tokens
|
| 205 |
+
alt_viz_html = ""
|
| 206 |
+
if logprobs and top_alternatives:
|
| 207 |
+
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
|
| 208 |
+
for i, (token, probs) in enumerate(zip(tokens, top_alternatives)):
|
| 209 |
+
alt_viz_html += f"<li>Position {i} (Token: {token}):<br>"
|
| 210 |
+
for tok, prob in probs:
|
| 211 |
+
alt_viz_html += f"{tok}: {prob:.4f}<br>"
|
| 212 |
+
alt_viz_html += "</li>"
|
| 213 |
+
alt_viz_html += "</ul>"
|
| 214 |
+
|
| 215 |
+
return img_html, df, colored_text_html, alt_viz_html
|
| 216 |
|
| 217 |
except Exception as e:
|
| 218 |
logger.error("Visualization failed: %s", str(e))
|
| 219 |
+
return f"Error: {str(e)}", None, None, None
|
| 220 |
|
| 221 |
# Gradio interface
|
| 222 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
|
|
| 231 |
placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
|
| 232 |
)
|
| 233 |
|
| 234 |
+
plot_output = gr.HTML(label="Log Probability Plot (Hover for Tokens)")
|
| 235 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
| 236 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
| 237 |
+
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
|
| 238 |
|
| 239 |
btn = gr.Button("Visualize")
|
| 240 |
btn.click(
|
| 241 |
fn=visualize_logprobs,
|
| 242 |
inputs=json_input,
|
| 243 |
+
outputs=[plot_output, table_output, text_output, alt_viz_output],
|
| 244 |
)
|
| 245 |
|
| 246 |
app.launch()
|