Spaces:
Sleeping
Sleeping
Update visualizer for 1F1B overlap.
Browse files- assets/1f1b_overlap.png +2 -2
- src/visualizer.py +122 -86
assets/1f1b_overlap.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
src/visualizer.py
CHANGED
|
@@ -213,71 +213,14 @@ def create_pipeline_figure(
|
|
| 213 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
| 214 |
|
| 215 |
for task in sorted_tasks:
|
| 216 |
-
# Determine task color and text color
|
| 217 |
-
if task["type"] == "forward":
|
| 218 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
| 219 |
-
text_color = "white"
|
| 220 |
-
name = "Forward"
|
| 221 |
-
elif task["type"] == "backward":
|
| 222 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
| 223 |
-
text_color = "black"
|
| 224 |
-
name = "Backward"
|
| 225 |
-
elif task["type"] == "backward_D":
|
| 226 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
| 227 |
-
text_color = "black"
|
| 228 |
-
name = "Backward (Grad)"
|
| 229 |
-
elif task["type"] == "backward_W":
|
| 230 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
| 231 |
-
text_color = "black"
|
| 232 |
-
name = "Backward (Weight)"
|
| 233 |
-
elif task["type"].startswith("overlapped_"):
|
| 234 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
| 235 |
-
text_color = "white"
|
| 236 |
-
name = "Overlapped"
|
| 237 |
-
# Create a more descriptive name for the hover text
|
| 238 |
-
if "is_overlapped" in task and task["is_overlapped"]:
|
| 239 |
-
op_types = [op["type"] for op in task["operations"]]
|
| 240 |
-
name = f"Overlapped ({', '.join(op_types)})"
|
| 241 |
-
else:
|
| 242 |
-
color = empty_color
|
| 243 |
-
text_color = "black"
|
| 244 |
-
name = "Unknown"
|
| 245 |
-
|
| 246 |
-
# Add rectangle for the task
|
| 247 |
-
start_time = task["start_time"]
|
| 248 |
-
duration = task["duration"]
|
| 249 |
-
|
| 250 |
# Calculate y positions with no gaps
|
| 251 |
y_pos = device_idx_reversed * y_spacing
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
y0=y_pos - 0.5,
|
| 259 |
-
x1=start_time + duration,
|
| 260 |
-
y1=y_pos + 0.5,
|
| 261 |
-
line=dict(color="black", width=0.5),
|
| 262 |
-
fillcolor=color,
|
| 263 |
-
layer="above",
|
| 264 |
-
)
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
# Add batch number text (batch-add later)
|
| 268 |
-
annotations.append(
|
| 269 |
-
dict(
|
| 270 |
-
x=start_time + duration / 2,
|
| 271 |
-
y=y_pos,
|
| 272 |
-
text=f"{task['batch']}" + ("*" if task.get("is_overlapped", False) else ""),
|
| 273 |
-
showarrow=False,
|
| 274 |
-
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 275 |
-
)
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
# Prepare hover data (add traces in batches later)
|
| 279 |
-
if task.get("is_overlapped", False):
|
| 280 |
-
# Enhanced hover text for overlapped operations
|
| 281 |
op_details = "<br>".join([
|
| 282 |
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
| 283 |
for op in task["operations"]
|
|
@@ -288,7 +231,113 @@ def create_pipeline_figure(
|
|
| 288 |
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
| 289 |
f"Duration: {task['duration']:.2f}"
|
| 290 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
hover_text = (
|
| 293 |
f"Batch: {task['batch']}<br>"
|
| 294 |
f"Stage: {task['stage']}<br>"
|
|
@@ -298,17 +347,17 @@ def create_pipeline_figure(
|
|
| 298 |
f"Duration: {task['duration']:.2f}"
|
| 299 |
)
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
| 310 |
)
|
| 311 |
-
)
|
| 312 |
|
| 313 |
# Update progress
|
| 314 |
if show_progress:
|
|
@@ -374,15 +423,6 @@ def create_pipeline_figure(
|
|
| 374 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
| 375 |
)
|
| 376 |
)
|
| 377 |
-
|
| 378 |
-
# Add entry for overlapped operations if they exist
|
| 379 |
-
if has_overlapped:
|
| 380 |
-
legend_items.append(
|
| 381 |
-
dict(
|
| 382 |
-
name=f"Overlapped (VS {vs})",
|
| 383 |
-
color=get_color("overlapped_", vs * num_devices, num_devices),
|
| 384 |
-
)
|
| 385 |
-
)
|
| 386 |
|
| 387 |
# If no tasks found, add default legend items
|
| 388 |
if not legend_items:
|
|
@@ -397,10 +437,6 @@ def create_pipeline_figure(
|
|
| 397 |
name="Backward Weight (VS 0)",
|
| 398 |
color=get_color("backward_W", 0, num_devices),
|
| 399 |
),
|
| 400 |
-
dict(
|
| 401 |
-
name="Overlapped (VS 0)",
|
| 402 |
-
color=get_color("overlapped_", 0, num_devices),
|
| 403 |
-
),
|
| 404 |
]
|
| 405 |
|
| 406 |
for i, item in enumerate(legend_items):
|
|
|
|
| 213 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
| 214 |
|
| 215 |
for task in sorted_tasks:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# Calculate y positions with no gaps
|
| 217 |
y_pos = device_idx_reversed * y_spacing
|
| 218 |
+
start_time = task["start_time"]
|
| 219 |
+
duration = task["duration"]
|
| 220 |
+
|
| 221 |
+
# Special handling for overlapped operations
|
| 222 |
+
if task.get("is_overlapped", False) and "operations" in task:
|
| 223 |
+
# Prepare hover text for the entire overlapped operation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
op_details = "<br>".join([
|
| 225 |
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
| 226 |
for op in task["operations"]
|
|
|
|
| 231 |
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
| 232 |
f"Duration: {task['duration']:.2f}"
|
| 233 |
)
|
| 234 |
+
|
| 235 |
+
# Add invisible marker for hover info
|
| 236 |
+
hover_traces.append(
|
| 237 |
+
dict(
|
| 238 |
+
x=[start_time + duration / 2],
|
| 239 |
+
y=[y_pos],
|
| 240 |
+
mode="markers",
|
| 241 |
+
marker=dict(opacity=0), # Invisible marker
|
| 242 |
+
hoverinfo="text",
|
| 243 |
+
text=hover_text,
|
| 244 |
+
showlegend=False,
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Calculate height of each sub-operation
|
| 249 |
+
sub_height = 1.0 / len(task["operations"])
|
| 250 |
+
|
| 251 |
+
# Add rectangles and annotations for each sub-operation
|
| 252 |
+
for i, sub_op in enumerate(task["operations"]):
|
| 253 |
+
# Determine color for this sub-operation
|
| 254 |
+
color = get_color(sub_op["type"], sub_op["stage"], num_devices)
|
| 255 |
+
|
| 256 |
+
# Calculate y position for this sub-operation
|
| 257 |
+
sub_y_pos_bottom = y_pos - 0.5 + (i * sub_height)
|
| 258 |
+
sub_y_pos_top = sub_y_pos_bottom + sub_height
|
| 259 |
+
sub_y_center = (sub_y_pos_bottom + sub_y_pos_top) / 2
|
| 260 |
+
|
| 261 |
+
# Add rectangle for this sub-operation
|
| 262 |
+
shapes.append(
|
| 263 |
+
dict(
|
| 264 |
+
type="rect",
|
| 265 |
+
x0=start_time,
|
| 266 |
+
y0=sub_y_pos_bottom,
|
| 267 |
+
x1=start_time + duration,
|
| 268 |
+
y1=sub_y_pos_top,
|
| 269 |
+
line=dict(color="black", width=0.5),
|
| 270 |
+
fillcolor=color,
|
| 271 |
+
layer="above",
|
| 272 |
+
)
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Add batch number text for this sub-operation
|
| 276 |
+
# Determine text color based on background color
|
| 277 |
+
if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
|
| 278 |
+
text_color = "black"
|
| 279 |
+
else:
|
| 280 |
+
text_color = "white"
|
| 281 |
+
|
| 282 |
+
annotations.append(
|
| 283 |
+
dict(
|
| 284 |
+
x=start_time + duration / 2,
|
| 285 |
+
y=sub_y_center,
|
| 286 |
+
text=f"{sub_op['batch']}",
|
| 287 |
+
showarrow=False,
|
| 288 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 289 |
+
)
|
| 290 |
+
)
|
| 291 |
else:
|
| 292 |
+
# Regular (non-overlapped) operation
|
| 293 |
+
# Determine task color and text color
|
| 294 |
+
if task["type"] == "forward":
|
| 295 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
| 296 |
+
text_color = "white"
|
| 297 |
+
name = "Forward"
|
| 298 |
+
elif task["type"] == "backward":
|
| 299 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
| 300 |
+
text_color = "black"
|
| 301 |
+
name = "Backward"
|
| 302 |
+
elif task["type"] == "backward_D":
|
| 303 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
| 304 |
+
text_color = "black"
|
| 305 |
+
name = "Backward (Grad)"
|
| 306 |
+
elif task["type"] == "backward_W":
|
| 307 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
| 308 |
+
text_color = "black"
|
| 309 |
+
name = "Backward (Weight)"
|
| 310 |
+
else:
|
| 311 |
+
color = empty_color
|
| 312 |
+
text_color = "black"
|
| 313 |
+
name = "Unknown"
|
| 314 |
+
|
| 315 |
+
# Add rectangle for the task
|
| 316 |
+
shapes.append(
|
| 317 |
+
dict(
|
| 318 |
+
type="rect",
|
| 319 |
+
x0=start_time,
|
| 320 |
+
y0=y_pos - 0.5,
|
| 321 |
+
x1=start_time + duration,
|
| 322 |
+
y1=y_pos + 0.5,
|
| 323 |
+
line=dict(color="black", width=0.5),
|
| 324 |
+
fillcolor=color,
|
| 325 |
+
layer="above",
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Add batch number text
|
| 330 |
+
annotations.append(
|
| 331 |
+
dict(
|
| 332 |
+
x=start_time + duration / 2,
|
| 333 |
+
y=y_pos,
|
| 334 |
+
text=f"{task['batch']}",
|
| 335 |
+
showarrow=False,
|
| 336 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Prepare hover data
|
| 341 |
hover_text = (
|
| 342 |
f"Batch: {task['batch']}<br>"
|
| 343 |
f"Stage: {task['stage']}<br>"
|
|
|
|
| 347 |
f"Duration: {task['duration']:.2f}"
|
| 348 |
)
|
| 349 |
|
| 350 |
+
hover_traces.append(
|
| 351 |
+
dict(
|
| 352 |
+
x=[start_time + duration / 2],
|
| 353 |
+
y=[y_pos],
|
| 354 |
+
mode="markers",
|
| 355 |
+
marker=dict(opacity=0), # Invisible marker
|
| 356 |
+
hoverinfo="text",
|
| 357 |
+
text=hover_text,
|
| 358 |
+
showlegend=False,
|
| 359 |
+
)
|
| 360 |
)
|
|
|
|
| 361 |
|
| 362 |
# Update progress
|
| 363 |
if show_progress:
|
|
|
|
| 423 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
| 424 |
)
|
| 425 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
# If no tasks found, add default legend items
|
| 428 |
if not legend_items:
|
|
|
|
| 437 |
name="Backward Weight (VS 0)",
|
| 438 |
color=get_color("backward_W", 0, num_devices),
|
| 439 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
]
|
| 441 |
|
| 442 |
for i, item in enumerate(legend_items):
|