Spaces:
Runtime error
Runtime error
Add visualization for 1F1B overlap.
Browse files- main.py +1 -2
- src/execution_model.py +20 -3
- src/strategies.py +3 -1
- src/visualizer.py +100 -13
main.py
CHANGED
|
@@ -105,8 +105,7 @@ def run_1f1b_overlap(cfg: DictConfig) -> None:
|
|
| 105 |
)
|
| 106 |
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
| 107 |
schedule.execute()
|
| 108 |
-
schedule.
|
| 109 |
-
# visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 110 |
|
| 111 |
|
| 112 |
if __name__ == "__main__":
|
|
|
|
| 105 |
)
|
| 106 |
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
| 107 |
schedule.execute()
|
| 108 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
if __name__ == "__main__":
|
src/execution_model.py
CHANGED
|
@@ -158,9 +158,8 @@ class ScheduleConfig:
|
|
| 158 |
# Check if we have a specific time for this combination
|
| 159 |
if (op_type1, op_type2) in self.overlapped_op_times:
|
| 160 |
return self.overlapped_op_times[(op_type1, op_type2)]
|
| 161 |
-
# Otherwise, use the
|
| 162 |
-
return (self.get_op_time(op_type1, stage_id) +
|
| 163 |
-
self.get_op_time(op_type2, stage_id))
|
| 164 |
|
| 165 |
if op_type not in self.op_times:
|
| 166 |
raise ValueError(f"Invalid operation type: {op_type}")
|
|
@@ -184,6 +183,12 @@ class Schedule:
|
|
| 184 |
self.config = config
|
| 185 |
|
| 186 |
self.init_operations()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def init_operations(self):
|
| 189 |
op_types = ["forward", "backward"]
|
|
@@ -197,10 +202,21 @@ class Schedule:
|
|
| 197 |
)
|
| 198 |
|
| 199 |
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
|
|
|
|
|
|
| 200 |
return self.ops[(batch_id, stage_id, op_type)]
|
| 201 |
|
| 202 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
| 203 |
deps = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
if op.op_type == "forward":
|
| 205 |
if op.stage_id > 0:
|
| 206 |
deps.append(
|
|
@@ -272,6 +288,7 @@ class Schedule:
|
|
| 272 |
print(f"\nTotal execution time: {total_time:.2f}")
|
| 273 |
|
| 274 |
def execute(self):
|
|
|
|
| 275 |
def execute_op(op: Operation):
|
| 276 |
if op.end_time is not None:
|
| 277 |
return
|
|
|
|
| 158 |
# Check if we have a specific time for this combination
|
| 159 |
if (op_type1, op_type2) in self.overlapped_op_times:
|
| 160 |
return self.overlapped_op_times[(op_type1, op_type2)]
|
| 161 |
+
# Otherwise, use the max of individual times plus a small overhead
|
| 162 |
+
return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id)) + 0.2
|
|
|
|
| 163 |
|
| 164 |
if op_type not in self.op_times:
|
| 165 |
raise ValueError(f"Invalid operation type: {op_type}")
|
|
|
|
| 183 |
self.config = config
|
| 184 |
|
| 185 |
self.init_operations()
|
| 186 |
+
self.op_to_overlapped = {}
|
| 187 |
+
|
| 188 |
+
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
| 189 |
+
for op in overlapped_op.operations:
|
| 190 |
+
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
| 191 |
+
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
| 192 |
|
| 193 |
def init_operations(self):
|
| 194 |
op_types = ["forward", "backward"]
|
|
|
|
| 202 |
)
|
| 203 |
|
| 204 |
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
| 205 |
+
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
| 206 |
+
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
| 207 |
return self.ops[(batch_id, stage_id, op_type)]
|
| 208 |
|
| 209 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
| 210 |
deps = []
|
| 211 |
+
if isinstance(op, OverlappedOperation):
|
| 212 |
+
for sub_op in op.operations:
|
| 213 |
+
deps.extend(self.get_dependencies(sub_op, include_device_dependency=False))
|
| 214 |
+
|
| 215 |
+
if include_device_dependency:
|
| 216 |
+
device_index = self.device_queues[op.device_id].ops.index(op)
|
| 217 |
+
if device_index > 0:
|
| 218 |
+
deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
|
| 219 |
+
return deps
|
| 220 |
if op.op_type == "forward":
|
| 221 |
if op.stage_id > 0:
|
| 222 |
deps.append(
|
|
|
|
| 288 |
print(f"\nTotal execution time: {total_time:.2f}")
|
| 289 |
|
| 290 |
def execute(self):
|
| 291 |
+
# TODO: change the execution order to topological order via DAG
|
| 292 |
def execute_op(op: Operation):
|
| 293 |
if op.end_time is not None:
|
| 294 |
return
|
src/strategies.py
CHANGED
|
@@ -114,7 +114,9 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
|
| 114 |
for _ in range(steady_batches):
|
| 115 |
fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
|
| 116 |
bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
|
| 119 |
fwd_batch_id += 1
|
| 120 |
bwd_batch_id += 1
|
|
|
|
| 114 |
for _ in range(steady_batches):
|
| 115 |
fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
|
| 116 |
bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
|
| 117 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
| 118 |
+
schedule.register_overlapped_operation(overlapped_op)
|
| 119 |
+
schedule.device_queues[i].add_operation(overlapped_op)
|
| 120 |
|
| 121 |
fwd_batch_id += 1
|
| 122 |
bwd_batch_id += 1
|
src/visualizer.py
CHANGED
|
@@ -8,7 +8,7 @@ from functools import lru_cache
|
|
| 8 |
import webbrowser
|
| 9 |
from threading import Timer
|
| 10 |
|
| 11 |
-
from src.execution_model import Schedule
|
| 12 |
|
| 13 |
|
| 14 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
@@ -32,15 +32,37 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
| 32 |
visualization_data[device_id] = []
|
| 33 |
|
| 34 |
for op in device_queue.ops:
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
return visualization_data
|
| 46 |
|
|
@@ -103,13 +125,30 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
| 103 |
"#99cc99", # Pale green
|
| 104 |
"#c6e6c6", # Pastel green
|
| 105 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
virtual_stage = stage_id // num_devices
|
| 108 |
|
| 109 |
# If virtual_stage is beyond our color list, cycle through the colors
|
| 110 |
color_index = virtual_stage % len(forward_colors)
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
return forward_colors[color_index]
|
| 114 |
elif op_type == "backward":
|
| 115 |
return backward_colors[color_index]
|
|
@@ -191,6 +230,14 @@ def create_pipeline_figure(
|
|
| 191 |
color = get_color(task["type"], task["stage"], num_devices)
|
| 192 |
text_color = "black"
|
| 193 |
name = "Backward (Weight)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
else:
|
| 195 |
color = empty_color
|
| 196 |
text_color = "black"
|
|
@@ -222,14 +269,34 @@ def create_pipeline_figure(
|
|
| 222 |
dict(
|
| 223 |
x=start_time + duration / 2,
|
| 224 |
y=y_pos,
|
| 225 |
-
text=f"{task['batch']}",
|
| 226 |
showarrow=False,
|
| 227 |
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 228 |
)
|
| 229 |
)
|
| 230 |
|
| 231 |
# Prepare hover data (add traces in batches later)
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
hover_traces.append(
|
| 235 |
dict(
|
|
@@ -268,6 +335,13 @@ def create_pipeline_figure(
|
|
| 268 |
virtual_stage = task["stage"] // num_devices
|
| 269 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
# Add forward and backward items for each virtual stage
|
| 272 |
for vs in range(max_virtual_stage + 1):
|
| 273 |
legend_items.append(
|
|
@@ -300,6 +374,15 @@ def create_pipeline_figure(
|
|
| 300 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
| 301 |
)
|
| 302 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
# If no tasks found, add default legend items
|
| 305 |
if not legend_items:
|
|
@@ -314,6 +397,10 @@ def create_pipeline_figure(
|
|
| 314 |
name="Backward Weight (VS 0)",
|
| 315 |
color=get_color("backward_W", 0, num_devices),
|
| 316 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
]
|
| 318 |
|
| 319 |
for i, item in enumerate(legend_items):
|
|
|
|
| 8 |
import webbrowser
|
| 9 |
from threading import Timer
|
| 10 |
|
| 11 |
+
from src.execution_model import Schedule, OverlappedOperation
|
| 12 |
|
| 13 |
|
| 14 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
|
|
| 32 |
visualization_data[device_id] = []
|
| 33 |
|
| 34 |
for op in device_queue.ops:
|
| 35 |
+
# Handle both regular Operations and OverlappedOperations
|
| 36 |
+
if isinstance(op, OverlappedOperation):
|
| 37 |
+
visualization_data[device_id].append(
|
| 38 |
+
{
|
| 39 |
+
"type": op.op_type,
|
| 40 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
| 41 |
+
"stage": op.stage_id,
|
| 42 |
+
"start_time": op.start_time,
|
| 43 |
+
"duration": op.end_time - op.start_time,
|
| 44 |
+
"is_overlapped": True,
|
| 45 |
+
"operations": [
|
| 46 |
+
{
|
| 47 |
+
"type": nested_op.op_type,
|
| 48 |
+
"batch": nested_op.batch_id + 1,
|
| 49 |
+
"stage": nested_op.stage_id
|
| 50 |
+
}
|
| 51 |
+
for nested_op in op.operations
|
| 52 |
+
]
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
visualization_data[device_id].append(
|
| 57 |
+
{
|
| 58 |
+
"type": op.op_type,
|
| 59 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
| 60 |
+
"stage": op.stage_id,
|
| 61 |
+
"start_time": op.start_time,
|
| 62 |
+
"duration": op.end_time - op.start_time,
|
| 63 |
+
"is_overlapped": False
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
|
| 67 |
return visualization_data
|
| 68 |
|
|
|
|
| 125 |
"#99cc99", # Pale green
|
| 126 |
"#c6e6c6", # Pastel green
|
| 127 |
]
|
| 128 |
+
|
| 129 |
+
# Purple palette for overlapped operations
|
| 130 |
+
overlapped_colors = [
|
| 131 |
+
"#9966cc", # Medium purple
|
| 132 |
+
"#8a2be2", # Blue violet
|
| 133 |
+
"#9370db", # Medium purple
|
| 134 |
+
"#6a5acd", # Slate blue
|
| 135 |
+
"#7b68ee", # Medium slate blue
|
| 136 |
+
"#ba55d3", # Medium orchid
|
| 137 |
+
"#9932cc", # Dark orchid
|
| 138 |
+
"#d8bfd8", # Thistle
|
| 139 |
+
"#e6e6fa", # Lavender
|
| 140 |
+
"#dda0dd", # Plum
|
| 141 |
+
]
|
| 142 |
|
| 143 |
virtual_stage = stage_id // num_devices
|
| 144 |
|
| 145 |
# If virtual_stage is beyond our color list, cycle through the colors
|
| 146 |
color_index = virtual_stage % len(forward_colors)
|
| 147 |
|
| 148 |
+
# Handle overlapped operations
|
| 149 |
+
if op_type.startswith("overlapped_"):
|
| 150 |
+
return overlapped_colors[color_index]
|
| 151 |
+
elif op_type == "forward":
|
| 152 |
return forward_colors[color_index]
|
| 153 |
elif op_type == "backward":
|
| 154 |
return backward_colors[color_index]
|
|
|
|
| 230 |
color = get_color(task["type"], task["stage"], num_devices)
|
| 231 |
text_color = "black"
|
| 232 |
name = "Backward (Weight)"
|
| 233 |
+
elif task["type"].startswith("overlapped_"):
|
| 234 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
| 235 |
+
text_color = "white"
|
| 236 |
+
name = "Overlapped"
|
| 237 |
+
# Create a more descriptive name for the hover text
|
| 238 |
+
if "is_overlapped" in task and task["is_overlapped"]:
|
| 239 |
+
op_types = [op["type"] for op in task["operations"]]
|
| 240 |
+
name = f"Overlapped ({', '.join(op_types)})"
|
| 241 |
else:
|
| 242 |
color = empty_color
|
| 243 |
text_color = "black"
|
|
|
|
| 269 |
dict(
|
| 270 |
x=start_time + duration / 2,
|
| 271 |
y=y_pos,
|
| 272 |
+
text=f"{task['batch']}" + ("*" if task.get("is_overlapped", False) else ""),
|
| 273 |
showarrow=False,
|
| 274 |
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 275 |
)
|
| 276 |
)
|
| 277 |
|
| 278 |
# Prepare hover data (add traces in batches later)
|
| 279 |
+
if task.get("is_overlapped", False):
|
| 280 |
+
# Enhanced hover text for overlapped operations
|
| 281 |
+
op_details = "<br>".join([
|
| 282 |
+
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
| 283 |
+
for op in task["operations"]
|
| 284 |
+
])
|
| 285 |
+
hover_text = (
|
| 286 |
+
f"Overlapped Operations:<br>{op_details}<br>"
|
| 287 |
+
f"Start: {task['start_time']:.2f}<br>"
|
| 288 |
+
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
| 289 |
+
f"Duration: {task['duration']:.2f}"
|
| 290 |
+
)
|
| 291 |
+
else:
|
| 292 |
+
hover_text = (
|
| 293 |
+
f"Batch: {task['batch']}<br>"
|
| 294 |
+
f"Stage: {task['stage']}<br>"
|
| 295 |
+
f"Type: {name}<br>"
|
| 296 |
+
f"Start: {task['start_time']:.2f}<br>"
|
| 297 |
+
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
| 298 |
+
f"Duration: {task['duration']:.2f}"
|
| 299 |
+
)
|
| 300 |
|
| 301 |
hover_traces.append(
|
| 302 |
dict(
|
|
|
|
| 335 |
virtual_stage = task["stage"] // num_devices
|
| 336 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
| 337 |
|
| 338 |
+
# Check if overlapped operations exist
|
| 339 |
+
has_overlapped = any(
|
| 340 |
+
task.get("is_overlapped", False)
|
| 341 |
+
for device in schedule_data
|
| 342 |
+
for task in schedule_data[device]
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
# Add forward and backward items for each virtual stage
|
| 346 |
for vs in range(max_virtual_stage + 1):
|
| 347 |
legend_items.append(
|
|
|
|
| 374 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
| 375 |
)
|
| 376 |
)
|
| 377 |
+
|
| 378 |
+
# Add entry for overlapped operations if they exist
|
| 379 |
+
if has_overlapped:
|
| 380 |
+
legend_items.append(
|
| 381 |
+
dict(
|
| 382 |
+
name=f"Overlapped (VS {vs})",
|
| 383 |
+
color=get_color("overlapped_", vs * num_devices, num_devices),
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
|
| 387 |
# If no tasks found, add default legend items
|
| 388 |
if not legend_items:
|
|
|
|
| 397 |
name="Backward Weight (VS 0)",
|
| 398 |
color=get_color("backward_W", 0, num_devices),
|
| 399 |
),
|
| 400 |
+
dict(
|
| 401 |
+
name="Overlapped (VS 0)",
|
| 402 |
+
color=get_color("overlapped_", 0, num_devices),
|
| 403 |
+
),
|
| 404 |
]
|
| 405 |
|
| 406 |
for i, item in enumerate(legend_items):
|