TensorFlowLiteのJavaインタプリタでハマった話

環境

  • 学習はKeras(back-endはTensorFlow)を使用
  • スマートフォン上のランタイムはTensorFLowLiteを使用
  • TensorFlow(PBモデル)からTensorFLowLite(FBモデル)へ変換(変換には公式のものを使用)
  • 入力は8bitグレースケールの画像配列、出力はXY座標とした単純な回帰モデル

現象

下記のように8bitグレースケールで入力バッファを用意してXY座標を取得しようとした。
255で割っているのは、正規化しているため。

public synchronized void recognize(byte[] aImageBytes) {
    for(byte px : aImageBytes) {
        mInputBuffer.putFloat(px / 255f);
    }
    ...
}

しかし、推定結果のXY座標が誤差どころではなく、でたらめ(壊れているかのような)の推定結果が返ってくる。

原因

コードを読んで一瞬で"あっ"ってなった方、さすがです。

ここでJavaのbyteの範囲について思い出してみます。
byteは符号あり8bitなんですね。
つまり-128~127までなんです...

上記のコードは0~255(符号なし)の範囲で正規化しているつもりなので...あっあっあっ

解決策

普通にANDとればよいです。
ちなみにJava8からはByte.toUnsignedInt(byte)が使えます。

public synchronized void recognize(byte[] aImageBytes) {
    for(byte px : aImageBytes) {
        mInputBuffer.putFloat( ((int)px & 0xFF) / 255f);
    }
    ...
}

このバグ、他の人も巻き込んで悩むくらいめちゃくちゃ沼にはまってしまったので猛反省...

ちなみにputFloatは遅いのでput(byte[])を使ったほうがよいです。
自力でfloatからbyte[]に変換する手間がありますが。
※Float.floatToIntBitsとかであっさりいけますがエンディアンには注意。