【実験】MNISTで学習させたRaspberry Pi上のDNNで、PaPeRo iに手書き数字を音読させる

PaPeRo i の目の前に手で数字を手書きした紙を提示し、その数字をPaPeRo i に音読させる方法の一例を紹介します。
今回、手書き数字の認識にはディープラーニング用のライブラリであるTensorFlow及びKerasを使用し、学習用データとしてMNISTを使用します。
TensorFlowはRaspberry Pi 上にインストールし、Pythonアプリから呼び出して使用します。
使用したTensorFlowのバージョンは1.0.1、Kerasのバージョンは2.1.3、Raspberry Pi は Raspberry Pi 3 model B v1.2、RaspbianはStretchです。

TensorFlowについて、今回の作業実施時点でRaspberry Pi 用の公式バイナリは存在せず、ソースからビルドすると長時間かかる為、今回は https://github.com/samjabrahams/tensorflow-on-raspberry-pi/issues/92 で公開されている非公式のバイナリを使用します。

手順(PaPeRo i 側)

(1) PaPeRo i 制御用WebSocket通信アドオンシナリオをまだインストールしていない場合は、「PaPeRo iをRaspberry Pi上のpythonから操作する」の「PaPeRo iにアドオンシナリオをインストール」に従ってインストール作業を行います。

手順(Raspberry Pi 側)

(1)TensorFlowを、下記のコマンドでインストールします(20分程度かかります)。

$ wget https://www.dropbox.com/s/gy4kockdbdyx85j/tensorflow-1.0.1-cp35-cp35m-linux_armv7l.whl
$ sudo pip3 install tensorflow-1.0.1-cp35-cp35m-linux_armv7l.whl

(2)scipyをインストールします。

$ sudo apt-get install python3-scipy

(3)Kerasをインストールします。

$ sudo pip3 install keras

(4)h5pyをインストールします。

$ sudo apt-get install python3-h5py

(5)https://github.com/keras-team/keras/tree/master/examples で公開されているKerasのサンプルコードのうち、mnist_cnn.py を作業用のディレクトリにダウンロードします。

$ mkdir ~/papero
$ cd ~/papero
$ wget https://github.com/keras-team/keras/raw/master/examples/mnist_cnn.py

python3 mnist_cnn.py でサンプルコードを実行すると、
・MNISTデータのダウンロード(初回実行時のみ)と、ダウンロード済MNISTデータの読み込み
・DNNの一形態である、CNN(畳み込みニューラルネットワーク)の構築
・CNNの、MNISTデータによる学習
が行われるのですが、このままでは長時間かけて行う学習の成果物が残らず、後から再利用する事ができません。

(6)学習結果及びネットワークモデルが保存されるようにする為、

$ cp mnist_cnn.py mnist_cnn_save.py

で、mnist_cnn_save.pyにサンプルコードをコピーし、mnist_cnn_save.pyの末尾に下記の行を追加します。

# save model and weights
model_json_str = model.to_json()
open('mnist_cnn_model.json', 'w').write(model_json_str)
model.save_weights('mnist_cnn_weights.h5')

(7)修正後のコードを実行します。

$ python3 mnist_cnn_save.py

実行すると、学習の進捗状況が下記のような形で表示されます。

Using TensorFlow backend.
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 1279s 21ms/step - loss: 0.2614 - acc: 0.9194 - val_loss: 0.0619 - val_acc: 0.9801
Epoch 2/12
60000/60000 [==============================] - 1382s 23ms/step - loss: 0.0912 - acc: 0.9732 - val_loss: 0.0436 - val_acc: 0.9859
Epoch 3/12
60000/60000 [==============================] - 1405s 23ms/step - loss: 0.0669 - acc: 0.9798 - val_loss: 0.0335 - val_acc: 0.9883
Epoch 4/12
23296/60000 [==========>...................] - ETA: 12:49 - loss: 0.0575 - acc: 0.9833

上記の表示から、1 Epoch 当たりおよそ20~24分程度かかっており、それが学習完了まで12回繰り返される事が分かります。

12回目のEpoch終了後の表示は

Epoch 12/12
60000/60000 [==============================] - 1387s 23ms/step - loss: 0.0249 - acc: 0.9924 - val_loss: 0.0266 - val_acc: 0.9908
Test loss: 0.0266205102392

のようになり、テストサンプルに対して99.08%の認識精度となった事が分かります。

