Spaces:
Running
Running
Update visualizer.
Browse files- main.py +1 -1
- src/visualizer.py +67 -44
main.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from src.execution_model import ScheduleConfig, ScheduleExecutor
|
| 2 |
from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
|
| 3 |
-
from src.visualizer import visualize_pipeline_parallelism_dash
|
| 4 |
import hydra
|
| 5 |
from omegaconf import DictConfig, OmegaConf
|
| 6 |
|
|
|
|
| 1 |
from src.execution_model import ScheduleConfig, ScheduleExecutor
|
| 2 |
from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
|
| 3 |
+
from src.visualizer import visualize_pipeline_parallelism_dash
|
| 4 |
import hydra
|
| 5 |
from omegaconf import DictConfig, OmegaConf
|
| 6 |
|
src/visualizer.py
CHANGED
|
@@ -55,24 +55,42 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 55 |
empty_color = "whitesmoke"
|
| 56 |
# Colors for task types
|
| 57 |
def get_color(op_type: str, stage_id: int):
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
virtual_stage = stage_id // num_devices
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
if op_type == "forward":
|
| 65 |
-
|
| 66 |
-
return forward_base_color
|
| 67 |
-
else:
|
| 68 |
-
# Lighter shade for virtual_stage > 0
|
| 69 |
-
return "lightskyblue"
|
| 70 |
elif op_type == "backward":
|
| 71 |
-
|
| 72 |
-
return backward_base_color
|
| 73 |
-
else:
|
| 74 |
-
# Lighter shade for virtual_stage > 0
|
| 75 |
-
return "lightseagreen"
|
| 76 |
else:
|
| 77 |
raise ValueError(f"Invalid operation type: {op_type}")
|
| 78 |
|
|
@@ -165,10 +183,32 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 165 |
progress_bar.update(1)
|
| 166 |
|
| 167 |
# Add custom legend
|
| 168 |
-
legend_items = [
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
for i, item in enumerate(legend_items):
|
| 174 |
fig.add_trace(go.Scatter(
|
|
@@ -209,14 +249,17 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 209 |
font=dict(size=20)
|
| 210 |
),
|
| 211 |
legend=dict(
|
| 212 |
-
orientation="
|
| 213 |
yanchor="top",
|
| 214 |
-
y
|
| 215 |
-
xanchor="
|
| 216 |
-
x=
|
|
|
|
|
|
|
|
|
|
| 217 |
),
|
| 218 |
-
width=
|
| 219 |
-
height=400, #
|
| 220 |
bargap=0,
|
| 221 |
bargroupgap=0,
|
| 222 |
)
|
|
@@ -285,7 +328,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
|
|
| 285 |
def load_graph(_):
|
| 286 |
# Create the figure when the app loads
|
| 287 |
return create_pipeline_figure(schedule_data, show_progress=True)
|
| 288 |
-
|
| 289 |
@app.callback(
|
| 290 |
Output("download-image", "data"),
|
| 291 |
Input("btn-download", "n_clicks"),
|
|
@@ -326,23 +369,3 @@ def visualize_pipeline_parallelism_dash(
|
|
| 326 |
app = create_dash_app(schedule)
|
| 327 |
print(f"Starting Dash app on http://localhost:{port}/")
|
| 328 |
app.run_server(debug=debug, port=port)
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def save_pipeline_visualization_plotly(
|
| 332 |
-
schedule: Schedule,
|
| 333 |
-
output_file: str = "pipeline_visualization_plotly.png",
|
| 334 |
-
):
|
| 335 |
-
"""
|
| 336 |
-
Save a static image of the pipeline schedule visualization.
|
| 337 |
-
|
| 338 |
-
Args:
|
| 339 |
-
schedule: Schedule object to visualize
|
| 340 |
-
output_file: Path to save the image to
|
| 341 |
-
"""
|
| 342 |
-
schedule_data = convert_schedule_to_visualization_format(schedule)
|
| 343 |
-
fig = create_pipeline_figure(schedule_data, show_progress=True)
|
| 344 |
-
|
| 345 |
-
print(f"Saving visualization to {output_file}...")
|
| 346 |
-
fig.write_image(output_file, width=1600, height=400, scale=2)
|
| 347 |
-
print(f"Visualization saved to {output_file}")
|
| 348 |
-
|
|
|
|
| 55 |
empty_color = "whitesmoke"
|
| 56 |
# Colors for task types
|
| 57 |
def get_color(op_type: str, stage_id: int):
|
| 58 |
+
# Color palettes for different virtual stages
|
| 59 |
+
forward_colors = [
|
| 60 |
+
"royalblue", # Stage 0
|
| 61 |
+
"lightskyblue", # Stage 1
|
| 62 |
+
"cornflowerblue", # Stage 2
|
| 63 |
+
"steelblue", # Stage 3
|
| 64 |
+
"dodgerblue", # Stage 4
|
| 65 |
+
"deepskyblue", # Stage 5
|
| 66 |
+
"mediumblue", # Stage 6
|
| 67 |
+
"mediumslateblue",# Stage 7
|
| 68 |
+
"slateblue", # Stage 8
|
| 69 |
+
"darkslateblue" # Stage 9
|
| 70 |
+
]
|
| 71 |
|
| 72 |
+
backward_colors = [
|
| 73 |
+
"lightgreen", # Stage 0
|
| 74 |
+
"mediumseagreen", # Stage 1
|
| 75 |
+
"seagreen", # Stage 2
|
| 76 |
+
"lightseagreen", # Stage 3
|
| 77 |
+
"mediumaquamarine", # Stage 4
|
| 78 |
+
"mediumspringgreen", # Stage 5
|
| 79 |
+
"springgreen", # Stage 6
|
| 80 |
+
"palegreen", # Stage 7
|
| 81 |
+
"limegreen", # Stage 8
|
| 82 |
+
"forestgreen" # Stage 9
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
virtual_stage = stage_id // num_devices
|
| 86 |
|
| 87 |
+
# If virtual_stage is beyond our color list, cycle through the colors
|
| 88 |
+
color_index = virtual_stage % len(forward_colors)
|
| 89 |
+
|
| 90 |
if op_type == "forward":
|
| 91 |
+
return forward_colors[color_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
elif op_type == "backward":
|
| 93 |
+
return backward_colors[color_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
raise ValueError(f"Invalid operation type: {op_type}")
|
| 96 |
|
|
|
|
| 183 |
progress_bar.update(1)
|
| 184 |
|
| 185 |
# Add custom legend
|
| 186 |
+
legend_items = []
|
| 187 |
+
|
| 188 |
+
# Find the maximum virtual stage in the data
|
| 189 |
+
max_virtual_stage = 0
|
| 190 |
+
for device in schedule_data:
|
| 191 |
+
for task in schedule_data[device]:
|
| 192 |
+
virtual_stage = task["stage"] // num_devices
|
| 193 |
+
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
| 194 |
+
|
| 195 |
+
# Add forward and backward items for each virtual stage
|
| 196 |
+
for vs in range(max_virtual_stage + 1):
|
| 197 |
+
legend_items.append(dict(
|
| 198 |
+
name=f"Forward (VS {vs})",
|
| 199 |
+
color=get_color("forward", vs * num_devices)
|
| 200 |
+
))
|
| 201 |
+
legend_items.append(dict(
|
| 202 |
+
name=f"Backward (VS {vs})",
|
| 203 |
+
color=get_color("backward", vs * num_devices)
|
| 204 |
+
))
|
| 205 |
+
|
| 206 |
+
# If no tasks found, add default legend items
|
| 207 |
+
if not legend_items:
|
| 208 |
+
legend_items = [
|
| 209 |
+
dict(name="Forward (VS 0)", color=get_color("forward", 0)),
|
| 210 |
+
dict(name="Backward (VS 0)", color=get_color("backward", 0)),
|
| 211 |
+
]
|
| 212 |
|
| 213 |
for i, item in enumerate(legend_items):
|
| 214 |
fig.add_trace(go.Scatter(
|
|
|
|
| 249 |
font=dict(size=20)
|
| 250 |
),
|
| 251 |
legend=dict(
|
| 252 |
+
orientation="v", # Changed from horizontal to vertical
|
| 253 |
yanchor="top",
|
| 254 |
+
y=1.02, # Position at the top
|
| 255 |
+
xanchor="right",
|
| 256 |
+
x=1.15, # Position to the right of the plot
|
| 257 |
+
title=dict(text="<b>Operation Types:</b>"),
|
| 258 |
+
itemsizing="constant",
|
| 259 |
+
tracegroupgap=0
|
| 260 |
),
|
| 261 |
+
width=1800, # Increase width to accommodate the legend
|
| 262 |
+
height=400, # Maintain current height
|
| 263 |
bargap=0,
|
| 264 |
bargroupgap=0,
|
| 265 |
)
|
|
|
|
| 328 |
def load_graph(_):
|
| 329 |
# Create the figure when the app loads
|
| 330 |
return create_pipeline_figure(schedule_data, show_progress=True)
|
| 331 |
+
|
| 332 |
@app.callback(
|
| 333 |
Output("download-image", "data"),
|
| 334 |
Input("btn-download", "n_clicks"),
|
|
|
|
| 369 |
app = create_dash_app(schedule)
|
| 370 |
print(f"Starting Dash app on http://localhost:{port}/")
|
| 371 |
app.run_server(debug=debug, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|