#Ejemplo: Simple CNN para el problema del MNIST
#Mis ambientes: env_pi38_tf25 (Python 3.8 con TensorFlow 2.5)

# Importing libraries
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow.keras.utils import plot_model

# Loading the MNIST dataset
path = 'D:/fcoj23/CIMAT_MERIDA/Investigacion/objectDetection/deepLearning/CNNs/mnist.npz'
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(path)

# Preprocessing the data
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255

# Building the CNN model V1
model = models.Sequential([
    layers.Conv2D(2, (3, 3), activation='relu',
                  input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(4, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(4, activation='relu'),
    layers.Dense(10, activation='softmax')
])

'''
# Building the CNN model V2
model = models.Sequential([
    layers.Conv2D(4, (3, 3), activation='relu',
                  input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(8, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(8, activation='relu'),
    layers.Dense(10, activation='softmax')
])
'''
'''
# Building the CNN model V3
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', 
                  input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])'''
plot_model(model,show_shapes=True, 
           to_file='model_basic_CNN.png')

# Compiling the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Training the model
history = model.fit(train_images, train_labels, 
                    epochs=10, batch_size=64, 
                    validation_split=0.2)

# Evaluating the model
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Test Loss: {test_loss}')
print(f'Test Accuracy: {test_acc}')

#  "Accuracy"
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
# "Loss"
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()