You are on page 1of 2

import numpy as np

import tensorflow as tf
from tensorflow import keras
import os
import matplotlib.pyplot as plt
import cv2
from skimage.transform import resize
import pathlib

from keras.datasets import cifar10


from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras.layers import UpSampling2D, Dense, Flatten,
BatchNormalization, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam, SGD

trainingData, testingData = cifar10.load_data()

X_train = preprocess_input(trainingData[0])
X_test = preprocess_input(testingData[0])
y_train = to_categorical(trainingData[1])
y_test = to_categorical(testingData[1])

resnetModel = None
resnetModel = Sequential()
resnetModel.add(UpSampling2D())
resnetModel.add(UpSampling2D())
resnetModel.add(UpSampling2D())

model = ResNet50(weights = 'imagenet',


include_top = False,
pooling = 'max',
classes = 10)

for layer in model.layers:


layer.trainable = False

resnetModel.add(model)
resnetModel.add(Flatten())
resnetModel.add(BatchNormalization())
resnetModel.add(Dense(128, activation='relu'))
resnetModel.add(Dropout(0.5))
resnetModel.add(BatchNormalization())
resnetModel.add(Dense(64, activation='relu'))
resnetModel.add(Dropout(0.5))
resnetModel.add(BatchNormalization())
resnetModel.add(Dense(10, activation = 'softmax'))

resnetModel.compile(optimizer = Adam(lr = 0.001), loss =


'categorical_crossentropy', metrics = ['accuracy'])
resnetModel.build(input_shape = (None, 32 ,32 ,3))
resnetModel.summary()

history = resnetModel.fit(x = X_train, y = y_train, epochs = 5)


fig1 = plt.gcf()
plt.plot(history.history['accuracy'])
plt.plot(history.history['loss'])
plt.axis(ymin=0.4,ymax=1)
plt.grid()
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.legend(['train', 'loss'])
plt.show()

You might also like