In [None]:

import os
import sys
import argparse
import numpy as np
from glob import glob
from skimage.io import imread
from skimage.transform import resize

from tqdm.keras import TqdmCallback

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# import tensorflow_addons as tfa
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from utils import *


In [None]:
# exampl:
# python cyclegan.py --dataset planck/ --lx $lx --ly $lx --epochs 300 --BS 32 --filters 16 --ndb 2 --nrb 3 --nub 2 --nd 3 --model $lx-l1

if not in_notebook():
    import argparse
    parser = argparse.ArgumentParser(description='MODEL ACTIVITY ANALYZER.')
    parser.add_argument('--dataset', default='./dataset', type=str, help='path to dataset')
    parser.add_argument('--model', default='model file name', type=str, help='model file name')
    parser.add_argument('--lx', default=0, type=int, help='image length')
    parser.add_argument('--ly', default=0, type=int, help='image width')
    parser.add_argument('--epochs', default=200, type=int, help='number of epochs')
    parser.add_argument('--BS', default=32, type=int, help='number of epochs')
    parser.add_argument('--filters', default=64, type=int, help='number of epochs')
    parser.add_argument('--ndb', default=2, type=int, help='number of epochs') 
    parser.add_argument('--nrb', default=9, type=int, help='number of epochs') 
    parser.add_argument('--nub', default=2, type=int, help='number of epochs') 
    parser.add_argument('--nd', default=3, type=int, help='number of epochs') 
    
#    parser.add_argument('--prefix', default='', type=str, help='path to save the results')
#     parser.add_argument('--deep', default=0, type=int, help='Network depth!')
#     parser.add_argument('--dpi', default=200, type=int, help='image dpi')
    parser.add_argument('--restart', action="store_true")

    args = parser.parse_args()
    data_path = args.dataset
    lx,ly = args.lx,args.ly
    restart = args.restart
    epochs = args.epochs
    batch_size = args.BS
    filters = args.filters
    num_downsampling_blocks = args.ndb
    num_residual_blocks = args.nrb
    num_upsample_blocks = args.nub
    num_downsampling = args.nd
    
    mname = args.model
    
#     dpi = args.dpi
#     DEEP = args.deep
else:
    data_path = ''
    lx,ly = 64,64
    restart = 0
    epochs = 50
    batch_size = 32

    filters = 16
    num_downsampling_blocks = 2
    num_residual_blocks = 3
    num_upsample_blocks = 2
    num_downsampling = 3

    mname = '64-l1'

In [None]:
!mkdir -p models

In [None]:
PREFIX = 'models/'
ch_mkdir(PREFIX[:-1])

def blocker(x,nside):
    xx = np.array_split(x, nside, axis=1)
    xx = np.concatenate(xx,axis=0)
    xx = np.array_split(xx, nside, axis=2)
    xx = np.concatenate(xx,axis=0)
    return xx


csep = 'healpix'
train_x0 = np.load(data_path+csep+'.npy')[:10]

csep = 'sevem'
train_x1 = np.load(data_path+csep+'.npy')[:10]

