File size: 4,784 Bytes
dde1929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1713823
 
 
 
 
 
 
 
 
dde1929
1713823
 
 
 
 
dde1929
 
1713823
 
 
 
 
 
 
 
 
 
 
 
dde1929
1713823
 
dde1929
1713823
 
 
 
 
 
dde1929
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr

# Visualization style
plt.style.use('seaborn-v0_8-darkgrid')

# Random seed for reproducibility
random.seed(42)
np.random.seed(42)

def simulate_batches(num_workers=4, 
                     batch_time=500,     # ms
                     network_latency=200, # ms
                     mode='synchronous', 
                     num_batches=10):
    """
    Simulates mini-batch scheduling under synchronous vs asynchronous update strategies.
    Returns worker timelines and performance metrics.
    """
    timelines = []  # [(worker_id, start, end, active_flag)]
    current_time = [0] * num_workers  # track each worker's time progress

    for batch in range(num_batches):
        for w in range(num_workers):
            # Each worker takes a random batch processing time with jitter
            proc_time = random.uniform(batch_time * 0.8, batch_time * 1.2)
            start = current_time[w]
            end = start + proc_time
            timelines.append((w, start, end, 'active'))
            current_time[w] = end

        if mode == 'synchronous':
            # Barrier: wait for all workers to finish
            max_time = max(current_time)
            for w in range(num_workers):
                if current_time[w] < max_time:
                    timelines.append((w, current_time[w], max_time, 'idle'))
                    current_time[w] = max_time
            # Add sync overhead (e.g., gradient aggregation)
            current_time = [t + network_latency for t in current_time]
        else:
            # Asynchronous mode adds random network jitter
            current_time = [t + random.uniform(0, network_latency * 0.3) for t in current_time]

    total_time = max(current_time)
    idle_time = sum(
        end - start for (w, start, end, flag) in timelines if flag == 'idle'
    )
    total_blocks = sum(end - start for (_, start, end, _) in timelines)
    idle_percent = (idle_time / total_blocks) * 100
    throughput = (num_workers * num_batches * 1000) / total_time  # batches per second (approx)

    metrics = {
        "epoch_time_ms": total_time,
        "idle_percent": round(idle_percent, 2),
        "throughput": round(throughput, 2)
    }

    return timelines, metrics


def plot_timeline(timelines, metrics, num_workers):
    colors = {'active': '#4CAF50', 'idle': '#E74C3C'}
    fig, ax = plt.subplots(figsize=(10, 5))

    for (w, start, end, flag) in timelines:
        ax.barh(y=w, width=end-start, left=start, color=colors[flag], edgecolor='black')

    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Worker ID")
    ax.set_title("Batch Scheduler Simulation")
    ax.set_yticks(range(num_workers))
    ax.set_yticklabels([f"W{i}" for i in range(num_workers)])
    ax.invert_yaxis()

    text_summary = (
        f"Epoch Duration: {metrics['epoch_time_ms']:.2f} ms\n"
        f"Idle Time: {metrics['idle_percent']}%\n"
        f"Throughput: {metrics['throughput']} batches/sec"
    )

    plt.figtext(0.72, 0.35, text_summary, fontsize=10, 
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='gray'))
    plt.tight_layout()
    return fig


def run_simulation(num_workers, batch_time, network_latency, mode, num_batches):
    timelines, metrics = simulate_batches(
        num_workers=int(num_workers),
        batch_time=float(batch_time),
        network_latency=float(network_latency),
        mode=mode,
        num_batches=int(num_batches)
    )
    fig = plot_timeline(timelines, metrics, num_workers)
    summary = (
        f"Mode: {mode.capitalize()}\n"
        f"Epoch Time: {metrics['epoch_time_ms']:.2f} ms\n"
        f"Idle Time: {metrics['idle_percent']} %\n"
        f"Throughput: {metrics['throughput']} batches/sec"
    )
    return fig, summary


interface = gr.Interface(
    fn=run_simulation,
    inputs=[
        gr.Slider(1, 8, value=4, step=1, label="Number of Workers"),
        gr.Slider(100, 1000, value=500, step=50, label="Batch Processing Time (ms)"),
        gr.Slider(50, 500, value=200, step=25, label="Network Latency (ms)"),
        gr.Radio(["synchronous", "asynchronous"], value="synchronous", label="Mode"),
        gr.Slider(5, 30, value=10, step=1, label="Number of Batches per Epoch"),
    ],
    outputs=[
        gr.Plot(label="Timeline Visualization"),
        gr.Textbox(label="Simulation Summary", lines=8, max_lines=12, show_copy_button=True)
    ],
    title="🧠 Batch Scheduler Simulator",
    description="Visualize how synchronous vs asynchronous batch scheduling affects throughput, idle time, and epoch duration.",
    examples=[
        [4, 500, 200, "synchronous", 10],
        [8, 400, 150, "asynchronous", 15]
    ]
)

if __name__ == "__main__":
    interface.launch()