Spaces:
Runtime error
Runtime error
| #use this to make and train a MIRnet model | |
| import cv2 | |
| import random | |
| import numpy as np | |
| from glob import glob | |
| from PIL import Image, ImageOps | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from google.colab import drive | |
| drive.mount('/content/gdrive') | |
| random.seed(10) | |
| IMAGE_SIZE = 128 | |
| BATCH_SIZE = 4 | |
| MAX_TRAIN_IMAGES = 300 | |
| def read_image(image_path): | |
| image = tf.io.read_file(image_path) | |
| image = tf.image.decode_png(image, channels=3) | |
| image.set_shape([None, None, 3]) | |
| image = tf.cast(image, dtype=tf.float32) / 255.0 | |
| return image | |
| def random_crop(low_image, enhanced_image): | |
| low_image_shape = tf.shape(low_image)[:2] | |
| low_w = tf.random.uniform( | |
| shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32 | |
| ) | |
| low_h = tf.random.uniform( | |
| shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32 | |
| ) | |
| enhanced_w = low_w | |
| enhanced_h = low_h | |
| low_image_cropped = low_image[ | |
| low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE | |
| ] | |
| enhanced_image_cropped = enhanced_image[ | |
| enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE | |
| ] | |
| return low_image_cropped, enhanced_image_cropped | |
| def load_data(low_light_image_path, enhanced_image_path): | |
| low_light_image = read_image(low_light_image_path) | |
| enhanced_image = read_image(enhanced_image_path) | |
| low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image) | |
| return low_light_image, enhanced_image | |
| def get_dataset(low_light_images, enhanced_images): | |
| dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images)) | |
| dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE) | |
| dataset = dataset.batch(BATCH_SIZE, drop_remainder=True) | |
| return dataset | |
| train_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES] | |
| train_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES] | |
| val_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:] | |
| val_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:] | |
| test_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/low/*")) | |
| test_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/high/*")) | |
| train_dataset = get_dataset(train_low_light_images, train_enhanced_images) | |
| val_dataset = get_dataset(val_low_light_images, val_enhanced_images) | |
| print("Train Dataset:", train_dataset) | |
| print("Val Dataset:", val_dataset) | |
| def selective_kernel_feature_fusion( | |
| multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3 | |
| ): | |
| channels = list(multi_scale_feature_1.shape)[-1] | |
| combined_feature = layers.Add()( | |
| [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3] | |
| ) | |
| gap = layers.GlobalAveragePooling2D()(combined_feature) | |
| channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels)) | |
| compact_feature_representation = layers.Conv2D( | |
| filters=channels // 8, kernel_size=(1, 1), activation="relu" | |
| )(channel_wise_statistics) | |
| feature_descriptor_1 = layers.Conv2D( | |
| channels, kernel_size=(1, 1), activation="softmax" | |
| )(compact_feature_representation) | |
| feature_descriptor_2 = layers.Conv2D( | |
| channels, kernel_size=(1, 1), activation="softmax" | |
| )(compact_feature_representation) | |
| feature_descriptor_3 = layers.Conv2D( | |
| channels, kernel_size=(1, 1), activation="softmax" | |
| )(compact_feature_representation) | |
| feature_1 = multi_scale_feature_1 * feature_descriptor_1 | |
| feature_2 = multi_scale_feature_2 * feature_descriptor_2 | |
| feature_3 = multi_scale_feature_3 * feature_descriptor_3 | |
| aggregated_feature = layers.Add()([feature_1, feature_2, feature_3]) | |
| return aggregated_feature | |
| def spatial_attention_block(input_tensor): | |
| average_pooling = tf.reduce_max(input_tensor, axis=-1) | |
| average_pooling = tf.expand_dims(average_pooling, axis=-1) | |
| max_pooling = tf.reduce_mean(input_tensor, axis=-1) | |
| max_pooling = tf.expand_dims(max_pooling, axis=-1) | |
| concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling]) | |
| feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated) | |
| feature_map = tf.nn.sigmoid(feature_map) | |
| return input_tensor * feature_map | |
| def channel_attention_block(input_tensor): | |
| channels = list(input_tensor.shape)[-1] | |
| average_pooling = layers.GlobalAveragePooling2D()(input_tensor) | |
| feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels)) | |
| feature_activations = layers.Conv2D( | |
| filters=channels // 8, kernel_size=(1, 1), activation="relu" | |
| )(feature_descriptor) | |
| feature_activations = layers.Conv2D( | |
| filters=channels, kernel_size=(1, 1), activation="sigmoid" | |
| )(feature_activations) | |
| return input_tensor * feature_activations | |
| def dual_attention_unit_block(input_tensor): | |
| channels = list(input_tensor.shape)[-1] | |
| feature_map = layers.Conv2D( | |
| channels, kernel_size=(3, 3), padding="same", activation="relu" | |
| )(input_tensor) | |
| feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")( | |
| feature_map | |
| ) | |
| channel_attention = channel_attention_block(feature_map) | |
| spatial_attention = spatial_attention_block(feature_map) | |
| concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention]) | |
| concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation) | |
| return layers.Add()([input_tensor, concatenation]) | |
| # Recursive Residual Modules | |
| def down_sampling_module(input_tensor): | |
| channels = list(input_tensor.shape)[-1] | |
| main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( | |
| input_tensor | |
| ) | |
| main_branch = layers.Conv2D( | |
| channels, kernel_size=(3, 3), padding="same", activation="relu" | |
| )(main_branch) | |
| main_branch = layers.MaxPooling2D()(main_branch) | |
| main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch) | |
| skip_branch = layers.MaxPooling2D()(input_tensor) | |
| skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch) | |
| return layers.Add()([skip_branch, main_branch]) | |
| def up_sampling_module(input_tensor): | |
| channels = list(input_tensor.shape)[-1] | |
| main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( | |
| input_tensor | |
| ) | |
| main_branch = layers.Conv2D( | |
| channels, kernel_size=(3, 3), padding="same", activation="relu" | |
| )(main_branch) | |
| main_branch = layers.UpSampling2D()(main_branch) | |
| main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch) | |
| skip_branch = layers.UpSampling2D()(input_tensor) | |
| skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch) | |
| return layers.Add()([skip_branch, main_branch]) | |
| # MRB Block | |
| def multi_scale_residual_block(input_tensor, channels): | |
| # features | |
| level1 = input_tensor | |
| level2 = down_sampling_module(input_tensor) | |
| level3 = down_sampling_module(level2) | |
| # DAU | |
| level1_dau = dual_attention_unit_block(level1) | |
| level2_dau = dual_attention_unit_block(level2) | |
| level3_dau = dual_attention_unit_block(level3) | |
| # SKFF | |
| level1_skff = selective_kernel_feature_fusion( | |
| level1_dau, | |
| up_sampling_module(level2_dau), | |
| up_sampling_module(up_sampling_module(level3_dau)), | |
| ) | |
| level2_skff = selective_kernel_feature_fusion( | |
| down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau) | |
| ) | |
| level3_skff = selective_kernel_feature_fusion( | |
| down_sampling_module(down_sampling_module(level1_dau)), | |
| down_sampling_module(level2_dau), | |
| level3_dau, | |
| ) | |
| # DAU 2 | |
| level1_dau_2 = dual_attention_unit_block(level1_skff) | |
| level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff))) | |
| level3_dau_2 = up_sampling_module( | |
| up_sampling_module(dual_attention_unit_block(level3_skff)) | |
| ) | |
| # SKFF 2 | |
| skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2) | |
| conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_) | |
| return layers.Add()([input_tensor, conv]) | |
| def recursive_residual_group(input_tensor, num_mrb, channels): | |
| conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor) | |
| for _ in range(num_mrb): | |
| conv1 = multi_scale_residual_block(conv1, channels) | |
| conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1) | |
| return layers.Add()([conv2, input_tensor]) | |
| def mirnet_model(num_rrg, num_mrb, channels): | |
| input_tensor = keras.Input(shape=[None, None, 3]) | |
| x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor) | |
| for _ in range(num_rrg): | |
| x1 = recursive_residual_group(x1, num_mrb, channels) | |
| conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1) | |
| output_tensor = layers.Add()([input_tensor, conv]) | |
| return keras.Model(input_tensor, output_tensor) | |
| model = mirnet_model(num_rrg=3, num_mrb=2, channels=64) | |
| def charbonnier_loss(y_true, y_pred): | |
| return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3))) | |
| def peak_signal_noise_ratio(y_true, y_pred): | |
| return tf.image.psnr(y_pred, y_true, max_val=255.0) | |
| optimizer = keras.optimizers.Adam(learning_rate=1e-4) | |
| model.compile( | |
| optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio] | |
| ) | |
| history = model.fit( | |
| train_dataset, | |
| validation_data=val_dataset, | |
| #epochs traning cycles set krna k lia | |
| epochs=1, | |
| callbacks=[ | |
| keras.callbacks.ReduceLROnPlateau( | |
| monitor="val_peak_signal_noise_ratio", | |
| factor=0.5, | |
| patience=5, | |
| verbose=1, | |
| min_delta=1e-7, | |
| mode="max", | |
| ) | |
| ], | |
| ) | |
| plt.plot(history.history["loss"], label="train_loss") | |
| plt.plot(history.history["val_loss"], label="val_loss") | |
| plt.xlabel("Epochs") | |
| plt.ylabel("Loss") | |
| plt.title("Train and Validation Losses Over Epochs", fontsize=14) | |
| plt.legend() | |
| plt.grid() | |
| plt.show() | |
| plt.plot(history.history["peak_signal_noise_ratio"], label="train_psnr") | |
| plt.plot(history.history["val_peak_signal_noise_ratio"], label="val_psnr") | |
| plt.xlabel("Epochs") | |
| plt.ylabel("PSNR") | |
| plt.title("Train and Validation PSNR Over Epochs", fontsize=14) | |
| plt.legend() | |
| plt.grid() | |
| plt.show() | |
| def plot_results(images, titles, figure_size=(12, 12)): | |
| fig = plt.figure(figsize=figure_size) | |
| for i in range(len(images)): | |
| fig.add_subplot(1, len(images), i + 1).set_title(titles[i]) | |
| _ = plt.imshow(images[i]) | |
| plt.axis("off") | |
| plt.show() | |
| def infer(original_image): | |
| image = keras.preprocessing.image.img_to_array(original_image) | |
| image = image.astype("float16") / 255.0 | |
| image = np.expand_dims(image, axis=0) | |
| output = model.predict(image) | |
| output_image = output[0] * 255.0 | |
| output_image = output_image.clip(0, 255) | |
| output_image = output_image.reshape( | |
| (np.shape(output_image)[0], np.shape(output_image)[1], 3) | |
| ) | |
| output_image = Image.fromarray(np.uint8(output_image)) | |
| original_image = Image.fromarray(np.uint8(original_image)) | |
| return output_image | |
| for low_light_image in random.sample(test_low_light_images, 2): | |
| original_image = Image.open(low_light_image) | |
| enhanced_image = infer(original_image) | |
| plot_results( | |
| [original_image, ImageOps.autocontrast(original_image), enhanced_image], | |
| ["Original", "PIL Autocontrast", "MIRNet Enhanced"], | |
| (20, 12), | |
| ) |