train_x0 = blocker(train_x0,2048//lx)
train_x1 = blocker(train_x1,2048//lx)

train_x0 = train_x0-train_x0.min()
train_x0 = train_x0/train_x0.max()
train_x0 = 2*train_x0-1
train_x0 = train_x0[:,:,:,None]

train_x1 = train_x1-train_x1.min()
train_x1 = train_x1/train_x1.max()
train_x1 = 2*train_x1-1
train_x1 = train_x1[:,:,:,None]

test_x0 = train_x0[:20]
test_x1 = train_x1[:20]

print(train_x0.shape,train_x1.shape)

In [None]:
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(12,6))
irr = np.random.randint(train_x0.shape[0])
ax1.imshow(train_x0[irr],cmap='jet')
ax2.imshow(train_x1[irr],cmap='jet')
plt.tight_layout()
# plt.savefig('test.jpg')

In [None]:
#input_img_size = (256, 256, 1)
input_img_size = train_x0.shape[1:]

buffer_size = 256

# Get the generators
gen_G = get_resnet_generator(input_img_size,
                             filters=filters,
                             num_downsampling_blocks=num_downsampling_blocks,
                             num_residual_blocks=num_residual_blocks,
                             num_upsample_blocks=num_upsample_blocks,
                             name="generator_G")
                             
gen_F = get_resnet_generator(input_img_size,
                             filters=filters,
                             num_downsampling_blocks=num_downsampling_blocks,
                             num_residual_blocks=num_residual_blocks,
                             num_upsample_blocks=num_upsample_blocks,
                             name="generator_F")

# Get the discriminators
disc_X = get_discriminator(input_img_size,
                           filters=filters,
                           kernel_initializer=kernel_init,
                           num_downsampling=num_downsampling,
                           name="discriminator_X")
disc_Y = get_discriminator(input_img_size,
                           filters=filters,
                           kernel_initializer=kernel_init,
                           num_downsampling=num_downsampling,
                           name="discriminator_Y")

# Loss function for evaluating adversarial loss
adv_loss_fn = keras.losses.MeanSquaredError()

# Define the loss function for the generators
def generator_loss_fn(fake):
    fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
    return fake_loss


# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
    real_loss = adv_loss_fn(tf.ones_like(real), real)
    fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5

# Create cycle gan model
cycle_gan_model = CycleGan(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

# Compile the model
cycle_gan_model.compile(
    gen_G_optimizer=keras.optimizers.Adam(learning_rate=5e-5, beta_1=0.5),
    gen_F_optimizer=keras.optimizers.Adam(learning_rate=5e-5, beta_1=0.5),
    disc_X_optimizer=keras.optimizers.Adam(learning_rate=5e-5, beta_1=0.5),
    disc_Y_optimizer=keras.optimizers.Adam(learning_rate=5e-5, beta_1=0.5),
    gen_loss_fn=generator_loss_fn,
    disc_loss_fn=discriminator_loss_fn,
)

# if os.path.exists(PREFIX+'{}'.format(mname)):
#     cycle_gan_model.loadit(PREFIX+'{}'.format(mname))

# fake_train_x1 = gen_G(real_train_x0)
# fake_train_x0 = gen_F(real_train_x1)

In [None]:
cycle_gan_model.fit(train_x0, train_x1,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=0,
                    callbacks=[TqdmCallback(verbose=0)]
                    )

cycle_gan_model.saveit(PREFIX+'{}'.format(mname))

In [None]:
cycle_gan_model.loadit(PREFIX+'{}'.format(mname))

In [None]:
_, ax = plt.subplots(4, 3, figsize=(12, 15))
#for i, img in enumerate(test_horses.take(4)):
for i in range(4):
    img = test_x0[i:i+1]
    prediction = np.array(cycle_gan_model.gen_G(img, training=False)[0])
    ax[i, 0].imshow(img[0],cmap='jet')
    ax[i, 1].imshow(prediction,cmap='jet')
    ax[i, 2].imshow(np.abs(img[0]-prediction),cmap='jet')
    ax[i, 0].set_title("Input image")
    ax[i, 1].set_title("Translated image")
    ax[i, 2].set_title("Difference")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.tight_layout()

In [None]:
_, ax = plt.subplots(4, 3, figsize=(12, 15))
for i in range(4):
    img = test_x1[i:i+1]
    prediction = np.array(cycle_gan_model.gen_F(img, training=False)[0])
    ax[i, 0].imshow(img[0],cmap='jet')
    ax[i, 1].imshow(prediction,cmap='jet')
    ax[i, 2].imshow(np.abs(img[0]-prediction),cmap='jet')
    ax[i, 0].set_title("Input image")
    ax[i, 1].set_title("Translated image")
    ax[i, 2].set_title("Difference")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")
plt.tight_layout()