# The mean squared error for PCA is 0.8399
# We now try and improve on PCA using an autoecoder where there are two latent variables
# The are two hidden layers between inputs and latent variables (with 6 and 4 neurons)
# and two hidden layers (with 4 and 6 neurons) between the latent variable and output
# Activation function is leaky Relu withan alpha parameter of 0.7.
# Set up layers with 6, 4 and 2 neurons for encoder
treas_encoder = keras.models.Sequential([
Dense(6,input_shape=[len(raw_data.columns)],activation=LeakyReLU(alpha=0.7)),
Dense(4,input_shape=[6],activation=LeakyReLU(alpha=0.7)),
Dense(2,input_shape=[4],activation=LeakyReLU(alpha=0.7))])
# Set up layers with 4, 6, and 8 neurons for decoder
treas_decoder = keras.models.Sequential([
Dense(4,input_shape=[2],activation=LeakyReLU(alpha=0.7)),
Dense(6,input_shape=[4],activation=LeakyReLU(alpha=0.7)),
Dense(8,input_shape=[6],activation=LeakyReLU(alpha=0.7))])
# Set up autoencoder
treas_autoencoder = keras.models.Sequential([treas_encoder,treas_decoder])
treas_autoencoder.compile(loss = "mse", optimizer = "adam")
# Checkpoint function is used here to periodically save a copy of the model.
# Currently it is set to save the best performing model
checkpoint_cb = keras.callbacks.ModelCheckpoint("treas_autoencoder_leakyrelu_multi_layer_vFinal_v2.h5",save_best_only = True, monitor='loss')
# Early stopping stopsr training early if no improvment is shown after a number of epochs equal to patience
# The model then reverts back to the best weights
early_stopping_cb = keras.callbacks.EarlyStopping(monitor='loss',patience = 500,restore_best_weights = True,verbose=1)
# Epochs specifies the maximum number of epochs
treas_history = treas_autoencoder.fit(raw_data,raw_data,epochs = 5000, callbacks=[checkpoint_cb,early_stopping_cb], verbose=0)
treas_autoencoder = keras.models.load_model("treas_autoencoder_leakyrelu_multi_layer_vFinal_v2.h5",custom_objects={'LeakyReLU': LeakyReLU(alpha=0.7)})
mse_test = treas_autoencoder.evaluate(raw_data,raw_data,verbose=0)
print('Neural network mean squared error:', mse_test)