※ニューラルネットワークの学習には乱数が使われる為、再試行では必ずしも上記の結果と一致するとは限りません。

(8)学習完了後、成果物が保存されている事を確認します。

$ ls -l mnist_cnn_*
-rw-r--r-- 1 pi pi    2781  3月  8 04:33 mnist_cnn_model.json
-rw-r--r-- 1 pi pi    2415  3月  7 05:24 mnist_cnn_save.py
-rw-r--r-- 1 pi pi 4821392  3月  8 04:33 mnist_cnn_weights.h5

mnist_cnn_model.json が学習前に構築されたネットワークモデルで、mnist_cnn_weights.h5 が学習結果(成果物)です。

(9)pythonの通信パッケージws4pyをインストールします。また、画像ファイル転送の為にparamikoとscpも使いますので、それらもインストールします。

$ sudo pip3 install ws4py
$ sudo pip3 install paramiko
$ sudo pip3 install scp

(10)「PaPeRo iをRaspberry Pi上のpythonから操作する」の「Raspberry Piへ通信ライブラリをインストール」に従い、ライブラリを配置します。
以下の説明では、ライブラリを ~/papero の下に配置したものとします。

(11)下記のコードをコピー・ペーストして、cnn_readdigit.pyというファイルを作成し、~/papero に置きます。
※ ****ユーザ名****、****パスワード**** の部分につきましては、PaPeRo i に一般ユーザでログインする際に使用するユーザ名とパスワードに置き換えて下さい。

import argparse
import time
from enum import Enum
import numpy as np
from PIL import Image

from paramiko import SSHClient,AutoAddPolicy
from scp import SCPClient
import keras
from keras.models import model_from_json
from keras import backend as K

import pypapero


class State(Enum):
    st0 = 10
    st1 = 11
    st2 = 12
    st3 = 13
    st4 = 14
    end = 999


def predict_by_img(img, model):
    img_gray = img.convert('L')
    # Crop
    iw, ih = img_gray.size;
    ic = iw/2
    il = int(ic - ih / 2)
    ir = int(ic + ih / 2)
    img_cropped = img_gray.crop((il, 0, ir, ih))
    # Binarize and reverse color
    thresh = 80
    im_vec = np.asarray(img_cropped)
    im_vec_bk = np.array(im_vec)
    im_vec.flags.writeable = True
    im_vec[im_vec_bk < thresh] = 255
    im_vec[im_vec_bk >= thresh] = 0
    # Thickening
    w_line = 20
    w_line_h = int(w_line / 2)
    ixy_end = ih - int(w_line / 2 + 0.5)
    lst_iy, lst_ix = np.where(im_vec[w_line_h:ixy_end, w_line_h:ixy_end] > 128)
    for i in range(len(lst_iy)):
        iy = lst_iy[i] + w_line_h
        ix = lst_ix[i] + w_line_h
        im_vec[iy - w_line_h : iy - w_line_h + w_line, ix - w_line_h : ix - w_line_h + w_line] = 255
    # Resize
    img_resized = Image.fromarray(np.uint8(im_vec)).resize((28, 28))
    img_resized.save("tmp2.png")
    # Predict
    im_vec = np.asarray(img_resized)
    if K.image_data_format() == 'channels_first':
        im_vec = im_vec.reshape(1, 1, 28, 28).astype('float32') / 255
    else:
        im_vec = im_vec.reshape(1, 28, 28, 1).astype('float32') / 255
    print("--------")
    predict = model.predict(im_vec, batch_size = 1)
    print(predict)
    i_res = predict.argmax(1)[0]
    prob = predict[0, i_res]
    print("Recognition: " + str(i_res))
    print("Probability: " + str(prob))
    return i_res


