squirelmail commited on
Commit
ca1e5ec
·
verified ·
1 Parent(s): 76a4609

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +535 -0
train_model.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Install Library
4
+ # pip install -U tensorflow[and-cuda] torch torchvision pandas scikit-learn pillow numpy
5
+ # pip install -U tf-nightly[and-cuda] torch torchvision pandas scikit-learn pillow numpy
6
+ # pip install -U tensorflow torch torchvision pandas scikit-learn pillow numpy
7
+ # pip install -U "tensorflow[and-cuda]==2.17.0"
8
+ # pip install torch==2.8.0 torchvision==0.23.0
9
+
10
+ # pip uninstall -y tensorflow tensorflow-cpu tensorflow-intel tensorflow-gpu
11
+ # pip cache purge
12
+ # # opsi A: nightly bundling CUDA
13
+ # pip install -U "tf-nightly[and-cuda]"
14
+ # # atau opsi B (kalau A tidak tersedia di index kamu):
15
+ # pip install -U tf-nightly
16
+
17
+
18
+ import tensorflow as tf
19
+ gpus = tf.config.list_physical_devices('GPU')
20
+ print(gpus)
21
+ if gpus:
22
+ try:
23
+ # for gpu in gpus:
24
+ # tf.config.experimental.set_memory_growth(gpu, True) # no full prealloc
25
+ print(f"GPU aktif: {gpus}")
26
+ except Exception as e:
27
+ print("Set memory growth gagal:", e)
28
+ else:
29
+ print("Tidak ada GPU terdeteksi.")
30
+
31
+
32
+ # Clean UP Dataset Make Sure Every Style Same Image
33
+ import os
34
+
35
+ BASE_DIR = "/workspace/dataset" # ubah sesuai path dataset kamu
36
+ START, END = 0, 59 # style0..style59
37
+ DRY_RUN = False # ubah ke False untuk beneran hapus
38
+
39
+ def main():
40
+ base = os.path.abspath(BASE_DIR)
41
+ ref_dir = os.path.join(base, f"style{START}")
42
+ if not os.path.isdir(ref_dir):
43
+ print(f"❌ Folder {ref_dir} tidak ditemukan.")
44
+ return
45
+
46
+ files_ref = sorted([f for f in os.listdir(ref_dir) if f.lower().endswith(".png")])
47
+ print(f"🔍 Total referensi dari style{START}: {len(files_ref)} file")
48
+
49
+ # Cari file yang lengkap di semua style
50
+ complete = []
51
+ missing = {}
52
+
53
+ for fname in files_ref:
54
+ ok = True
55
+ for i in range(START, END + 1):
56
+ style_path = os.path.join(base, f"style{i}", fname)
57
+ if not os.path.isfile(style_path):
58
+ ok = False
59
+ missing.setdefault(fname, []).append(f"style{i}")
60
+ if ok:
61
+ complete.append(fname)
62
+
63
+ print(f"✅ Lengkap di semua style: {len(complete)} file")
64
+ print(f"❌ Tidak lengkap: {len(missing)} file")
65
+
66
+ # Hapus file yang tidak lengkap dari semua style
67
+ if missing:
68
+ for fname, styles in missing.items():
69
+ for i in range(START, END + 1):
70
+ path = os.path.join(base, f"style{i}", fname)
71
+ if os.path.isfile(path):
72
+ if not DRY_RUN:
73
+ os.remove(path)
74
+ print(f"🗑️ Hapus {path}")
75
+ print(f"\n🔥 Selesai! Total {len(missing)} file dibersihkan dari semua style folder.")
76
+ else:
77
+ print("Semua file sudah lengkap di semua style — tidak ada yang dihapus.")
78
+
79
+ if __name__ == "__main__":
80
+ main()
81
+
82
+ import os
83
+ from glob import glob
84
+ import pandas as pd
85
+
86
+ data = []
87
+
88
+ root_dir = "/workspace/dataset"
89
+
90
+ for style_id in range(60):
91
+ folder_path = os.path.join(root_dir, f"style{style_id}")
92
+ image_paths = glob(os.path.join(folder_path, "*.png"))
93
+
94
+ for path in image_paths:
95
+ label = os.path.splitext(os.path.basename(path))[0] # ambil nama file tanpa ekstensi
96
+ data.append((path, label, f"style{style_id}"))
97
+
98
+ df = pd.DataFrame(data, columns=["filepath", "label", "style"])
99
+
100
+ df
101
+
102
+ import re
103
+ import pandas as pd
104
+ from collections import Counter
105
+
106
+ # --- aturan ketat: 5 karakter, A-Z atau 0-9 saja ---
107
+ ALLOWED_REGEX_STRICT = r'^[A-Z0-9]{5}$'
108
+ ALLOWED_REGEX_LEN5_ALNUM = r'^[A-Za-z0-9]{5}$' # kalau mau toleransi lowercase hanya untuk deteksi
109
+
110
+ # pastikan kolom label rapi untuk diperiksa
111
+ df['label'] = df['label'].astype(str).str.strip()
112
+
113
+ # 1) MASK PELANGGAR (ketat)
114
+ invalid_mask = ~df['label'].str.match(ALLOWED_REGEX_STRICT, na=True)
115
+ invalid_df = df[invalid_mask].copy()
116
+
117
+ # 2) KATEGORIKAN PENYEBAB
118
+ df['len'] = df['label'].str.len()
119
+ too_short = df[df['len'] < 5]
120
+ too_long = df[df['len'] > 5]
121
+ has_non_alnum = df[df['label'].str.contains(r'[^A-Za-z0-9]', na=True)]
122
+ has_lower = df[df['label'].str.contains(r'[a-z]', na=True)] # masih ada huruf kecil?
123
+
124
+ # 3) KARAKTER NAKAL (non-alnum) YANG MUNCUL
125
+ def extract_bad_chars(s: str):
126
+ return re.findall(r'[^A-Za-z0-9]', s)
127
+
128
+ bad_chars_counter = Counter()
129
+ for lab in has_non_alnum['label'].dropna().tolist():
130
+ bad_chars_counter.update(extract_bad_chars(lab))
131
+ bad_chars_list = sorted(bad_chars_counter.items(), key=lambda x: -x[1])
132
+
133
+ # 4) RINGKASAN
134
+ print("=== VALIDASI LABEL ===")
135
+ print(f"Total data : {len(df)}")
136
+ print(f"Tidak valid (ketat): {len(invalid_df)}")
137
+ print(f"- Panjang < 5 : {len(too_short)}")
138
+ print(f"- Panjang > 5 : {len(too_long)}")
139
+ print(f"- Ada non-alnum : {len(has_non_alnum)}")
140
+ print(f"- Ada lowercase : {len(has_lower)}")
141
+
142
+ # contoh beberapa label bermasalah
143
+ if len(invalid_df) > 0:
144
+ sampel = invalid_df['label'].head(20).tolist()
145
+ print("\nContoh label tidak valid (maks 20):", sampel)
146
+
147
+ # karakter non-alnum beserta frekuensinya
148
+ if bad_chars_list:
149
+ print("\nKarakter non-alnum yang muncul (char, count):", bad_chars_list[:20])
150
+
151
+ # 5) SIMPAN DAFTAR PELANGGAR KE CSV (biar bisa diperbaiki manual / rename file)
152
+ if len(invalid_df) > 0:
153
+ invalid_df.to_csv("invalid_labels.csv", index=False)
154
+ print("\n>> Disimpan: invalid_labels.csv")
155
+
156
+ # 6) OPSIONAL: STOP TRAINING JIKA MASIH ADA PELANGGAR
157
+ if len(invalid_df) > 0:
158
+ raise ValueError(
159
+ f"Ditemukan {len(invalid_df)} label tidak valid. Perbaiki dulu (lihat invalid_labels.csv)."
160
+ )
161
+
162
+ # Contoh: validasi panjang label = 5, hanya alphanumeric
163
+ # df = df[df['label'].str.match(r'^[a-zA-Z0-9]{5}$')]
164
+
165
+ df
166
+
167
+ from sklearn.model_selection import train_test_split
168
+
169
+ train_df, test_df = train_test_split(df, test_size=0.1, random_state=42, stratify=df['style'])
170
+ train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42, stratify=train_df['style'])
171
+
172
+ from torchvision import transforms
173
+ from PIL import Image
174
+
175
+ transform = transforms.Compose([
176
+ transforms.Resize((50, 250)), # Ukuran umum CAPTCHA
177
+ transforms.ToTensor(),
178
+ transforms.Normalize((0.5,), (0.5,)) # Normalisasi ke -1..1
179
+ ])
180
+
181
+ def load_image(path):
182
+ img = Image.open(path).convert("L") # convert to grayscale
183
+ return transform(img)
184
+
185
+ from torch.utils.data import Dataset
186
+
187
+ class CaptchaDataset(Dataset):
188
+ def __init__(self, dataframe, transform):
189
+ self.dataframe = dataframe.reset_index(drop=True)
190
+ self.transform = transform
191
+
192
+ def __len__(self):
193
+ return len(self.dataframe)
194
+
195
+ def __getitem__(self, idx):
196
+ row = self.dataframe.iloc[idx]
197
+ image = Image.open(row.filepath).convert("L")
198
+ image = self.transform(image)
199
+ label = row.label
200
+ return image, label
201
+
202
+
203
+ from tensorflow.keras import mixed_precision
204
+ mixed_precision.set_global_policy('mixed_float16') # aktivasi AMP
205
+
206
+ import tensorflow as tf
207
+ from tensorflow.keras.models import Model
208
+ from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Reshape, Bidirectional, LSTM, Dense, Dropout, Activation, BatchNormalization
209
+ from tensorflow.keras import backend as K
210
+
211
+ # Define the character set (based on your label data)
212
+ # You need to create a character set based on the unique characters in your 'label' column
213
+ # For example:
214
+ # char_set = sorted(list(set("".join(df['label'].unique()))))
215
+ # num_classes = len(char_set) + 1 # +1 for the blank label for CTC
216
+
217
+ # Placeholder for the actual character set - replace with your data's character set
218
+ # char_set = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
219
+ char_set = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
220
+ num_classes = len(char_set) + 1 # +1 for the blank label for CTC
221
+
222
+ # Model parameters
223
+ # input_shape = (60, 160, 1) # (height, width, channels)
224
+ input_shape = (50, 250, 1) # (height, width, channels)
225
+ lstm_units = 128
226
+
227
+ # Input layer
228
+ input_tensor = Input(shape=input_shape, name='input')
229
+
230
+ # Convolutional layers (CNN)
231
+ x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_tensor)
232
+ x = BatchNormalization()(x)
233
+ x = MaxPooling2D((2, 2))(x)
234
+
235
+ x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
236
+ x = BatchNormalization()(x)
237
+ x = MaxPooling2D((2, 2))(x)
238
+
239
+ x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
240
+ x = BatchNormalization()(x)
241
+ x = MaxPooling2D((2, 2))(x)
242
+
243
+ # Reshape for RNN
244
+ # The output shape of the last pooling layer is (batch_size, height, width, filters)
245
+ # We need to reshape it to (batch_size, time_steps, features) for the RNN
246
+ # time_steps will be the width of the feature maps after pooling
247
+ # features will be height * filters
248
+ shape_before_rnn = K.int_shape(x)
249
+ x = Reshape(target_shape=(shape_before_rnn[2], shape_before_rnn[1] * shape_before_rnn[3]))(x)
250
+
251
+ # Recurrent layers (RNN - Bidirectional LSTM)
252
+ # x = Bidirectional(LSTM(lstm_units, return_sequences=True, dropout=0.25))(x)
253
+ # x = Bidirectional(LSTM(lstm_units, return_sequences=True, dropout=0.25))(x)
254
+ # dropout>0 menonaktifkan kernel cuDNN. Untuk memaksimalkan GPU:
255
+ # set dropout=0.0 dan recurrent_dropout=0.0
256
+ # biarkan activation='tanh' & recurrent_activation='sigmoid' (default)
257
+ # unroll=False (default)
258
+ x = Bidirectional(tf.keras.layers.LSTM(
259
+ 128, return_sequences=True,
260
+ dropout=0.0, recurrent_dropout=0.0
261
+ ))(x)
262
+ x = Bidirectional(tf.keras.layers.LSTM(
263
+ 128, return_sequences=True,
264
+ dropout=0.0, recurrent_dropout=0.0
265
+ ))(x)
266
+
267
+
268
+ # Output layer
269
+ x = Dense(num_classes, activation='softmax', name='predictions')(x)
270
+
271
+ # Model definition
272
+ model = Model(inputs=input_tensor, outputs=x)
273
+
274
+
275
+ # CTC Loss function – TANPA slicing
276
+ # ganti dtypes ke int32
277
+ labels = tf.keras.Input(name='labels', shape=(None,), dtype='int32')
278
+ input_length= tf.keras.Input(name='input_length', shape=(1,), dtype='int32')
279
+ label_length= tf.keras.Input(name='label_length', shape=(1,), dtype='int32')
280
+
281
+ def ctc_lambda_func(args):
282
+ y_pred, labels_t, in_len, lab_len = args
283
+ # jangan slicing y_pred
284
+ return tf.keras.backend.ctc_batch_cost(labels_t, y_pred, in_len, lab_len)
285
+
286
+ ctc_loss_output = tf.keras.layers.Lambda(
287
+ ctc_lambda_func, output_shape=(1,), name='ctc_loss', dtype='float32' # pastikan loss float32
288
+ )([x, labels, input_length, label_length])
289
+
290
+ # Model with CTC loss
291
+ model_with_ctc = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=ctc_loss_output)
292
+
293
+ # Compile the model
294
+ model_with_ctc.compile(loss={'ctc_loss': lambda y_true, y_pred: y_pred}, optimizer='adam')
295
+ # opt = tf.keras.optimizers.Adam(1e-3, clipnorm=5.0)
296
+ # model_with_ctc.compile(
297
+ # loss={'ctc_loss': lambda y_true, y_pred: y_pred},
298
+ # optimizer=opt,
299
+ # # jit_compile=True, # <<— aktifkan XLA (TF >= 2.9 / Keras 3)
300
+ # jit_compile=False, # <<— aktifkan XLA (TF >= 2.9 / Keras 3)
301
+ # )
302
+
303
+ model.summary()
304
+
305
+ from torchvision import transforms as T
306
+ from torchvision.transforms import InterpolationMode
307
+ import tensorflow as tf
308
+
309
+ # 1) Transform ke 50x250 (tanpa distorsi)
310
+ transform = transforms.Compose([
311
+ transforms.Resize((50, 250), interpolation=InterpolationMode.BILINEAR, antialias=True),
312
+ transforms.ToTensor(),
313
+ transforms.Normalize((0.5,), (0.5,)),
314
+ ])
315
+
316
+ CHARSET = list("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ")
317
+
318
+ # forward mapping: no UNK, no mask
319
+ char_to_num = tf.keras.layers.StringLookup(
320
+ vocabulary=CHARSET,
321
+ oov_token=None,
322
+ mask_token=None, # no mask
323
+ num_oov_indices=0 # no UNK
324
+ )
325
+
326
+
327
+ # inverse mapping: JANGAN set oov_token
328
+ num_to_char = tf.keras.layers.StringLookup(
329
+ vocabulary=CHARSET, # pakai CHARSET langsung
330
+ invert=True,
331
+ num_oov_indices=0, # penting
332
+ mask_token=None,
333
+ )
334
+
335
+ print("vocab size:", len(char_to_num.get_vocabulary())) # -> 36
336
+ print(char_to_num.get_vocabulary()) # -> ['0','1',...,'Z']
337
+ print(num_to_char.get_vocabulary()) # -> ['0','1',...,'Z']
338
+
339
+ class DataGenerator(tf.keras.utils.Sequence):
340
+ def __init__(self, dataframe, char_to_num,
341
+ batch_size=32, img_width=250, img_height=50, max_label_length=5):
342
+ self.dataframe = dataframe.reset_index(drop=True)
343
+ self.char_to_num = char_to_num
344
+ self.batch_size = batch_size
345
+ self.img_width = img_width
346
+ self.img_height = img_height
347
+ self.max_label_length = max_label_length
348
+ # time-steps setelah 3x MaxPool(2,2) di sumbu lebar
349
+ self.time_steps = self.img_width // 8 # 250 // 8 = 31
350
+ self.on_epoch_end()
351
+
352
+ def __len__(self):
353
+ return len(self.dataframe) // self.batch_size # drop last
354
+
355
+ def __getitem__(self, index):
356
+ start_index = index * self.batch_size
357
+ end_index = (index + 1) * self.batch_size
358
+ batch_df = self.dataframe.iloc[start_index:end_index]
359
+
360
+ images = []
361
+ labels = []
362
+ input_lengths = np.full((len(batch_df), 1), self.time_steps, dtype=np.int64)
363
+ label_lengths = []
364
+
365
+ for _, row in batch_df.iterrows():
366
+ # 1) Load & preprocess image -> (H,W,1) float32
367
+ img = Image.open(row.filepath).convert("L")
368
+ t = transform(img) # torch tensor (1,H,W), normalized [-1,1]
369
+ arr = t.permute(1, 2, 0).numpy() # -> (H,W,1)
370
+ images.append(arr)
371
+
372
+ # 2) Encode label (UPPERCASE), pad -1, dtype int32
373
+ lab = row.label.upper()
374
+ lab_ids = self.char_to_num(tf.constant(list(lab))).numpy().astype(np.int32)
375
+ pad_len = self.max_label_length - len(lab_ids)
376
+ if pad_len < 0:
377
+ lab_ids = lab_ids[:self.max_label_length]
378
+ pad_len = 0
379
+ lab_ids = np.pad(lab_ids, (0, pad_len), mode="constant", constant_values=-1)
380
+ labels.append(lab_ids)
381
+
382
+ # 3) label_length asli (tanpa padding)
383
+ label_lengths.append([len(lab)])
384
+
385
+ images = np.asarray(images, dtype=np.float32) # (B,H,W,1)
386
+ labels = np.asarray(labels, dtype=np.int32) # (B,L)
387
+ label_lengths = np.asarray(label_lengths, dtype=np.int64) # (B,1)
388
+
389
+ inputs = {
390
+ 'input': images,
391
+ 'labels': labels,
392
+ 'input_length': input_lengths,
393
+ 'label_length': label_lengths
394
+ }
395
+ # dummy target; loss dihitung di Lambda
396
+ outputs = np.zeros((images.shape[0],), dtype=np.float32)
397
+
398
+ return inputs, outputs
399
+
400
+ def on_epoch_end(self):
401
+ self.dataframe = self.dataframe.sample(frac=1.0).reset_index(drop=True)
402
+
403
+ # Instantiate the data generators
404
+ train_generator = DataGenerator(train_df, char_to_num, batch_size=32, max_label_length=5)
405
+ val_generator = DataGenerator(val_df, char_to_num, batch_size=32, max_label_length=5)
406
+
407
+ import numpy as np
408
+ # cek isian
409
+ # ambil batch pertama
410
+ (inputs, outputs) = train_generator[0]
411
+
412
+ x = inputs['input'] # (B, 50, 250, 1), float32, ~[-1,1]
413
+ y = inputs['labels'] # (B, 5), int32, pad = -1
414
+ inlen = inputs['input_length'] # (B, 1) == 31
415
+ lablen = inputs['label_length'] # (B, 1) == 5
416
+
417
+ print("x:", x.shape, x.dtype, x.min(), x.max())
418
+ print("labels:", y.shape, y.dtype, "unique pads:", sorted(set(y.flatten()) - set(range(0,36)))[:5])
419
+ print("input_length uniq:", set(inlen.flatten().tolist()))
420
+ print("label_length uniq:", set(lablen.flatten().tolist()))
421
+ print("outputs (dummy):", outputs.shape, outputs.dtype)
422
+
423
+ # assert sanity
424
+ assert x.shape[1:] == (50, 250, 1)
425
+ assert y.shape[1] == 5
426
+ assert inlen.min() == inlen.max() == 31
427
+ assert lablen.min() >= 1 and lablen.max() <= 5
428
+ assert y.dtype == np.int32
429
+
430
+ # CEK CTC DECODING
431
+ # 1) pastikan semua id label ada di rentang 0..35
432
+ assert y.min() >= 0 and y.max() <= 35, f"Label di luar rentang 0..35: min={y.min()}, max={y.max()}"
433
+
434
+ # 2) quick CTC loss test (harus finite, bukan NaN/Inf)
435
+ yp = model.predict(x[:4], verbose=0) # (4, 31, 37)
436
+ loss = tf.keras.backend.ctc_batch_cost(y[:4], yp, inlen[:4], lablen[:4]).numpy()
437
+ print("CTC sample loss:", loss) # cek semua np.isfinite(loss)
438
+ assert np.all(np.isfinite(loss)), f"CTC loss non-finite: {loss}"
439
+
440
+ # 3) (opsional) decode balik 3 label GT buat sanity check mapping
441
+ CHARSET = np.array(list("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
442
+ def decode_ids_row_np(ids_1d):
443
+ ids_1d = [int(t) for t in ids_1d if int(t) >= 0] # buang padding
444
+ return "".join(CHARSET[ids_1d]) if ids_1d else ""
445
+
446
+ for i in range(3):
447
+ print(i, "GT:", decode_ids_row_np(y[i]))
448
+
449
+
450
+
451
+ """SIMPAN TIAP EPOCH"""
452
+
453
+ import os, re, glob
454
+ from pathlib import Path
455
+ import tensorflow as tf
456
+ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
457
+
458
+ # ====== Paths ======
459
+ CKPT_DIR = Path("/workspace")
460
+ CKPT_DIR.mkdir(parents=True, exist_ok=True)
461
+
462
+ BEST_PATH = CKPT_DIR / "captcha_best.weights.h5"
463
+ EPOCH_PATH = CKPT_DIR / "captcha_ep{epoch:03d}.weights.h5" # <-- setiap epoch
464
+
465
+ # ====== Callbacks ======
466
+ # 1) Simpan "best" berdasarkan val_loss
467
+ ckpt_best = ModelCheckpoint(
468
+ filepath=str(BEST_PATH),
469
+ monitor="val_loss",
470
+ save_best_only=True,
471
+ save_weights_only=True,
472
+ save_freq="epoch",
473
+ verbose=1,
474
+ )
475
+
476
+ # 2) Simpan SETIAP EPOCH
477
+ ckpt_every_epoch = ModelCheckpoint(
478
+ filepath=str(EPOCH_PATH),
479
+ save_best_only=False, # <-- wajib False untuk setiap epoch
480
+ save_weights_only=True,
481
+ save_freq="epoch", # defaultnya juga 'epoch', ini eksplisit saja
482
+ verbose=0,
483
+ )
484
+
485
+ early_stopping = EarlyStopping(
486
+ monitor="val_loss",
487
+ patience=15,
488
+ restore_best_weights=True,
489
+ verbose=1,
490
+ )
491
+
492
+ # ====== Resume logic ======
493
+ def find_latest_epoch_ckpt(dir_path: Path):
494
+ files = glob.glob(str(dir_path / "captcha_ep*.weights.h5"))
495
+ if not files:
496
+ return None, None
497
+ pairs = []
498
+ for f in files:
499
+ m = re.search(r"captcha_ep(\d{3})\.weights\.h5$", os.path.basename(f))
500
+ if m:
501
+ pairs.append((int(m.group(1)), f))
502
+ if not pairs:
503
+ return None, None
504
+ pairs.sort(key=lambda x: x[0])
505
+ return pairs[-1] # (epoch, path)
506
+
507
+ initial_epoch = 0
508
+ ep, last_path = find_latest_epoch_ckpt(CKPT_DIR)
509
+ if last_path:
510
+ print(f"[RESUME] Loading weights from {last_path}")
511
+ model_with_ctc.load_weights(last_path)
512
+ initial_epoch = ep
513
+ print(f"[RESUME] initial_epoch set to {initial_epoch}")
514
+ elif BEST_PATH.exists():
515
+ print(f"[RESUME] Loading BEST weights from {BEST_PATH}")
516
+ model_with_ctc.load_weights(str(BEST_PATH))
517
+ initial_epoch = 0
518
+ else:
519
+ print("[RESUME] No checkpoint found. Starting from scratch.")
520
+
521
+ # ====== Fit ======
522
+ history = model_with_ctc.fit(
523
+ train_generator,
524
+ validation_data=val_generator,
525
+ epochs=100, # balikin ke target kamu
526
+ # epochs=10, # balikin ke target kamu
527
+ initial_epoch=initial_epoch,
528
+ callbacks=[ckpt_best, ckpt_every_epoch, early_stopping],
529
+ verbose=1,
530
+ )
531
+
532
+ # (Opsional) simpan bobot final & model inference
533
+ model_with_ctc.save_weights(str(CKPT_DIR / "captcha_final.weights.h5"))
534
+ model.save(str(CKPT_DIR / "captcha_final_model_base.h5")) # model inference (tanpa Lambda CTC)
535
+ model.save(str(CKPT_DIR / "captcha_final_model_base.keras"))