Keras predict_generator로 confusion matrix를 그렸을 때 accuracy가 안 맞는 오류 해결방안 정리

1. 문제

ImageDataGenerator로 데이터셋을 생성할 경우 predict_classes 대신 predict_generator를 이용해 테스트 클래스를 예측하게 되는데, 이 때 evaluate_generator로 얻은 accuracy와 sklearnconfusion_matrix를 통해 계산한 accuracy가 일치하지 않는 문제가 발생했다.

2. 원인 및 해결방안 (Causes & Solutions)

https://groups.google.com/forum/#!topic/keras-users/bqWwFox_zZs

Tutorial on using Keras flow_from_directory and generators

1) 제너레이터 생성 시, shuffle=False로 지정되어있는지 확인한다.

valid_datagen = ImageDataGenerator(rescale=1.0 / 255)

validation_generator = valid_datagen.flow_from_directory(
    validation_dir,
    target_size=(height, width),
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=False,  # For evaluation
)

2) predict_generator를 호출하기 전에 반드시 제너레이터를 reset해야 한다!!!

validation_generator.reset()
Y_pred = model.predict_generator(validation_generator, STEP_SIZE_VALID+1)#validation_generator.n // validation_generator.batch_size+1)
classes = validation_generator.classes[validation_generator.index_array]
y_pred = np.argmax(Y_pred, axis=-1)  # Returns maximum indices in each row

4. 참고

Keras ImageDataGenerator

Image Preprocessing - Keras Documentation

파일 구조를 이용해 자동으로 이미지 데이터셋을 읽고 라벨링할 수 있다.

Keras Model 저장하기 & 읽어서 evaluate하기

How to Save and Load Your Keras Deep Learning Model

Confusion Matrix란?

Confusion Matrix in Machine Learning - GeeksforGeeks

Scikit-learn으로 Confusion Matrix 그리기

sklearn.metrics.confusion_matrix — scikit-learn 0.21.3 documentation

Confusion matrix — scikit-learn 0.21.3 documentation

Scikit-learn으로 Classification Report 뽑기

sklearn.metrics.classification_report — scikit-learn 0.21.3 documentation

주요 코드

print("-- Evaluate --")
STEP_SIZE_VALID = validation_generator.n // validation_generator.batch_size
scores = model.evaluate_generator(generator=validation_generator, steps=STEP_SIZE_VALID)
print("%s: %.2f%%" %(model.metrics_names[1], scores[1]*100))


#Confusion Matrix and Classification Report

np.set_printoptions(precision=2)

validation_generator.reset()
Y_pred = model.predict_generator(validation_generator, STEP_SIZE_VALID+1)#validation_generator.n // validation_generator.batch_size+1)
classes = validation_generator.classes[validation_generator.index_array]
y_pred = np.argmax(Y_pred, axis=-1)  # Returns maximum indices in each row
print(sum(y_pred==classes)/10000)

class_names = ['cry', 'laugh', 'silence', 'speech(babble)']  # Alphanumeric order

print('-- Confusion Matrix --')
print(confusion_matrix(validation_generator.classes[validation_generator.index_array], y_pred))

# Plot non-normalized confusion matrix
plot_confusion_matrix(validation_generator.classes[validation_generator.index_array], y_pred, classes=class_names, title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plot_confusion_matrix(validation_generator.classes[validation_generator.index_array], y_pred, classes=class_names, normalize=True, title='Normalized confusion matrix')

plt.show()

print('-- Classification Report --')
print(classification_report(validation_generator.classes[validation_generator.index_array], y_pred, target_names=class_names))

2019

Redis 기초 특강 - 강대명 멘토

1 minute read

Redis 소개 In-memory data structure storage disk 접근을 하지 않음 -> 속도가 빠르다 오픈소스(BSD 3) 제공되는 자료구조들 Strings, set, sorted-set, ...

Node.js 특강 - 손영수 멘토 (2)

2 minute read

MongoDB 클라우드 서비스를 SaaS로 제공하려다가 그 중 DB 서비스가 제일 잘 나가서 MongoDB가 됨 No Schema: JSON data 삽입 시 field가 생성됨 Document data model JSON data를 그대로 넣음 ...

Node.js 특강 - 손영수 멘토 (1)

2 minute read

웹 서버의 구동 방식에는 8가지가 있음(3페이지) Node.js는 비동기, non-blocking Server side Javascript Event driven Asynchronous Non-Blocking I/O ...

TensorFlow.js: Speech Command Recognizer (번역)

less than 1 minute read

TensorFlow.js의 공식 모델 중 하나인 Speech command recognition에 대해 조사해 보았습니다. 원문 링크: tfjs-models/speech-commands at master · tensorflow/tfjs-models · GitHub

Chapter 04. HTML & CSS 필수 기초 (2)

2 minute read

이 포스트는 SW마에스트로 자기주도학습으로 패스트캠퍼스의 웹 프론트엔드 올인원 패키지 Online을 수강하면서 작성한 노트입니다.

Chapter 04. HTML & CSS 필수 기초 (1)

1 minute read

이 포스트는 SW마에스트로 자기주도학습으로 패스트캠퍼스의 웹 프론트엔드 올인원 패키지 Online을 수강하면서 작성한 노트입니다.

Chapter 03. CSS 입문

1 minute read

이 포스트는 SW마에스트로 자기주도학습으로 패스트캠퍼스의 웹 프론트엔드 올인원 패키지 Online을 수강하면서 작성한 노트입니다.

CC 라이센스

less than 1 minute read

소프트웨어 마에스트로 과정에서 아기 울음소리를 인식하고 분류하는 딥러닝 모델을 작성하고 있습니다. GitHub에 공개되어 있는 소리 데이터셋을 사용하기 위해 확인한 라이센스 몇 가지를 정리해 보았습니다. 참고한 곳: CC 라이선스 :: Creative Commons K...

URI와 URL의 차이점

less than 1 minute read

이 포스트는 What Is The Difference Between A URI And A URL? - DEV Community 👩‍💻👨‍💻 를 참고하여 작성되었습니다.

Chapter 02. HTML 입문

1 minute read

이 포스트는 SW마에스트로 자기주도학습으로 패스트캠퍼스의 웹 프론트엔드 올인원 패키지 Online을 수강하면서 작성한 노트입니다.

블로그 개설

less than 1 minute read

Github 블로그를 드디어 개설했습니다.🎉

Back to Top ↑