MNIST 손글씨 인식 프로그램

앞에서 만들었던 간단한 그림판 프로그램과 TensorFlow의 MNIST 손글씨 이미지 인식 예제를 함께 이용해서 간단한 손글씨 인식 프로그램을 만들 수 있습니다.

TensorFlow - 인공 신경망 모델의 저장과 복원 예제를 함께 참고하세요.


예제

## Ex 10-6. MNIST 손글씨 인식 프로그램.

import sys
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
import numpy as np
import tensorflow as tf


class MyApp(QMainWindow):

    def __init__(self):
        super().__init__()
        self.image = QImage(QSize(400, 400), QImage.Format_RGB32)
        self.image.fill(Qt.white)
        self.drawing = False
        self.brush_size = 30
        self.brush_color = Qt.black
        self.last_point = QPoint()
        self.loaded_model = None
        self.initUI()

    def initUI(self):
        menubar = self.menuBar()
        menubar.setNativeMenuBar(False)
        filemenu = menubar.addMenu('File')

        load_model_action = QAction('Load model', self)
        load_model_action.setShortcut('Ctrl+L')
        load_model_action.triggered.connect(self.load_model)

        save_action = QAction('Save', self)
        save_action.setShortcut('Ctrl+S')
        save_action.triggered.connect(self.save)

        clear_action = QAction('Clear', self)
        clear_action.setShortcut('Ctrl+C')
        clear_action.triggered.connect(self.clear)

        filemenu.addAction(load_model_action)
        filemenu.addAction(save_action)
        filemenu.addAction(clear_action)

        self.statusbar = self.statusBar()

        self.setWindowTitle('MNIST Classifier')
        self.setGeometry(300, 300, 400, 400)
        self.show()

    def paintEvent(self, e):
        canvas = QPainter(self)
        canvas.drawImage(self.rect(), self.image, self.image.rect())

    def mousePressEvent(self, e):
        if e.button() == Qt.LeftButton:
            self.drawing = True
            self.last_point = e.pos()

    def mouseMoveEvent(self, e):
        if (e.buttons() & Qt.LeftButton) & self.drawing:
            painter = QPainter(self.image)
            painter.setPen(QPen(self.brush_color, self.brush_size, Qt.SolidLine, Qt.RoundCap))
            painter.drawLine(self.last_point, e.pos())
            self.last_point = e.pos()
            self.update()

    def mouseReleaseEvent(self, e):
        if e.button() == Qt.LeftButton:
            self.drawing = False

            arr = np.zeros((28, 28))
            for i in range(28):
                for j in range(28):
                    arr[j, i] = 1 - self.image.scaled(28, 28).pixelColor(i, j).getRgb()[0] / 255.0
            arr = arr.reshape(-1, 28, 28)

            if self.loaded_model:
                pred = self.loaded_model.predict(arr)[0]
                pred_num = str(np.argmax(pred))
                self.statusbar.showMessage('숫자 ' + pred_num + '입니다.')

    def load_model(self):
        fname, _ = QFileDialog.getOpenFileName(self, 'Load Model', '')

        if fname:
            self.loaded_model = tf.keras.models.load_model(fname)
            self.statusbar.showMessage('Model loaded.')

    def save(self):
        fpath, _ = QFileDialog.getSaveFileName(self, 'Save Image', '', "PNG(*.png);;JPEG(*.jpg *.jpeg);;All Files(*.*) ")

        if fpath:
            self.image.scaled(28, 28).save(fpath)

    def clear(self):
        self.image.fill(Qt.white)
        self.update()
        self.statusbar.clearMessage()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = MyApp()
    sys.exit(app.exec_())


설명

load_model_action = QAction('Load model', self)
load_model_action.setShortcut('Ctrl+L')
load_model_action.triggered.connect(self.load_model)

...

filemenu.addAction(load_model_action)
filemenu.addAction(save_action)
filemenu.addAction(clear_action)

미리 학습한 모델을 불러오기 위한 load_model_action을 하나 만들고, ‘File’ 메뉴에 추가합니다.



def mouseReleaseEvent(self, e):
    if e.button() == Qt.LeftButton:
        self.drawing = False

        arr = np.zeros((28, 28))
        for i in range(28):
            for j in range(28):
                arr[j, i] = 1 - self.image.scaled(28, 28).pixelColor(i, j).getRgb()[0] / 255.0
        arr = arr.reshape(-1, 28, 28)

self.image.scaled(28, 28)을 이용해서 사용자가 그린 이미지를 MNIST 데이터셋의 이미지 크기에 맞게 변환합니다.

픽셀의 값을 NumPy 어레이로 가져와서, 인공 신경망 모델의 입력에 맞게 형태를 변환합니다.



if self.loaded_model:
    pred = self.loaded_model.predict(arr)[0]
    pred_num = str(np.argmax(pred))
    self.statusbar.showMessage('숫자 ' + pred_num + '입니다.')

불러온 인공 신경망 모델의 predict()를 이용해서 사용자가 그린 손글씨에 대한 예측을 수행합니다.

showMessage()를 이용해서 이 예측에 따른 결과를 상태바에 표시합니다.

(참고: PyQt5 기초 - 상태바 만들기)



def load_model(self):
    fname, _ = QFileDialog.getOpenFileName(self, 'Load Model', '')

    if fname:
        self.loaded_model = tf.keras.models.load_model(fname)
        self.statusbar.showMessage('Model loaded.')

load_model() 메서드는 QFileDialog를 이용해서 미리 학습한 인공 신경망 모델을 불러오는 기능을 합니다.

모델을 불러왔다면 상태바에 메세지를 표시합니다.

(참고: PyQt5 다이얼로그 - QFileDialog)



결과

../_images/9_6_mnist_classifier.gif

그림 10-6. MNIST 손글씨 인식 프로그램.


이전글/다음글