Spaces:
Runtime error
Runtime error
| import plotly.express as px | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import seaborn as sns | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import PercentFormatter | |
| def plot_wow_retention_by_type(wow_retention): | |
| wow_retention["week"] = pd.to_datetime(wow_retention["week"], format="%b-%d-%Y") | |
| wow_retention = wow_retention.sort_values(["trader_type", "week"]) | |
| fig = px.line( | |
| wow_retention, | |
| x="week", | |
| y="retention_rate", | |
| color="trader_type", | |
| markers=True, | |
| title="Weekly Retention Rate by Trader Type", | |
| labels={ | |
| "week": "Week", | |
| "retention_rate": "Retention Rate (%)", | |
| "trader_type": "Trader Type", | |
| }, | |
| color_discrete_sequence=["purple", "goldenrod", "green"], | |
| ) | |
| fig.update_layout( | |
| hovermode="x unified", | |
| legend=dict( | |
| yanchor="middle", | |
| y=0.5, | |
| xanchor="left", | |
| x=0.99, | |
| orientation="v", | |
| ), | |
| yaxis=dict( | |
| ticksuffix="%", | |
| range=[ | |
| 0, | |
| max(wow_retention["retention_rate"]) * 1.1, | |
| ], # Add 10% padding to y-axis | |
| ), | |
| xaxis=dict(tickformat="%Y-%m-%d"), | |
| margin=dict(r=200), # Adjusted margins | |
| width=600, # Set explicit width | |
| height=500, # Set explicit height | |
| ) | |
| # Add hover template | |
| fig.update_traces( | |
| hovertemplate="<b>%{y:.1f}%</b><br>Week: %{x|%Y-%m-%d}<extra></extra>" | |
| ) | |
| return gr.Plot( | |
| value=fig, | |
| ) | |
| def plot_cohort_retention_heatmap(retention_matrix: pd.DataFrame, cmap: str): | |
| # Create a copy of the matrix to avoid modifying the original | |
| retention_matrix = retention_matrix.copy() | |
| # Convert index to datetime and format to date string | |
| retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%a-%b %d") | |
| # Create figure and axes with specified size | |
| plt.figure(figsize=(12, 8)) | |
| # Create mask for NaN values | |
| mask = retention_matrix.isna() | |
| # Create heatmap | |
| ax = sns.heatmap( | |
| data=retention_matrix, | |
| annot=True, # Show numbers in cells | |
| fmt=".1f", # Format numbers to 1 decimal place | |
| cmap=cmap, # Yellow to Orange to Red color scheme | |
| vmin=0, | |
| vmax=100, | |
| center=50, | |
| cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()}, | |
| mask=mask, | |
| annot_kws={"size": 8}, | |
| ) | |
| # Customize the plot | |
| plt.title("Cohort Retention Analysis", pad=20, size=14) | |
| plt.xlabel("Weeks Since First Activiy", size=12) | |
| plt.ylabel("Cohort First Day of the Week", size=12) | |
| # Format week numbers on x-axis | |
| x_labels = [f"Week {i}" for i in retention_matrix.columns] | |
| ax.set_xticklabels(x_labels, rotation=45, ha="right") | |
| # Set y-axis labels rotation | |
| plt.yticks(rotation=0) | |
| # Add gridlines | |
| ax.set_axisbelow(True) | |
| # Adjust layout to prevent label cutoff | |
| plt.tight_layout() | |
| cohort_fig = ax.get_figure() | |
| return gr.Plot(value=cohort_fig) | |