hash-map commited on
Commit
5b01b4f
Β·
verified Β·
1 Parent(s): f131a92

Upload 3 files

Browse files
discriminator_final.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be1a74166546ba9b56d53a6a00fcbb0b2ccf1c6efc2b11d5f79cf0b786562621
3
+ size 11140160
generator_final.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:629fef04ed3743eee1185e94a738c6dd9609f01a31bb94dd469130ee7c8cf823
3
+ size 66766872
visible-to-thermal-f7984f.ipynb ADDED
@@ -0,0 +1,1050 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "2f21528f",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-10-20T07:03:19.866733Z",
10
+ "iopub.status.busy": "2025-10-20T07:03:19.866103Z",
11
+ "iopub.status.idle": "2025-10-20T07:03:34.800621Z",
12
+ "shell.execute_reply": "2025-10-20T07:03:34.799988Z"
13
+ },
14
+ "papermill": {
15
+ "duration": 14.941555,
16
+ "end_time": "2025-10-20T07:03:34.802099",
17
+ "exception": false,
18
+ "start_time": "2025-10-20T07:03:19.860544",
19
+ "status": "completed"
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "import tensorflow as tf\n",
26
+ "from tensorflow.keras import Model, Input\n",
27
+ "from tensorflow.keras.layers import Conv2D, Conv2DTranspose, LeakyReLU\n",
28
+ "from tensorflow.keras.optimizers import Adam\n",
29
+ "import numpy as np\n",
30
+ "import os, glob, time, matplotlib.pyplot as plt\n",
31
+ "\n",
32
+ "# -------------------- SETTINGS --------------------\n",
33
+ "IMG_SIZE = 256\n",
34
+ "BATCH_SIZE = 16\n",
35
+ "EPOCHS = 100\n",
36
+ "PRINT_INTERVAL = 100\n",
37
+ "SAVE_INTERVAL_EPOCHS = 5\n",
38
+ "OUTPUT_DIR = \"new/output\"\n",
39
+ "CKPT_DIR = \"new/ckpt\"\n",
40
+ "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
41
+ "os.makedirs(CKPT_DIR, exist_ok=True)\n",
42
+ "\n",
43
+ "# -------------------- DATASET --------------------\n",
44
+ "BASE_DIR = \"data\"\n",
45
+ "def load_image_pair(v_path, i_path):\n",
46
+ " vis = tf.image.decode_png(tf.io.read_file(v_path), channels=3)\n",
47
+ " ir = tf.image.decode_png(tf.io.read_file(i_path), channels=3)\n",
48
+ " vis = tf.image.resize(vis, (IMG_SIZE, IMG_SIZE))\n",
49
+ " ir = tf.image.resize(ir, (IMG_SIZE, IMG_SIZE))\n",
50
+ " vis = tf.cast(vis, tf.float32) / 127.5 - 1.0\n",
51
+ " ir = tf.cast(ir, tf.float32) / 127.5 - 1.0\n",
52
+ " return vis, ir\n",
53
+ "def augment(vis, ir):\n",
54
+ " vis = tf.image.random_brightness(vis, 0.1)\n",
55
+ " vis = tf.image.random_contrast(vis, 0.8, 1.2)\n",
56
+ " return vis,ir\n",
57
+ "\n",
58
+ "def make_dataset(v_dir, i_dir,train=True):\n",
59
+ " vis_files = sorted(glob.glob(os.path.join(v_dir, \"*\")))\n",
60
+ " ir_files = sorted(glob.glob(os.path.join(i_dir, \"*\")))\n",
61
+ " ds = tf.data.Dataset.from_tensor_slices((vis_files, ir_files))\n",
62
+ " ds = ds.map(load_image_pair, num_parallel_calls=tf.data.AUTOTUNE)\n",
63
+ " ds = ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)\n",
64
+ " ds = ds.shuffle(500, reshuffle_each_iteration=True)\n",
65
+ " ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n",
66
+ " return ds"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 2,
72
+ "id": "d05541d3",
73
+ "metadata": {
74
+ "execution": {
75
+ "iopub.execute_input": "2025-10-20T07:03:34.810666Z",
76
+ "iopub.status.busy": "2025-10-20T07:03:34.809984Z",
77
+ "iopub.status.idle": "2025-10-20T07:03:36.136795Z",
78
+ "shell.execute_reply": "2025-10-20T07:03:36.135868Z"
79
+ },
80
+ "papermill": {
81
+ "duration": 1.332245,
82
+ "end_time": "2025-10-20T07:03:36.138219",
83
+ "exception": false,
84
+ "start_time": "2025-10-20T07:03:34.805974",
85
+ "status": "completed"
86
+ },
87
+ "tags": []
88
+ },
89
+ "outputs": [],
90
+ "source": [
91
+ "#train_ds = make_dataset(f\"{BASE_DIR}/train/visible\", f\"{BASE_DIR}/train/infrared\")"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 3,
97
+ "id": "c5fe4e05",
98
+ "metadata": {
99
+ "execution": {
100
+ "iopub.execute_input": "2025-10-20T07:03:36.146619Z",
101
+ "iopub.status.busy": "2025-10-20T07:03:36.146201Z",
102
+ "iopub.status.idle": "2025-10-20T07:03:37.528302Z",
103
+ "shell.execute_reply": "2025-10-20T07:03:37.527486Z"
104
+ },
105
+ "papermill": {
106
+ "duration": 1.387878,
107
+ "end_time": "2025-10-20T07:03:37.529837",
108
+ "exception": false,
109
+ "start_time": "2025-10-20T07:03:36.141959",
110
+ "status": "completed"
111
+ },
112
+ "tags": []
113
+ },
114
+ "outputs": [],
115
+ "source": [
116
+ "from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization, Concatenate, Activation, Dropout\n",
117
+ "from tensorflow.keras.models import Model\n",
118
+ "from tensorflow.keras.optimizers import Adam\n",
119
+ "\n",
120
+ "def build_generator(img_size=256, dropout_rate=0.05):\n",
121
+ " inp = Input(shape=(img_size, img_size, 3))\n",
122
+ "\n",
123
+ " # ---- Encoder ----\n",
124
+ " e1 = Conv2D(64, 4, strides=2, padding='same')(inp) # 128x128\n",
125
+ " e1 = LeakyReLU(0.2)(e1)\n",
126
+ "\n",
127
+ " e2 = Conv2D(128, 4, strides=2, padding='same')(e1) # 64x64\n",
128
+ " e2 = BatchNormalization()(e2)\n",
129
+ " e2 = LeakyReLU(0.2)(e2)\n",
130
+ "\n",
131
+ " e3 = Conv2D(256, 4, strides=2, padding='same')(e2) # 32x32\n",
132
+ " e3 = BatchNormalization()(e3)\n",
133
+ " e3 = LeakyReLU(0.2)(e3)\n",
134
+ "\n",
135
+ " e4 = Conv2D(512, 4, strides=2, padding='same')(e3) # 16x16\n",
136
+ " e4 = BatchNormalization()(e4)\n",
137
+ " e4 = LeakyReLU(0.2)(e4)\n",
138
+ "\n",
139
+ " # ---- Bottleneck ----\n",
140
+ " b = Conv2D(512, 4, strides=2, padding='same')(e4) # 8x8\n",
141
+ " b = Activation('relu')(b)\n",
142
+ " b = Dropout(dropout_rate)(b) # dropout in bottleneck\n",
143
+ "\n",
144
+ " # ---- Decoder ----\n",
145
+ " d1 = Conv2DTranspose(512, 4, strides=2, padding='same')(b) # 16x16\n",
146
+ " d1 = BatchNormalization()(d1)\n",
147
+ " d1 = Activation('relu')(d1)\n",
148
+ " d1 = Dropout(dropout_rate)(d1) # optional decoder dropout\n",
149
+ " d1 = Concatenate()([d1, e4])\n",
150
+ "\n",
151
+ " d2 = Conv2DTranspose(256, 4, strides=2, padding='same')(d1) # 32x32\n",
152
+ " d2 = BatchNormalization()(d2)\n",
153
+ " d2 = Activation('relu')(d2)\n",
154
+ " d2 = Dropout(dropout_rate)(d2)\n",
155
+ " d2 = Concatenate()([d2, e3])\n",
156
+ "\n",
157
+ " d3 = Conv2DTranspose(128, 4, strides=2, padding='same')(d2) # 64x64\n",
158
+ " d3 = BatchNormalization()(d3)\n",
159
+ " d3 = Activation('relu')(d3)\n",
160
+ " d3 = Dropout(dropout_rate)(d3)\n",
161
+ " d3 = Concatenate()([d3, e2])\n",
162
+ "\n",
163
+ " d4 = Conv2DTranspose(64, 4, strides=2, padding='same')(d3) # 128x128\n",
164
+ " d4 = BatchNormalization()(d4)\n",
165
+ " d4 = Activation('relu')(d4)\n",
166
+ " d4 = Dropout(dropout_rate)(d4)\n",
167
+ " d4 = Concatenate()([d4, e1])\n",
168
+ "\n",
169
+ " out = Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')(d4) # 256x256\n",
170
+ "\n",
171
+ " return Model(inp, out, name=\"UNet_Generator\")\n",
172
+ "\n",
173
+ "\n",
174
+ "def build_small_discriminator(img_size=256, dropout_rate=0.2):\n",
175
+ " vis_inp = Input(shape=(img_size, img_size, 3))\n",
176
+ " ir_inp = Input(shape=(img_size, img_size, 3))\n",
177
+ " x = Concatenate()([vis_inp, ir_inp])\n",
178
+ " \n",
179
+ " x = Conv2D(64, 4, strides=2, padding='same')(x)\n",
180
+ " x = LeakyReLU(0.2)(x)\n",
181
+ "\n",
182
+ " x = Conv2D(128, 4, strides=2, padding='same')(x)\n",
183
+ " x = BatchNormalization()(x)\n",
184
+ " x = LeakyReLU(0.2)(x)\n",
185
+ " x = Dropout(dropout_rate)(x) # optional\n",
186
+ "\n",
187
+ " x = Conv2D(256, 4, strides=2, padding='same')(x)\n",
188
+ " x = BatchNormalization()(x)\n",
189
+ " x = LeakyReLU(0.2)(x)\n",
190
+ " x = Dropout(dropout_rate)(x)\n",
191
+ "\n",
192
+ " x = Conv2D(512, 4, strides=1, padding='same')(x)\n",
193
+ " x = BatchNormalization()(x)\n",
194
+ " x = LeakyReLU(0.2)(x)\n",
195
+ " x = Dropout(dropout_rate)(x)\n",
196
+ "\n",
197
+ " out = Conv2D(1, 4, strides=1, padding='same')(x)\n",
198
+ " return Model([vis_inp, ir_inp], out, name=\"CondPatchGAN_Discriminator\")\n",
199
+ "\n",
200
+ "\n",
201
+ "# Instantiate models\n",
202
+ "generator = build_generator(IMG_SIZE)\n",
203
+ "discriminator = build_small_discriminator(IMG_SIZE)\n",
204
+ "initial_lr = 2e-4\n",
205
+ "lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n",
206
+ " initial_learning_rate=initial_lr,\n",
207
+ " decay_steps=500, # number of steps before decay\n",
208
+ " decay_rate=0.96, # decay factor\n",
209
+ " staircase=True # True -> discrete steps\n",
210
+ ")\n",
211
+ "\n",
212
+ "gen_opt = Adam(learning_rate=lr_schedule, beta_1=0.5)\n",
213
+ "disc_opt = Adam(learning_rate=lr_schedule, beta_1=0.5)"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": 4,
219
+ "id": "1ffb77ec",
220
+ "metadata": {
221
+ "execution": {
222
+ "iopub.execute_input": "2025-10-20T07:03:38.403886Z",
223
+ "iopub.status.busy": "2025-10-20T07:03:38.403676Z",
224
+ "iopub.status.idle": "2025-10-20T07:03:38.410298Z",
225
+ "shell.execute_reply": "2025-10-20T07:03:38.409568Z"
226
+ },
227
+ "papermill": {
228
+ "duration": 0.011794,
229
+ "end_time": "2025-10-20T07:03:38.411366",
230
+ "exception": false,
231
+ "start_time": "2025-10-20T07:03:38.399572",
232
+ "status": "completed"
233
+ },
234
+ "tags": []
235
+ },
236
+ "outputs": [],
237
+ "source": [
238
+ "# -------------------- LOSSES --------------------\n",
239
+ "bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
240
+ "\n",
241
+ "def perceptual_loss(y_true, y_pred):\n",
242
+ " return tf.reduce_mean(tf.abs(y_true - y_pred))\n",
243
+ "\n",
244
+ "def brightness_loss(y_true, y_pred):\n",
245
+ " true_b = tf.image.rgb_to_grayscale(y_true)\n",
246
+ " pred_b = tf.image.rgb_to_grayscale(y_pred)\n",
247
+ " return tf.reduce_mean(tf.abs(true_b - pred_b))\n",
248
+ "\n",
249
+ "def intensity_weighted_l1(y_true, y_pred): \n",
250
+ " weights = tf.abs(y_true) + 1 \n",
251
+ " return tf.reduce_mean(weights * tf.abs(y_true - y_pred))\n",
252
+ "\n",
253
+ "def ssim_loss(y_true, y_pred): \n",
254
+ " return -tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=2.0)) \n",
255
+ "\n",
256
+ "def generator_adv_loss(fake_pred):\n",
257
+ " return bce(tf.ones_like(fake_pred), fake_pred)\n",
258
+ "\n",
259
+ "def discriminator_loss(real, fake):\n",
260
+ " real_loss = bce(tf.ones_like(real) * 0.9, real)\n",
261
+ " fake_loss = bce(tf.zeros_like(fake) + 0.1, fake)\n",
262
+ " return (real_loss + fake_loss) * 0.5\n",
263
+ "def generator_adv_loss(fake_pred):\n",
264
+ " return bce(tf.ones_like(fake_pred), fake_pred)\n",
265
+ "\n",
266
+ "def discriminator_loss(real, fake):\n",
267
+ "\n",
268
+ " real_loss = bce(tf.ones_like(real) * 0.9, real)\n",
269
+ " fake_loss = bce(tf.zeros_like(fake) + 0.1, fake)\n",
270
+ "\n",
271
+ " return (real_loss + fake_loss) * 0.5"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 5,
277
+ "id": "e7645bd3",
278
+ "metadata": {
279
+ "execution": {
280
+ "iopub.execute_input": "2025-10-20T07:03:38.419250Z",
281
+ "iopub.status.busy": "2025-10-20T07:03:38.418628Z",
282
+ "iopub.status.idle": "2025-10-20T07:03:38.425897Z",
283
+ "shell.execute_reply": "2025-10-20T07:03:38.425349Z"
284
+ },
285
+ "papermill": {
286
+ "duration": 0.012037,
287
+ "end_time": "2025-10-20T07:03:38.426854",
288
+ "exception": false,
289
+ "start_time": "2025-10-20T07:03:38.414817",
290
+ "status": "completed"
291
+ },
292
+ "tags": []
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "@tf.function\n",
297
+ "def train_step(input_vis, target_ir, adv_weight=1.0, noise_std=0.05):\n",
298
+ " with tf.GradientTape(persistent=True) as tape:\n",
299
+ " gen_out = generator(input_vis, training=True)\n",
300
+ " \n",
301
+ " # Add noise to D inputs (helps stabilize D)\n",
302
+ " noisy_real = target_ir \n",
303
+ " noisy_fake = gen_out \n",
304
+ " \n",
305
+ " real_pred = discriminator([input_vis, noisy_real], training=True)\n",
306
+ " fake_pred = discriminator([input_vis, noisy_fake], training=True)\n",
307
+ "\n",
308
+ " # Compute losses\n",
309
+ " p = perceptual_loss(target_ir, gen_out)\n",
310
+ " b = brightness_loss(target_ir, gen_out)\n",
311
+ " w = intensity_weighted_l1(target_ir, gen_out)\n",
312
+ " s = ssim_loss(target_ir, gen_out)\n",
313
+ " adv = generator_adv_loss(fake_pred)\n",
314
+ " gen_total = p*20.0 + b*2 + w*10.0 + s*1 + adv*adv_weight\n",
315
+ "\n",
316
+ " disc_total = discriminator_loss(real_pred, fake_pred)\n",
317
+ "\n",
318
+ " gen_grads = tape.gradient(gen_total, generator.trainable_variables)\n",
319
+ " disc_grads = tape.gradient(disc_total, discriminator.trainable_variables)\n",
320
+ " gen_grads, _ = tf.clip_by_global_norm(gen_grads, 5.0)\n",
321
+ " disc_grads, _ = tf.clip_by_global_norm(disc_grads, 5.0)\n",
322
+ " gen_opt.apply_gradients(zip(gen_grads, generator.trainable_variables))\n",
323
+ " disc_opt.apply_gradients(zip(disc_grads, discriminator.trainable_variables))\n",
324
+ " return gen_total, disc_total\n",
325
+ "\n",
326
+ "@tf.function\n",
327
+ "def val_step(input_vis, target_ir, adv_weight=1.0):\n",
328
+ " gen_out = generator(input_vis, training=False)\n",
329
+ " fake_pred = discriminator([input_vis, gen_out], training=False)\n",
330
+ " p = perceptual_loss(target_ir, gen_out)\n",
331
+ " b = brightness_loss(target_ir, gen_out)\n",
332
+ " w = intensity_weighted_l1(target_ir, gen_out)\n",
333
+ " s = ssim_loss(target_ir, gen_out)\n",
334
+ " adv = generator_adv_loss(fake_pred)\n",
335
+ " gen_total = p*20.0 + b*2 + w*10.0 + s*1 + adv*adv_weight\n",
336
+ " return gen_total"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": 6,
342
+ "id": "f9166e0b",
343
+ "metadata": {
344
+ "execution": {
345
+ "iopub.execute_input": "2025-10-20T07:03:38.434108Z",
346
+ "iopub.status.busy": "2025-10-20T07:03:38.433926Z",
347
+ "iopub.status.idle": "2025-10-20T07:03:38.439152Z",
348
+ "shell.execute_reply": "2025-10-20T07:03:38.438642Z"
349
+ },
350
+ "papermill": {
351
+ "duration": 0.010044,
352
+ "end_time": "2025-10-20T07:03:38.440202",
353
+ "exception": false,
354
+ "start_time": "2025-10-20T07:03:38.430158",
355
+ "status": "completed"
356
+ },
357
+ "tags": []
358
+ },
359
+ "outputs": [],
360
+ "source": [
361
+ "# -------------------- UTILITIES --------------------\n",
362
+ "def to_uint8(x):\n",
363
+ " x = (x + 1.0) * 127.5\n",
364
+ " return tf.cast(tf.clip_by_value(x, 0, 255), tf.uint8)\n",
365
+ "\n",
366
+ "def save_sample_images(model, val_ds, epoch):\n",
367
+ " os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
368
+ " rows = []\n",
369
+ " for i, (v_inp, v_tar) in enumerate(val_ds.take(5)):\n",
370
+ " pred = model(v_inp, training=False)\n",
371
+ " vis = to_uint8(v_inp[0])\n",
372
+ " targ = to_uint8(v_tar[0])\n",
373
+ " gen = to_uint8(pred[0])\n",
374
+ " row = tf.concat([vis, targ, gen], axis=1)\n",
375
+ " rows.append(row)\n",
376
+ " grid = tf.concat(rows, axis=0)\n",
377
+ " out_path = os.path.join(OUTPUT_DIR, f\"epoch_{epoch:03d}.png\")\n",
378
+ " tf.keras.preprocessing.image.save_img(out_path, grid.numpy())\n",
379
+ " print(f\"πŸ–Ό Saved sample images to {out_path}\")"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": 7,
385
+ "id": "c09e0933",
386
+ "metadata": {
387
+ "execution": {
388
+ "iopub.execute_input": "2025-10-20T07:03:38.447653Z",
389
+ "iopub.status.busy": "2025-10-20T07:03:38.447428Z",
390
+ "iopub.status.idle": "2025-10-20T07:03:38.451198Z",
391
+ "shell.execute_reply": "2025-10-20T07:03:38.450576Z"
392
+ },
393
+ "papermill": {
394
+ "duration": 0.008507,
395
+ "end_time": "2025-10-20T07:03:38.452143",
396
+ "exception": false,
397
+ "start_time": "2025-10-20T07:03:38.443636",
398
+ "status": "completed"
399
+ },
400
+ "tags": []
401
+ },
402
+ "outputs": [],
403
+ "source": [
404
+ "import os\n",
405
+ "import tensorflow as tf\n",
406
+ "\n",
407
+ "# Paths\n",
408
+ "WORKING_CKPT_DIR = \"new/ckpt\"\n",
409
+ "os.makedirs(WORKING_CKPT_DIR, exist_ok=True)\n"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": 8,
415
+ "id": "de535ce2",
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": [
419
+ "ckpt = tf.train.Checkpoint(generator=generator, discriminator=discriminator, gen_opt=gen_opt, disc_opt=disc_opt)\n",
420
+ "manager = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=5)"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 9,
426
+ "id": "90b6be9b",
427
+ "metadata": {
428
+ "execution": {
429
+ "iopub.execute_input": "2025-10-20T07:03:38.472856Z",
430
+ "iopub.status.busy": "2025-10-20T07:03:38.472285Z",
431
+ "iopub.status.idle": "2025-10-20T07:03:39.135324Z",
432
+ "shell.execute_reply": "2025-10-20T07:03:39.134544Z"
433
+ },
434
+ "papermill": {
435
+ "duration": 0.668272,
436
+ "end_time": "2025-10-20T07:03:39.136644",
437
+ "exception": false,
438
+ "start_time": "2025-10-20T07:03:38.468372",
439
+ "status": "completed"
440
+ },
441
+ "tags": []
442
+ },
443
+ "outputs": [
444
+ {
445
+ "name": "stdout",
446
+ "output_type": "stream",
447
+ "text": [
448
+ "βœ… Loaded checkpoint from new/ckpt\\best_val.ckpt-42\n"
449
+ ]
450
+ }
451
+ ],
452
+ "source": [
453
+ "input_ckpt = tf.train.latest_checkpoint(WORKING_CKPT_DIR)\n",
454
+ "if input_ckpt:\n",
455
+ " ckpt.restore(input_ckpt).expect_partial()\n",
456
+ " print(f\"βœ… Loaded checkpoint from {input_ckpt}\")\n",
457
+ "else:\n",
458
+ " print(\"⚠️ No checkpoint found in /kaggle/input, starting fresh.\")"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 10,
464
+ "id": "9df742f2",
465
+ "metadata": {
466
+ "execution": {
467
+ "iopub.execute_input": "2025-10-20T07:03:39.144718Z",
468
+ "iopub.status.busy": "2025-10-20T07:03:39.144465Z",
469
+ "iopub.status.idle": "2025-10-20T07:03:39.151276Z",
470
+ "shell.execute_reply": "2025-10-20T07:03:39.150570Z"
471
+ },
472
+ "papermill": {
473
+ "duration": 0.01204,
474
+ "end_time": "2025-10-20T07:03:39.152328",
475
+ "exception": false,
476
+ "start_time": "2025-10-20T07:03:39.140288",
477
+ "status": "completed"
478
+ },
479
+ "tags": []
480
+ },
481
+ "outputs": [],
482
+ "source": [
483
+ "# -------------------- TRAIN LOOP --------------------\n",
484
+ "\n",
485
+ "from time import time \n",
486
+ "def train(train_ds, val_ds, epochs=EPOCHS):\n",
487
+ " best_val_loss = 50.0\n",
488
+ " step = 0\n",
489
+ " ckpt = tf.train.Checkpoint(generator=generator, discriminator=discriminator, gen_opt=gen_opt, disc_opt=disc_opt)\n",
490
+ " manager = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=5)\n",
491
+ "\n",
492
+ " for epoch in range(1, epochs + 1):\n",
493
+ " start = time()\n",
494
+ " g_losses, d_losses = [], []\n",
495
+ " for vis, ir in train_ds:\n",
496
+ " g_loss, d_loss = train_step(vis, ir)\n",
497
+ " g_losses.append(g_loss)\n",
498
+ " d_losses.append(d_loss)\n",
499
+ " step += 1\n",
500
+ " print(f\"time {time()-start} | G={tf.reduce_mean(g_losses):.4f} | D={tf.reduce_mean(d_losses):.4f}\")\n",
501
+ " \n",
502
+ " # Validation\n",
503
+ " val_losses = [val_step(v, i) for v, i in val_ds]\n",
504
+ " val_mean = tf.reduce_mean(val_losses)\n",
505
+ " print(f\"Epoch {epoch}/{epochs} | Val_loss={val_mean:.4f}\")\n",
506
+ "\n",
507
+ " # Save samples and checkpoints\n",
508
+ " save_sample_images(generator, val_ds, epoch)\n",
509
+ " if val_mean < best_val_loss:\n",
510
+ " best_val_loss = val_mean\n",
511
+ " ckpt_save_path = os.path.join(WORKING_CKPT_DIR, \"best_val.ckpt\")\n",
512
+ " ckpt.save(ckpt_save_path)\n",
513
+ " print(f\"πŸ† Best checkpoint updated at {ckpt_save_path} | val_loss={val_mean:.4f}\")\n",
514
+ " if epoch % SAVE_INTERVAL_EPOCHS == 0:\n",
515
+ " manager.save()\n",
516
+ " print(f\"πŸ’Ύ Checkpoint saved at epoch {epoch}\")\n",
517
+ "\n",
518
+ " generator.save(os.path.join(OUTPUT_DIR, \"generator_final.h5\"))\n",
519
+ " discriminator.save(os.path.join(OUTPUT_DIR, \"discriminator_final.h5\"))\n",
520
+ " print(\"βœ… Training complete and models saved!\")"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": 11,
526
+ "id": "a331caff",
527
+ "metadata": {
528
+ "execution": {
529
+ "iopub.execute_input": "2025-10-20T07:03:39.160218Z",
530
+ "iopub.status.busy": "2025-10-20T07:03:39.159814Z",
531
+ "iopub.status.idle": "2025-10-20T07:03:39.287611Z",
532
+ "shell.execute_reply": "2025-10-20T07:03:39.286778Z"
533
+ },
534
+ "papermill": {
535
+ "duration": 0.133149,
536
+ "end_time": "2025-10-20T07:03:39.289009",
537
+ "exception": false,
538
+ "start_time": "2025-10-20T07:03:39.155860",
539
+ "status": "completed"
540
+ },
541
+ "tags": []
542
+ },
543
+ "outputs": [],
544
+ "source": [
545
+ "#test_ds = make_dataset(f\"{BASE_DIR}/train/visible\", f\"{BASE_DIR}/train/infrared\")"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": 12,
551
+ "id": "8c10ff7d",
552
+ "metadata": {
553
+ "execution": {
554
+ "iopub.execute_input": "2025-10-20T07:03:39.297249Z",
555
+ "iopub.status.busy": "2025-10-20T07:03:39.296775Z",
556
+ "iopub.status.idle": "2025-10-20T07:03:39.301287Z",
557
+ "shell.execute_reply": "2025-10-20T07:03:39.300656Z"
558
+ },
559
+ "papermill": {
560
+ "duration": 0.009607,
561
+ "end_time": "2025-10-20T07:03:39.302268",
562
+ "exception": false,
563
+ "start_time": "2025-10-20T07:03:39.292661",
564
+ "status": "completed"
565
+ },
566
+ "tags": []
567
+ },
568
+ "outputs": [],
569
+ "source": [
570
+ "# import tensorflow as tf\n",
571
+ "\n",
572
+ "# num_elements = tf.data.experimental.cardinality(train_ds).numpy()\n",
573
+ "# print(num_elements)"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": 13,
579
+ "id": "9690af89",
580
+ "metadata": {
581
+ "execution": {
582
+ "iopub.execute_input": "2025-10-20T07:03:39.310045Z",
583
+ "iopub.status.busy": "2025-10-20T07:03:39.309644Z",
584
+ "iopub.status.idle": "2025-10-20T07:03:39.319137Z",
585
+ "shell.execute_reply": "2025-10-20T07:03:39.318632Z"
586
+ },
587
+ "papermill": {
588
+ "duration": 0.014508,
589
+ "end_time": "2025-10-20T07:03:39.320245",
590
+ "exception": false,
591
+ "start_time": "2025-10-20T07:03:39.305737",
592
+ "status": "completed"
593
+ },
594
+ "tags": []
595
+ },
596
+ "outputs": [],
597
+ "source": [
598
+ "# val_ds = train_ds.take(num_elements * 0.1)\n",
599
+ "# train_ds = train_ds.skip(num_elements * 0.1)"
600
+ ]
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "execution_count": null,
605
+ "id": "6a74164a",
606
+ "metadata": {
607
+ "papermill": {
608
+ "duration": 0.003313,
609
+ "end_time": "2025-10-20T07:03:39.327070",
610
+ "exception": false,
611
+ "start_time": "2025-10-20T07:03:39.323757",
612
+ "status": "completed"
613
+ },
614
+ "tags": []
615
+ },
616
+ "outputs": [],
617
+ "source": []
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": 14,
622
+ "id": "ddf341ff",
623
+ "metadata": {
624
+ "execution": {
625
+ "iopub.execute_input": "2025-10-20T07:03:39.334784Z",
626
+ "iopub.status.busy": "2025-10-20T07:03:39.334320Z",
627
+ "iopub.status.idle": "2025-10-20T07:05:21.060089Z",
628
+ "shell.execute_reply": "2025-10-20T07:05:21.059479Z"
629
+ },
630
+ "papermill": {
631
+ "duration": 101.734509,
632
+ "end_time": "2025-10-20T07:05:21.064884",
633
+ "exception": false,
634
+ "start_time": "2025-10-20T07:03:39.330375",
635
+ "status": "completed"
636
+ },
637
+ "tags": []
638
+ },
639
+ "outputs": [],
640
+ "source": [
641
+ "# import matplotlib.pyplot as plt\n",
642
+ "# import tensorflow as tf\n",
643
+ "# import numpy as np\n",
644
+ "\n",
645
+ "# def to_uint8(x):\n",
646
+ "# \"\"\"Convert tensor from [-1,1] β†’ uint8 [0,255].\"\"\"\n",
647
+ "# x = (x + 1.0) * 127.5\n",
648
+ "# x = tf.clip_by_value(x, 0, 255)\n",
649
+ "# return tf.cast(x, tf.uint8)\n",
650
+ "\n",
651
+ "# # Shuffle the dataset to take random samples\n",
652
+ "# train_ds_shuffled = train_ds.shuffle(buffer_size=1000, reshuffle_each_iteration=True)\n",
653
+ "\n",
654
+ "# # Take 10 images\n",
655
+ "# sample_ds = train_ds_shuffled.take(10)\n",
656
+ "\n",
657
+ "# # Plot RGB and IR pairs side by side\n",
658
+ "# fig, axes = plt.subplots(10, 2, figsize=(6, 30))\n",
659
+ "\n",
660
+ "# for i, (vis, ir) in enumerate(sample_ds):\n",
661
+ "# # If batch size > 1, take the first image in the batch\n",
662
+ "# vis_img = vis[0].numpy() if vis.shape[0] > 1 else vis.numpy()[0]\n",
663
+ "# ir_img = ir[0].numpy() if ir.shape[0] > 1 else ir.numpy()[0]\n",
664
+ "\n",
665
+ "# # Denormalize\n",
666
+ "# vis_img = to_uint8(vis_img)\n",
667
+ "# ir_img = to_uint8(ir_img)\n",
668
+ "# # print(vis_img)\n",
669
+ "# # print(ir_img)\n",
670
+ "\n",
671
+ "# # Plot RGB input\n",
672
+ "# axes[i, 0].imshow(vis_img.numpy())\n",
673
+ "# axes[i, 0].set_title('RGB Input')\n",
674
+ "# axes[i, 0].axis('off')\n",
675
+ "\n",
676
+ "# # Plot IR output\n",
677
+ "# if ir_img.shape[-1] == 1: # single-channel IR\n",
678
+ "# axes[i, 1].imshow(ir_img.numpy().squeeze(), cmap='gray')\n",
679
+ "# else:\n",
680
+ "# axes[i, 1].imshow(ir_img.numpy())\n",
681
+ "# axes[i, 1].set_title('IR Output')\n",
682
+ "# axes[i, 1].axis('off')\n",
683
+ "# os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
684
+ "# save_path = os.path.join(OUTPUT_DIR, 'rgb_ir_pairs.png')\n",
685
+ "# plt.savefig(save_path)\n",
686
+ "# plt.close() \n",
687
+ "# print(f\"Saved visualization to {save_path}\")"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": 15,
693
+ "id": "ba0c423b",
694
+ "metadata": {
695
+ "execution": {
696
+ "iopub.execute_input": "2025-10-20T07:05:21.073518Z",
697
+ "iopub.status.busy": "2025-10-20T07:05:21.072865Z",
698
+ "iopub.status.idle": "2025-10-20T13:26:23.243918Z",
699
+ "shell.execute_reply": "2025-10-20T13:26:23.243151Z"
700
+ },
701
+ "papermill": {
702
+ "duration": 22862.17716,
703
+ "end_time": "2025-10-20T13:26:23.245759",
704
+ "exception": false,
705
+ "start_time": "2025-10-20T07:05:21.068599",
706
+ "status": "completed"
707
+ },
708
+ "tags": []
709
+ },
710
+ "outputs": [],
711
+ "source": [
712
+ "#train(train_ds, val_ds, epochs=EPOCHS)"
713
+ ]
714
+ },
715
+ {
716
+ "cell_type": "code",
717
+ "execution_count": 16,
718
+ "id": "7a020d03",
719
+ "metadata": {
720
+ "execution": {
721
+ "iopub.execute_input": "2025-10-20T13:26:23.281045Z",
722
+ "iopub.status.busy": "2025-10-20T13:26:23.280518Z",
723
+ "iopub.status.idle": "2025-10-20T13:26:55.305284Z",
724
+ "shell.execute_reply": "2025-10-20T13:26:55.304379Z"
725
+ },
726
+ "papermill": {
727
+ "duration": 32.043319,
728
+ "end_time": "2025-10-20T13:26:55.306642",
729
+ "exception": false,
730
+ "start_time": "2025-10-20T13:26:23.263323",
731
+ "status": "completed"
732
+ },
733
+ "tags": []
734
+ },
735
+ "outputs": [],
736
+ "source": [
737
+ "import tensorflow as tf\n",
738
+ "from tqdm import tqdm\n",
739
+ "import numpy as np\n",
740
+ "\n",
741
+ "model = tf.keras.models.load_model(\"generator_final.h5\",compile=False)\n",
742
+ "\n",
743
+ "# Load test dataset\n",
744
+ "test_ds = make_dataset(f\"{BASE_DIR}/train/visible\", f\"{BASE_DIR}/train/infrared\",train=False)\n",
745
+ "def l1_loss(y_true, y_pred):\n",
746
+ " return tf.reduce_mean(tf.abs(y_true - y_pred))\n",
747
+ "def evaluate(test_ds):\n",
748
+ " l1_list, psnr_list, ssim_list = [], [], []\n",
749
+ " for vis, ir in tqdm(test_ds):\n",
750
+ " pred = generator(vis, training=False)\n",
751
+ " l1 = l1_loss(ir, pred).numpy()\n",
752
+ " psnr = tf.image.psnr(ir, pred, max_val=2.0).numpy()\n",
753
+ " ssim = tf.image.ssim(ir, pred, max_val=2.0).numpy()\n",
754
+ "\n",
755
+ " l1_list.append(l1)\n",
756
+ " psnr_list.append(np.mean(psnr))\n",
757
+ " ssim_list.append(np.mean(ssim))\n",
758
+ "\n",
759
+ " print(\"==== Test Dataset Metrics ====\")\n",
760
+ " print(f\"L1 Loss : {np.mean(l1_list):.4f}\")\n",
761
+ " print(f\"PSNR : {np.mean(psnr_list):.4f}\")\n",
762
+ " print(f\"SSIM : {np.mean(ssim_list):.4f}\")\n",
763
+ "\n",
764
+ "\n",
765
+ "# -------------------- TEST INFERENCE SIDE BY SIDE --------------------\n",
766
+ "def test_and_save_predictions(test_ds, save_dir=\"output/test_results\"):\n",
767
+ " os.makedirs(save_dir, exist_ok=True)\n",
768
+ " l1_list, psnr_list, ssim_list = [], [], []\n",
769
+ "\n",
770
+ " for idx, (vis, ir) in enumerate(tqdm(test_ds)):\n",
771
+ " pred = generator(vis, training=False)\n",
772
+ "\n",
773
+ " # Metrics\n",
774
+ " l1 = tf.reduce_mean(tf.abs(ir - pred)).numpy()\n",
775
+ " psnr = tf.reduce_mean(tf.image.psnr(ir, pred, max_val=2.0)).numpy()\n",
776
+ " ssim = tf.reduce_mean(tf.image.ssim(ir, pred, max_val=2.0)).numpy()\n",
777
+ " l1_list.append(l1)\n",
778
+ " psnr_list.append(psnr)\n",
779
+ " ssim_list.append(ssim)\n",
780
+ "\n",
781
+ " # Side-by-side image saving\n",
782
+ " for i in range(vis.shape[0]):\n",
783
+ " vis_img = to_uint8(vis[i])\n",
784
+ " ir_img = to_uint8(ir[i])\n",
785
+ " gen_img = to_uint8(pred[i])\n",
786
+ " row = tf.concat([vis_img, ir_img, gen_img], axis=1)\n",
787
+ " save_path = os.path.join(save_dir, f\"test_{idx*vis.shape[0]+i:03d}.png\")\n",
788
+ " tf.keras.preprocessing.image.save_img(save_path, row.numpy())\n",
789
+ "\n",
790
+ " print(\"==== Test Dataset Metrics ====\")\n",
791
+ " print(f\"L1 Loss : {np.mean(l1_list):.4f}\")\n",
792
+ " print(f\"PSNR : {np.mean(psnr_list):.4f}\")\n",
793
+ " print(f\"SSIM : {np.mean(ssim_list):.4f}\")\n",
794
+ " print(f\"All predictions saved to {save_dir}\")\n",
795
+ "\n",
796
+ "#evaluate(test_ds)\n"
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": 17,
802
+ "id": "b0a9832b",
803
+ "metadata": {
804
+ "execution": {
805
+ "iopub.execute_input": "2025-10-20T13:26:55.355553Z",
806
+ "iopub.status.busy": "2025-10-20T13:26:55.354912Z",
807
+ "iopub.status.idle": "2025-10-20T13:31:04.873369Z",
808
+ "shell.execute_reply": "2025-10-20T13:31:04.872355Z"
809
+ },
810
+ "papermill": {
811
+ "duration": 249.544369,
812
+ "end_time": "2025-10-20T13:31:04.874842",
813
+ "exception": false,
814
+ "start_time": "2025-10-20T13:26:55.330473",
815
+ "status": "completed"
816
+ },
817
+ "tags": []
818
+ },
819
+ "outputs": [],
820
+ "source": [
821
+ "#test_and_save_predictions(test_ds)"
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "execution_count": 18,
827
+ "id": "b3b09c3c",
828
+ "metadata": {
829
+ "execution": {
830
+ "iopub.execute_input": "2025-10-20T13:31:04.941670Z",
831
+ "iopub.status.busy": "2025-10-20T13:31:04.941353Z",
832
+ "iopub.status.idle": "2025-10-20T13:44:19.143047Z",
833
+ "shell.execute_reply": "2025-10-20T13:44:19.142083Z"
834
+ },
835
+ "papermill": {
836
+ "duration": 794.236367,
837
+ "end_time": "2025-10-20T13:44:19.144324",
838
+ "exception": false,
839
+ "start_time": "2025-10-20T13:31:04.907957",
840
+ "status": "completed"
841
+ },
842
+ "tags": []
843
+ },
844
+ "outputs": [],
845
+ "source": [
846
+ "#test_and_save_predictions(train_ds,save_dir=\"output/train_results\")"
847
+ ]
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": null,
852
+ "id": "e627ddf6",
853
+ "metadata": {},
854
+ "outputs": [
855
+ {
856
+ "name": "stdout",
857
+ "output_type": "stream",
858
+ "text": [
859
+ "\n",
860
+ "=== Evaluating checkpoint: new/ckpt\\ckpt-45 ===\n"
861
+ ]
862
+ },
863
+ {
864
+ "name": "stderr",
865
+ "output_type": "stream",
866
+ "text": [
867
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 752/752 [21:46<00:00, 1.74s/it]\n"
868
+ ]
869
+ },
870
+ {
871
+ "name": "stdout",
872
+ "output_type": "stream",
873
+ "text": [
874
+ "==== Test Dataset Metrics ====\n",
875
+ "L1 Loss : 0.0611\n",
876
+ "PSNR : 24.3096\n",
877
+ "SSIM : 0.8386\n",
878
+ "All predictions saved to op/ckpt-45\n"
879
+ ]
880
+ },
881
+ {
882
+ "ename": "TypeError",
883
+ "evalue": "cannot unpack non-iterable NoneType object",
884
+ "output_type": "error",
885
+ "traceback": [
886
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
887
+ "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
888
+ "Cell \u001b[1;32mIn[19], line 68\u001b[0m\n\u001b[0;32m 66\u001b[0m ckpt\u001b[38;5;241m.\u001b[39mrestore(ckpt_path)\u001b[38;5;241m.\u001b[39mexpect_partial()\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m=== Evaluating checkpoint: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mckpt_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m ===\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 68\u001b[0m l1, psnr, ssim \u001b[38;5;241m=\u001b[39m test_and_save_predictions(test_ds,save_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mop/\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m+\u001b[39mckpt_file\u001b[38;5;241m.\u001b[39mreplace(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.index\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m 69\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mL1 Loss : \u001b[39m\u001b[38;5;132;01m{\u001b[39;00ml1\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m | PSNR : \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpsnr\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m | SSIM : \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mssim\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 70\u001b[0m results\u001b[38;5;241m.\u001b[39mappend((ckpt_path, l1, psnr, ssim))\n",
889
+ "\u001b[1;31mTypeError\u001b[0m: cannot unpack non-iterable NoneType object"
890
+ ]
891
+ }
892
+ ],
893
+ "source": [
894
+ "import tensorflow as tf\n",
895
+ "from tqdm import tqdm\n",
896
+ "import numpy as np\n",
897
+ "import os\n",
898
+ "\n",
899
+ "# Assuming generator, discriminator, gen_opt, disc_opt are defined\n",
900
+ "ckpt = tf.train.Checkpoint(generator=generator,\n",
901
+ " discriminator=discriminator,\n",
902
+ " gen_opt=gen_opt,\n",
903
+ " disc_opt=disc_opt)\n",
904
+ "manager = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=5)\n",
905
+ "\n",
906
+ "\n",
907
+ "def l1_loss(y_true, y_pred):\n",
908
+ " return tf.reduce_mean(tf.abs(y_true - y_pred))\n",
909
+ "\n",
910
+ "def evaluate(test_ds):\n",
911
+ " l1_list, psnr_list, ssim_list = [], [], []\n",
912
+ " for vis, ir in tqdm(test_ds):\n",
913
+ " pred = generator(vis, training=False)\n",
914
+ " l1 = l1_loss(ir, pred).numpy()\n",
915
+ " psnr = tf.image.psnr(ir, pred, max_val=2.0).numpy()\n",
916
+ " ssim = tf.image.ssim(ir, pred, max_val=2.0).numpy()\n",
917
+ "\n",
918
+ " l1_list.append(l1)\n",
919
+ " psnr_list.append(np.mean(psnr))\n",
920
+ " ssim_list.append(np.mean(ssim))\n",
921
+ "\n",
922
+ " return np.mean(l1_list), np.mean(psnr_list), np.mean(ssim_list)\n",
923
+ "\n",
924
+ "# -------------------- TEST INFERENCE SIDE BY SIDE --------------------\n",
925
+ "def test_and_save_predictions(test_ds, save_dir=\"output/test_results\"):\n",
926
+ " os.makedirs(save_dir, exist_ok=True)\n",
927
+ " l1_list, psnr_list, ssim_list = [], [], []\n",
928
+ "\n",
929
+ " for idx, (vis, ir) in enumerate(tqdm(test_ds)):\n",
930
+ " pred = generator(vis, training=False)\n",
931
+ "\n",
932
+ " # Metrics\n",
933
+ " l1 = tf.reduce_mean(tf.abs(ir - pred)).numpy()\n",
934
+ " psnr = tf.reduce_mean(tf.image.psnr(ir, pred, max_val=2.0)).numpy()\n",
935
+ " ssim = tf.reduce_mean(tf.image.ssim(ir, pred, max_val=2.0)).numpy()\n",
936
+ " l1_list.append(l1)\n",
937
+ " psnr_list.append(psnr)\n",
938
+ " ssim_list.append(ssim)\n",
939
+ "\n",
940
+ " # Side-by-side image saving\n",
941
+ " for i in range(vis.shape[0]):\n",
942
+ " vis_img = to_uint8(vis[i])\n",
943
+ " ir_img = to_uint8(ir[i])\n",
944
+ " gen_img = to_uint8(pred[i])\n",
945
+ " row = tf.concat([vis_img, ir_img, gen_img], axis=1)\n",
946
+ " save_path = os.path.join(save_dir, f\"test_{idx*vis.shape[0]+i:03d}.png\")\n",
947
+ " tf.keras.preprocessing.image.save_img(save_path, row.numpy())\n",
948
+ "\n",
949
+ " print(\"==== Test Dataset Metrics ====\")\n",
950
+ " print(f\"L1 Loss : {np.mean(l1_list):.4f}\")\n",
951
+ " print(f\"PSNR : {np.mean(psnr_list):.4f}\")\n",
952
+ " print(f\"SSIM : {np.mean(ssim_list):.4f}\")\n",
953
+ " print(f\"All predictions saved to {save_dir}\")"
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": null,
959
+ "id": "e363024c",
960
+ "metadata": {},
961
+ "outputs": [],
962
+ "source": [
963
+ "# === Evaluating checkpoint: new/ckpt\\best_val.ckpt-42 ===\n",
964
+ "# 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 752/752 [21:29<00:00, 1.71s/it]\n",
965
+ "# ==== Test Dataset Metrics ====\n",
966
+ "# L1 Loss : 0.0613\n",
967
+ "# PSNR : 24.3060\n",
968
+ "# SSIM : 0.8382\n",
969
+ "# All predictions saved to op/best_val.ckpt-42"
970
+ ]
971
+ },
972
+ {
973
+ "cell_type": "code",
974
+ "execution_count": null,
975
+ "id": "220843eb",
976
+ "metadata": {},
977
+ "outputs": [],
978
+ "source": [
979
+ "# === Evaluating checkpoint: new/ckpt\\ckpt-45 ===\n",
980
+ "# 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 752/752 [21:46<00:00, 1.74s/it]\n",
981
+ "# ==== Test Dataset Metrics ====\n",
982
+ "# L1 Loss : 0.0611\n",
983
+ "# PSNR : 24.3096\n",
984
+ "# SSIM : 0.8386\n",
985
+ "# All predictions saved to op/ckpt-45"
986
+ ]
987
+ }
988
+ ],
989
+ "metadata": {
990
+ "kaggle": {
991
+ "accelerator": "gpu",
992
+ "dataSources": [
993
+ {
994
+ "datasetId": 8436032,
995
+ "sourceId": 13308561,
996
+ "sourceType": "datasetVersion"
997
+ },
998
+ {
999
+ "modelId": 472226,
1000
+ "modelInstanceId": 456192,
1001
+ "sourceId": 607810,
1002
+ "sourceType": "modelInstanceVersion"
1003
+ },
1004
+ {
1005
+ "isSourceIdPinned": true,
1006
+ "modelId": 477033,
1007
+ "modelInstanceId": 461278,
1008
+ "sourceId": 613905,
1009
+ "sourceType": "modelInstanceVersion"
1010
+ }
1011
+ ],
1012
+ "dockerImageVersionId": 31154,
1013
+ "isGpuEnabled": true,
1014
+ "isInternetEnabled": false,
1015
+ "language": "python",
1016
+ "sourceType": "notebook"
1017
+ },
1018
+ "kernelspec": {
1019
+ "display_name": "base",
1020
+ "language": "python",
1021
+ "name": "python3"
1022
+ },
1023
+ "language_info": {
1024
+ "codemirror_mode": {
1025
+ "name": "ipython",
1026
+ "version": 3
1027
+ },
1028
+ "file_extension": ".py",
1029
+ "mimetype": "text/x-python",
1030
+ "name": "python",
1031
+ "nbconvert_exporter": "python",
1032
+ "pygments_lexer": "ipython3",
1033
+ "version": "3.12.11"
1034
+ },
1035
+ "papermill": {
1036
+ "default_parameters": {},
1037
+ "duration": 24066.997657,
1038
+ "end_time": "2025-10-20T13:44:23.261563",
1039
+ "environment_variables": {},
1040
+ "exception": null,
1041
+ "input_path": "__notebook__.ipynb",
1042
+ "output_path": "__notebook__.ipynb",
1043
+ "parameters": {},
1044
+ "start_time": "2025-10-20T07:03:16.263906",
1045
+ "version": "2.6.0"
1046
+ }
1047
+ },
1048
+ "nbformat": 4,
1049
+ "nbformat_minor": 5
1050
+ }