Spaces:
Running
Running
Add microbatch_group_size_per_vp_stage to app config button.
Browse files
app.py
CHANGED
|
@@ -38,6 +38,7 @@ default_values = {
|
|
| 38 |
"op_time_backward": 2.0,
|
| 39 |
"strategy": ["1f1b_interleave"],
|
| 40 |
"op_time_overlapped_fwd_bwd": None,
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
# Define input groups using dbc components
|
|
@@ -186,6 +187,20 @@ timing_params_card = dbc.Card(
|
|
| 186 |
placement="right"
|
| 187 |
)
|
| 188 |
], className="mb-3"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
]
|
| 190 |
)
|
| 191 |
]),
|
|
@@ -249,6 +264,7 @@ app.layout = dbc.Container([
|
|
| 249 |
Output('op_time_backward_d', 'invalid'),
|
| 250 |
Output('op_time_backward_w', 'invalid'),
|
| 251 |
Output('op_time_overlapped_fwd_bwd', 'invalid'),
|
|
|
|
| 252 |
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
|
| 253 |
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
|
| 254 |
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
|
|
@@ -263,12 +279,13 @@ app.layout = dbc.Container([
|
|
| 263 |
Input('op_time_backward_d', 'value'),
|
| 264 |
Input('op_time_backward_w', 'value'),
|
| 265 |
Input('op_time_overlapped_fwd_bwd', 'value'),
|
|
|
|
| 266 |
Input('selected-strategies-store', 'data'), # Validate strategy selection
|
| 267 |
prevent_initial_call=True # Prevent callback running on page load before user interaction
|
| 268 |
)
|
| 269 |
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
| 270 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
| 271 |
-
op_time_overlapped_fwd_bwd, selected_strategies):
|
| 272 |
is_invalid = {
|
| 273 |
"num_devices": num_devices is None or num_devices < 1,
|
| 274 |
"num_stages": num_stages is None or num_stages < 1,
|
|
@@ -279,6 +296,7 @@ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
|
| 279 |
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
|
| 280 |
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
|
| 281 |
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
|
|
|
|
| 282 |
}
|
| 283 |
|
| 284 |
# Validate strategy selection
|
|
@@ -318,6 +336,7 @@ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
|
| 318 |
is_invalid["op_time_backward_d"],
|
| 319 |
is_invalid["op_time_backward_w"],
|
| 320 |
is_invalid["op_time_overlapped_fwd_bwd"],
|
|
|
|
| 321 |
strategy_feedback # Update strategy feedback based on validation
|
| 322 |
)
|
| 323 |
|
|
@@ -361,12 +380,13 @@ app.clientside_callback(
|
|
| 361 |
State('op_time_backward_d', 'value'),
|
| 362 |
State('op_time_backward_w', 'value'),
|
| 363 |
State('op_time_overlapped_fwd_bwd', 'value'),
|
|
|
|
| 364 |
State('selected-strategies-store', 'data'),
|
| 365 |
prevent_initial_call=True
|
| 366 |
)
|
| 367 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
| 368 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
| 369 |
-
op_time_overlapped_fwd_bwd,
|
| 370 |
selected_strategies):
|
| 371 |
|
| 372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
|
@@ -480,6 +500,7 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 480 |
placement_strategy=placement_strategy,
|
| 481 |
split_backward=split_backward,
|
| 482 |
op_times=op_times,
|
|
|
|
| 483 |
)
|
| 484 |
|
| 485 |
schedule_func = STRATEGIES.get(strategy)
|
|
|
|
| 38 |
"op_time_backward": 2.0,
|
| 39 |
"strategy": ["1f1b_interleave"],
|
| 40 |
"op_time_overlapped_fwd_bwd": None,
|
| 41 |
+
"microbatch_group_size_per_vp_stage": None,
|
| 42 |
}
|
| 43 |
|
| 44 |
# Define input groups using dbc components
|
|
|
|
| 187 |
placement="right"
|
| 188 |
)
|
| 189 |
], className="mb-3"),
|
| 190 |
+
html.Div([
|
| 191 |
+
html.Div([
|
| 192 |
+
dbc.Label("Microbatch Group Size per VP Stage", html_for='microbatch_group_size_per_vp_stage', className="form-label d-inline-block me-1"),
|
| 193 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-microbatch-group", style={"cursor": "pointer"})
|
| 194 |
+
]),
|
| 195 |
+
dbc.Input(id='microbatch_group_size_per_vp_stage', type='number', placeholder=f"Defaults to num_devices", min=1, step=1, value=default_values["microbatch_group_size_per_vp_stage"]),
|
| 196 |
+
dbc.FormText("Used for interleave strategies (1f1b_interleave, 1f1b_interleave_overlap)."),
|
| 197 |
+
dbc.FormFeedback("Microbatch group size must be a positive integer if specified.", type="invalid", id="feedback-microbatch_group_size_per_vp_stage"),
|
| 198 |
+
dbc.Tooltip(
|
| 199 |
+
"Number of microbatches to process per virtual pipeline stage before switching to the next stage. Used primarily with interleave scheduling strategies. Defaults to the number of devices.",
|
| 200 |
+
target="tooltip-target-microbatch-group",
|
| 201 |
+
placement="right"
|
| 202 |
+
)
|
| 203 |
+
], className="mb-3"),
|
| 204 |
]
|
| 205 |
)
|
| 206 |
]),
|
|
|
|
| 264 |
Output('op_time_backward_d', 'invalid'),
|
| 265 |
Output('op_time_backward_w', 'invalid'),
|
| 266 |
Output('op_time_overlapped_fwd_bwd', 'invalid'),
|
| 267 |
+
Output('microbatch_group_size_per_vp_stage', 'invalid'),
|
| 268 |
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
|
| 269 |
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
|
| 270 |
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
|
|
|
|
| 279 |
Input('op_time_backward_d', 'value'),
|
| 280 |
Input('op_time_backward_w', 'value'),
|
| 281 |
Input('op_time_overlapped_fwd_bwd', 'value'),
|
| 282 |
+
Input('microbatch_group_size_per_vp_stage', 'value'),
|
| 283 |
Input('selected-strategies-store', 'data'), # Validate strategy selection
|
| 284 |
prevent_initial_call=True # Prevent callback running on page load before user interaction
|
| 285 |
)
|
| 286 |
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
| 287 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
| 288 |
+
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage, selected_strategies):
|
| 289 |
is_invalid = {
|
| 290 |
"num_devices": num_devices is None or num_devices < 1,
|
| 291 |
"num_stages": num_stages is None or num_stages < 1,
|
|
|
|
| 296 |
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
|
| 297 |
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
|
| 298 |
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
|
| 299 |
+
"microbatch_group_size_per_vp_stage": microbatch_group_size_per_vp_stage is not None and (microbatch_group_size_per_vp_stage < 1 or microbatch_group_size_per_vp_stage % 1 != 0),
|
| 300 |
}
|
| 301 |
|
| 302 |
# Validate strategy selection
|
|
|
|
| 336 |
is_invalid["op_time_backward_d"],
|
| 337 |
is_invalid["op_time_backward_w"],
|
| 338 |
is_invalid["op_time_overlapped_fwd_bwd"],
|
| 339 |
+
is_invalid["microbatch_group_size_per_vp_stage"],
|
| 340 |
strategy_feedback # Update strategy feedback based on validation
|
| 341 |
)
|
| 342 |
|
|
|
|
| 380 |
State('op_time_backward_d', 'value'),
|
| 381 |
State('op_time_backward_w', 'value'),
|
| 382 |
State('op_time_overlapped_fwd_bwd', 'value'),
|
| 383 |
+
State('microbatch_group_size_per_vp_stage', 'value'),
|
| 384 |
State('selected-strategies-store', 'data'),
|
| 385 |
prevent_initial_call=True
|
| 386 |
)
|
| 387 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
| 388 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
| 389 |
+
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage,
|
| 390 |
selected_strategies):
|
| 391 |
|
| 392 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
|
|
|
| 500 |
placement_strategy=placement_strategy,
|
| 501 |
split_backward=split_backward,
|
| 502 |
op_times=op_times,
|
| 503 |
+
microbatch_group_size_per_vp_stage=int(microbatch_group_size_per_vp_stage) if microbatch_group_size_per_vp_stage is not None else None,
|
| 504 |
)
|
| 505 |
|
| 506 |
schedule_func = STRATEGIES.get(strategy)
|