Spaces:
Running
Running
Update visualizer.
Browse files- src/visualizer.py +68 -66
src/visualizer.py
CHANGED
|
@@ -70,56 +70,47 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
| 70 |
# Cache the color calculation as it's repeatedly called with the same parameters
|
| 71 |
@lru_cache(maxsize=128)
|
| 72 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
| 73 |
-
# A more harmonious blue palette with
|
| 74 |
forward_colors = [
|
| 75 |
-
"#
|
| 76 |
-
"#
|
| 77 |
-
"#
|
| 78 |
-
"#
|
| 79 |
-
"#
|
| 80 |
-
"#809fff", # Medium blue
|
| 81 |
-
"#99b3e6", # Pale blue
|
| 82 |
-
"#b3c6ff", # Light blue
|
| 83 |
]
|
| 84 |
|
| 85 |
-
# Orange palette for backward operations
|
| 86 |
backward_colors = [
|
| 87 |
-
"#
|
| 88 |
-
"#
|
| 89 |
-
"#
|
| 90 |
-
"#
|
| 91 |
-
"#ffd699", # Light amber
|
| 92 |
-
"#ffd6ad", # Pale orange
|
| 93 |
-
"#ffe0c2", # Very pale orange
|
| 94 |
-
"#fff0e0", # Lightest orange
|
| 95 |
]
|
| 96 |
|
| 97 |
-
# Improved teal/turquoise palette with
|
| 98 |
backward_d_colors = [
|
| 99 |
-
"#
|
| 100 |
-
"#
|
| 101 |
-
"#
|
| 102 |
-
"#
|
| 103 |
-
"#
|
|
|
|
|
|
|
|
|
|
| 104 |
"#008080", # Dark teal
|
| 105 |
-
"#00e6cc", # Turquoise
|
| 106 |
-
"#4ddbbd", # Aqua
|
| 107 |
-
"#80d4c8", # Pale teal
|
| 108 |
-
"#b3e6e0", # Ice
|
| 109 |
]
|
| 110 |
|
| 111 |
-
# Improved green palette with
|
| 112 |
backward_w_colors = [
|
| 113 |
-
"#
|
| 114 |
-
"#
|
| 115 |
-
"#
|
| 116 |
-
"#80ffbf", #
|
| 117 |
-
"#
|
| 118 |
-
"#
|
| 119 |
-
"#
|
| 120 |
-
"#
|
| 121 |
-
"#
|
| 122 |
-
"#c6e6c6", # Pastel green
|
| 123 |
]
|
| 124 |
|
| 125 |
virtual_stage = stage_id // num_devices
|
|
@@ -130,11 +121,11 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
| 130 |
if op_type == "forward":
|
| 131 |
return forward_colors[color_index]
|
| 132 |
elif op_type == "backward":
|
| 133 |
-
return backward_colors[color_index]
|
| 134 |
elif op_type == "backward_D":
|
| 135 |
-
return backward_d_colors[color_index]
|
| 136 |
elif op_type == "backward_W":
|
| 137 |
-
return backward_w_colors[color_index]
|
| 138 |
else:
|
| 139 |
raise ValueError(f"Invalid operation type: {op_type}")
|
| 140 |
|
|
@@ -163,6 +154,15 @@ def create_pipeline_figure(
|
|
| 163 |
end_time = task["start_time"] + task["duration"]
|
| 164 |
if end_time > max_time:
|
| 165 |
max_time = end_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Create a figure
|
| 168 |
fig = go.Figure()
|
|
@@ -251,22 +251,23 @@ def create_pipeline_figure(
|
|
| 251 |
)
|
| 252 |
)
|
| 253 |
|
| 254 |
-
# Add batch number text for this sub-operation
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
)
|
| 269 |
-
)
|
| 270 |
else:
|
| 271 |
# Regular (non-overlapped) operation
|
| 272 |
# Determine task color and text color
|
|
@@ -305,16 +306,17 @@ def create_pipeline_figure(
|
|
| 305 |
)
|
| 306 |
)
|
| 307 |
|
| 308 |
-
# Add batch number text
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
| 316 |
)
|
| 317 |
-
)
|
| 318 |
|
| 319 |
# Prepare hover data
|
| 320 |
hover_text = (
|
|
|
|
| 70 |
# Cache the color calculation as it's repeatedly called with the same parameters
|
| 71 |
@lru_cache(maxsize=128)
|
| 72 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
| 73 |
+
# A more harmonious blue palette with low saturation and high brightness
|
| 74 |
forward_colors = [
|
| 75 |
+
"#0a5aff", # Intense blue
|
| 76 |
+
"#4c88ff", # Blue (deeper)
|
| 77 |
+
"#7aa7ff", # Medium blue
|
| 78 |
+
"#a8c5ff", # Soft blue
|
| 79 |
+
"#d6e4ff", # Very light blue
|
|
|
|
|
|
|
|
|
|
| 80 |
]
|
| 81 |
|
| 82 |
+
# Orange palette for backward operations with low saturation and high brightness
|
| 83 |
backward_colors = [
|
| 84 |
+
"#f47b00", # Intense orange
|
| 85 |
+
"#ffa952", # Orange
|
| 86 |
+
"#ffc78e", # Light orange
|
| 87 |
+
"#ffe6cc", # Very light orange
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
]
|
| 89 |
|
| 90 |
+
# Improved teal/turquoise palette with low saturation and high brightness
|
| 91 |
backward_d_colors = [
|
| 92 |
+
"#ccffff", # Very light cyan
|
| 93 |
+
"#b3ffff", # Pale cyan
|
| 94 |
+
"#99ffff", # Light cyan
|
| 95 |
+
"#80ffff", # Cyan
|
| 96 |
+
"#66e6e6", # Soft teal
|
| 97 |
+
"#4dcccc", # Light teal
|
| 98 |
+
"#33b3b3", # Teal
|
| 99 |
+
"#009999", # Medium teal
|
| 100 |
"#008080", # Dark teal
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
]
|
| 102 |
|
| 103 |
+
# Improved green palette with low saturation and high brightness
|
| 104 |
backward_w_colors = [
|
| 105 |
+
"#ccffe6", # Very light mint
|
| 106 |
+
"#b3ffd9", # Pale mint
|
| 107 |
+
"#99ffcc", # Light mint
|
| 108 |
+
"#80ffbf", # Mint green
|
| 109 |
+
"#66e6a6", # Soft green
|
| 110 |
+
"#4dcc8c", # Light green
|
| 111 |
+
"#33b373", # Medium green
|
| 112 |
+
"#009959", # Forest green
|
| 113 |
+
"#008040", # Dark green
|
|
|
|
| 114 |
]
|
| 115 |
|
| 116 |
virtual_stage = stage_id // num_devices
|
|
|
|
| 121 |
if op_type == "forward":
|
| 122 |
return forward_colors[color_index]
|
| 123 |
elif op_type == "backward":
|
| 124 |
+
return backward_colors[color_index % len(backward_colors)]
|
| 125 |
elif op_type == "backward_D":
|
| 126 |
+
return backward_d_colors[color_index % len(backward_d_colors)]
|
| 127 |
elif op_type == "backward_W":
|
| 128 |
+
return backward_w_colors[color_index % len(backward_w_colors)]
|
| 129 |
else:
|
| 130 |
raise ValueError(f"Invalid operation type: {op_type}")
|
| 131 |
|
|
|
|
| 154 |
end_time = task["start_time"] + task["duration"]
|
| 155 |
if end_time > max_time:
|
| 156 |
max_time = end_time
|
| 157 |
+
|
| 158 |
+
# Determine maximum batch number to decide whether to show text labels
|
| 159 |
+
max_batch = 0
|
| 160 |
+
for device in schedule_data:
|
| 161 |
+
for task in schedule_data[device]:
|
| 162 |
+
max_batch = max(max_batch, task["batch"])
|
| 163 |
+
|
| 164 |
+
# Flag to determine whether to show text labels
|
| 165 |
+
show_text_labels = max_batch <= 16
|
| 166 |
|
| 167 |
# Create a figure
|
| 168 |
fig = go.Figure()
|
|
|
|
| 251 |
)
|
| 252 |
)
|
| 253 |
|
| 254 |
+
# Add batch number text for this sub-operation only if show_text_labels is True
|
| 255 |
+
if show_text_labels:
|
| 256 |
+
# Determine text color based on background color
|
| 257 |
+
if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
|
| 258 |
+
text_color = "black"
|
| 259 |
+
else:
|
| 260 |
+
text_color = "white"
|
| 261 |
+
|
| 262 |
+
annotations.append(
|
| 263 |
+
dict(
|
| 264 |
+
x=start_time + duration / 2,
|
| 265 |
+
y=sub_y_center,
|
| 266 |
+
text=f"{sub_op['batch']}",
|
| 267 |
+
showarrow=False,
|
| 268 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 269 |
+
)
|
| 270 |
)
|
|
|
|
| 271 |
else:
|
| 272 |
# Regular (non-overlapped) operation
|
| 273 |
# Determine task color and text color
|
|
|
|
| 306 |
)
|
| 307 |
)
|
| 308 |
|
| 309 |
+
# Add batch number text only if show_text_labels is True
|
| 310 |
+
if show_text_labels:
|
| 311 |
+
annotations.append(
|
| 312 |
+
dict(
|
| 313 |
+
x=start_time + duration / 2,
|
| 314 |
+
y=y_pos,
|
| 315 |
+
text=f"{task['batch']}",
|
| 316 |
+
showarrow=False,
|
| 317 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 318 |
+
)
|
| 319 |
)
|
|
|
|
| 320 |
|
| 321 |
# Prepare hover data
|
| 322 |
hover_text = (
|