Spaces:
Running
Running
Update overlapped time.
Browse files- src/execution_model.py +14 -9
src/execution_model.py
CHANGED
|
@@ -90,7 +90,6 @@ class ScheduleConfig:
|
|
| 90 |
self.p2p_latency = p2p_latency
|
| 91 |
self.placement_strategy = placement_strategy
|
| 92 |
self.split_backward = split_backward
|
| 93 |
-
self.overlapped_op_times = {}
|
| 94 |
|
| 95 |
# Initialize default operation times
|
| 96 |
if self.split_backward:
|
|
@@ -152,14 +151,13 @@ class ScheduleConfig:
|
|
| 152 |
def get_op_time(self, op_type: str, stage_id: int):
|
| 153 |
# For overlapped operations, extract the original operation types
|
| 154 |
if op_type.startswith("overlapped_"):
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
if (
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id))
|
| 163 |
|
| 164 |
if op_type not in self.op_times:
|
| 165 |
raise ValueError(f"Invalid operation type: {op_type}")
|
|
@@ -332,3 +330,10 @@ class Schedule:
|
|
| 332 |
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
|
| 333 |
|
| 334 |
return (actual_time - ideal_time) / ideal_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
self.p2p_latency = p2p_latency
|
| 91 |
self.placement_strategy = placement_strategy
|
| 92 |
self.split_backward = split_backward
|
|
|
|
| 93 |
|
| 94 |
# Initialize default operation times
|
| 95 |
if self.split_backward:
|
|
|
|
| 151 |
def get_op_time(self, op_type: str, stage_id: int):
|
| 152 |
# For overlapped operations, extract the original operation types
|
| 153 |
if op_type.startswith("overlapped_"):
|
| 154 |
+
if op_type in self.op_times and self.op_times[op_type][stage_id]:
|
| 155 |
+
return self.op_times[op_type][stage_id]
|
| 156 |
+
else:
|
| 157 |
+
op_parts = op_type.split("_")[1:]
|
| 158 |
+
if len(op_parts) >= 2:
|
| 159 |
+
op_type1, op_type2 = op_parts[0], op_parts[1]
|
| 160 |
+
return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id)
|
|
|
|
| 161 |
|
| 162 |
if op_type not in self.op_times:
|
| 163 |
raise ValueError(f"Invalid operation type: {op_type}")
|
|
|
|
| 330 |
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
|
| 331 |
|
| 332 |
return (actual_time - ideal_time) / ideal_time
|
| 333 |
+
|
| 334 |
+
def get_device_running_time(self):
|
| 335 |
+
device_time = [0] * self.config.num_devices
|
| 336 |
+
for dev_id in range(self.config.num_devices):
|
| 337 |
+
for op in self.device_queues[dev_id].ops:
|
| 338 |
+
device_time[dev_id] += op.end_time - op.start_time
|
| 339 |
+
return device_time
|