반응형
4. VGG19
## import
import numpy as np
import keras
import tensorflow as tf
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from keras.applications import VGG19
from tensorflow.keras.optimizers import Adam, SGD
## set
batch_size = 128
num_classes = 100
epochs = 50
learn_rate=.001
## splilt
(X_train, y_train), (X_test, y_test) = cifar100.load_data()
## pixel to 0~1
X_train = X_train / 255.0
X_test = X_test / 255.0
## Build model
vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=(32,32,3), classes=y_train.shape[1])
model= Sequential()
model.add(vgg19)
model.add(Flatten())
# Dense layers
model.add(Dense(1024,activation=('relu'),input_dim=512))
model.add(Dense(512,activation=('relu')))
model.add(Dense(256,activation=('relu')))
model.add(Dropout(.3))
model.add(Dense(128,activation=('relu')))
model.add(Dropout(.2))
model.add(Dense(num_classes,activation=('softmax')))
## summary
model_1.summary()
## Train
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1)
adam=Adam(lr=learn_rate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
model_1.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model_1.fit(X_train, y_train, batch_size=batch_size,
validation_data=(X_test, y_test),
callbacks=[early_stop],
epochs=epochs, verbose=1)
## Visualize history
import matplotlib.pyplot as plt
f,ax=plt.subplots(2,1) #Creates 2 subplots under 1 column
ax[0].plot(model_1.history.history['loss'],color='b',label='Training Loss')
ax[0].plot(model_1.history.history['val_loss'],color='r',label='Validation Loss')
ax[1].plot(model_1.history.history['accuracy'],color='b',label='Training Accuracy')
ax[1].plot(model_1.history.history['val_accuracy'],color='r',label='Validation Accuracy')
반응형
'인공지능 > CV' 카테고리의 다른 글
객체탐지 (Object Detection) 2. YOLO !! (v1~v3) (2) | 2021.05.01 |
---|---|
객체탐지 (Object Detection) 1. YOLO 이전 까지 흐름 (0) | 2021.05.01 |
MASK RCNN 실행시 버전오류 (0) | 2021.04.20 |
[CODE] EfficientNetB7 (0) | 2021.04.10 |
[CODE] ANN, K-fold CV, CNN (0) | 2021.04.10 |
댓글