def main(papero, host, model):
    prev_time = time.monotonic()
    past_time = 0
    interval_time = 0
    state = State.st0
    first = True
    print("HOST=" + host)
    PORT = 22
    USER = "****ユーザ名****"
    PSWD = "****パスワード****"
    scp = None
    ssh = SSHClient() 
    ssh.set_missing_host_key_policy(AutoAddPolicy())
    ssh.connect(host, port=PORT, username=USER, password=PSWD)
    scp = SCPClient(ssh.get_transport())
    digit_now = (-1)
    while state != State.end:
        messages = papero.papero_robot_message_recv(0.1)
        now_time = time.monotonic()
        delta_time = now_time - prev_time
        prev_time = now_time
        if messages is not None:
            msg_dic_rcv = messages[0]
        else:
            msg_dic_rcv = None
        if papero.errOccurred != 0:
            print("------Error occured(main()). Detail : " + papero.errDetail)
            break
        if state == State.st0:
            papero.send_start_speech("中ボタンで目の前の数字1文字を発話します。右又は左ボタンで終了します。")
            past_time = 0
            state = State.st1
        elif state == State.st1:
            past_time += delta_time
            if past_time > 0.5:
                papero.send_get_speech_status()
                state = State.st2
        elif state == State.st2:
            if msg_dic_rcv is not None:
                if msg_dic_rcv["Name"] == "getSpeechStatusRes":
                    if str(msg_dic_rcv["Return"]) == "0":
                        state = State.st3
                    else:
                        past_time = 0
                        state = State.st1
        elif state == State.st3:
            if msg_dic_rcv is not None:
                if msg_dic_rcv["Name"] == "detectButton":
                    if msg_dic_rcv["Status"] == "C":
                        papero.send_take_picture("JPEG", filename="tmp.jpg", camera="VGA")
                        past_time = 0
                        state = State.st4
        elif state == State.st4:
            if msg_dic_rcv is not None:
                if msg_dic_rcv["Name"] == "takePictureRes":
                    scp.get("/tmp/tmp.jpg")
                    img = Image.open("tmp.jpg")
                    if img is not None:
                        digit_now = predict_by_img(img, model)
                        papero.send_start_speech(str(digit_now)+"に見えます")
                    state = State.st3
        if msg_dic_rcv is not None:
            if msg_dic_rcv["Name"] == "detectButton":
                if msg_dic_rcv["Status"] != "C":
                    state = State.end


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = "Usage:")
    parser.add_argument("host", type=str, help = "Host IP address")
    command_arguments = parser.parse_args()
    simulator_id = ""
    robot_name = ""
    host = command_arguments.host
    ws_server_addr = "ws://" + host + ":8088/papero"
    papero = pypapero.Papero(simulator_id, robot_name, ws_server_addr)
    # Load model
    model = model_from_json(open('mnist_cnn_model.json').read())
    model.load_weights('mnist_cnn_weights.h5')
    model.summary();
    main(papero, host, model)
    papero.papero_cleanup()

(12)cnn_readdigit.py を実行します。

$ cd ~/papero
$ python3 cnn_readdigit.py PaPeRoiのIPアドレス

実行すると、PaPeRo i が「中ボタンで目の前の数字1文字を発話します。右又は左ボタンで終了します」と発話します。
白い紙に黒い太目の線で数字を書いた物を PaPeRo i に見せながら、座布団の中ボタンを押すと、紙に書かれた数字が0~9のどれに見えるかを、PaPeRo i が発話します。
PaPeRo i のカメラで撮影された画像は ~/papero/tmp.jpg に保存されます。
また、撮影された画像をMNISTの画像になるべく近づけるためにグレースケール化・中央正方形範囲切出し・2値化・白黒反転・太線化・縮小したものが ~/papero/tmp2.png に保存されますので、画像を確認したい場合はこれらのファイルをブラウザ等で確認します。

認識状況

試しに白い紙の中央に太さ3mm程度の線で1辺6cm程度の正方形に収まる程度の大きさで0~9までの数字を1つずつ書いたものを用意し、~/papero/tmp2.png を目視確認しながら正しく発話させる為に中ボタンの押下がそれぞれの紙に対して何度必要だったかを数え、平均した所、3回程度となりました。
この結果は試行者によっても多少異なる可能性がありますが、学習完了時の成績表示では99.08%の認識精度であったものの、実際に数字を正しく認識できるようにPaPeRo i に見せる事は容易ではなく、数字以外の物がカメラに写ると誤認識の原因になったり、一度正しく認識できても、位置や角度を少しずらすと再び誤認識してしまうといった状況でした。

また、今回の学習では、画像が数字1文字である事が前提となっています。
その為、例えば
・画像の中に数字があるかないかを判別する
・数字の一部が見えた時に数字全体が見えるように首を動かす
・複数の数字を順番に読み上げる
といった事をするには、数字でない画像も学習データに加えたり、数字の一部の画像の認識や、それに対して首をどう動かすかの判断、画像から複数の数字を一つずつ取り出すための方法等、様々な工夫が必要になりそうです。