連載第3回の目的

この回では、前回の内容を踏まえて、実用的な事例として自動車のエンジン出力から燃費を予測するサンプルを作ります。大量データを使うときのデータ処理の方法、tfjs-visライブラリを使ったデータの可視化などの応用事例をハンズオンで実践します。

  • 図1:完成サンプルイメージ

▼完成サンプル
https://github.com/wateryinhare62/mynavi_tensorflowjs/

前回に引き続き、今回のテーマも「線形単回帰分析」です。今回の事例であるエンジン出力(kW)と燃費(km/リットル)には、一定の法則(エンジン出力が大きいと燃費が悪くなる)が働くことが予想されます。このような場合に用いることができるのが、まさに線形単回帰分析です。
ハンズオンは、図2のステップで実施していきます。データの準備から予測(汎化)に至る、第1回でも紹介したプロセスです。

  • 図2:ハンズオンのステップ

[NOTE]サンプルについて
本記事のサンプルは、Google Codelabのチュートリアル「TensorFlow.js - 2D データから予測を行う」(https://codelabs.developers.google.com/codelabs/tfjs-training-regression?hl=ja#0)におけるサンプルをベースに、一部改変して作成、掲載しています。

今回作成するサンプルでは、モデルの訓練のためのデータを外部から読み込み、モデルを訓練した後、実際のエンジン出力データを与えて燃費を予測させてみます。HTMLページには、指定したエンジン出力データで予測できるように、入力欄と予測開始のボタンから構成されるフォームを配置します。
ページがブラウザに読み込まれると、ただちにデータの読み込みや訓練が始まり、その経過はフォームの下、またはページの右ペインに表示されます。右ペインは「バイザー」と呼ばれる領域で、可視化ライブラリtfjs-visが作成するグラフや表を表示するために利用します。後ほど明らかになりますが、同一バイザーへの描画は蓄積されるので、必要に応じてスクロール操作します。
訓練が済むと予測が可能になるので、入力ボックスにエンジン出力を入力してボタンをクリックすると、予測できる燃費が表示されます。

[NOTE]バイザー(Visor)
tfjs-visでは、「バイザー」(Visor)という領域がブラウザに作成され、配下のサーフェスと呼ばれる領域にグラフや表が描画されます。図1では、右に表示されている領域がバイザーです。バイザー上部にあるボタンで最大化(Maximize)したり消去(Hide)したりできます。サーフェスは複数作成することができますが、本サンプルでは既定のサーフェスのみを使います。

ファイルの準備

まず、サンプルの基点となるHTMLファイル(index.html)を用意します(図3)。ポイントは、3つのJavaScriptファイルを読み込んでいることです。TensorFlow.jsに加えて、可視化ライブラリtfjs-vis、そして本サンプル用JavaScriptファイルを、それぞれインポートします。

  • 図3:index.html

この読み込みによって、以下のグローバル変数が作成されます。サンプルでは、この変数を使って各ライブラリの機能を使っていきます。

  • tf:TensorFlow.jsライブラリを使うための変数
  • tfvis:tfjs-visライブラリを使うための変数

続けて、ライブラリを使うコードを書いていくJavaScriptファイル(script.js)を用意します(図4)。初期状態として、HTMLの読み込み後に起動するイベントハンドラ関数trainを作成し、登録します。
現時点でのtrain関数の処理は、処理状況をページ(body要素)に表示するinsertAdjacentHTMLメソッドの呼び出しのみです。以降、モデルの作成、訓練などさまざまな関数をscript.jsに記述していきますが、それらの呼び出しコードはtrain関数に追記していくものとします。覚えておいてください。

  • 図4:script.js

フォームなどの形を整えるためのCSSファイル(style.css)も用意しますが、動作上は重要でないので内容は配布サンプルを参照してください。
HTMLファイルをブラウザで読み込んで、図5のように表示されればOKです。

  • 図5:初期表示

オリジナルデータの読み込み

訓練用のオリジナルデータを読み込んでフォーマットし、tfjs-visを使ってその散布図を表示します(図6)。

  • 図6:script.js(readData関数)

作成するのは、readData関数です。この関数は、オリジナルデータを読み込んでフォーマットを実行し、散布図を表示した後、読み込んだデータのオブジェクトを返します。大きく3つの処理に分かれています。

  1. オリジナルデータの読み込み
  2. フォーマット
  3. 散布図の表示

オリジナルデータは、TensorFlow.jsのチュートリアル用に用意されているJSON形式のファイルを利用します(1.)。このオリジナルデータは、自動車のエンジン出力と燃費の組み合わせを多数収録したものです。
2.では、オリジナルデータはエンジン出力が馬力(HP)、燃費がMPG(1ガロンあたりのマイル数)という米国仕様なので、日本国内で分かりやすいようにエンジン出力をkW、燃費をkm/リットルに換算します。また、どちらかが欠落しているような不正なデータが混じっているので、それを除外します。このようなノイズを除去する作業により、訓練の精度が向上します。
3.では、フォーマットしたオリジナルデータを散布図として表示します。散布図の表示には、tfjs-visのscatterplotメソッド(散布図の意)を使います。scatterplotメソッドの引数の意味については、図6を参照してください。
readData関数の呼び出しを、train関数の末尾に追加します(図7)。

  • 図7:script.js(train関数)

HTMLファイルを読み込んで、図8のように散布図が表示されることを確認してください。なお、グラフ右上に表示されるseriesとは、「系列」という意味です。

  • 図8:読み込んだデータをフォーマットして散布図を表示

モデルの作成

オリジナルデータの準備ができたら、機械学習の核となるモデルを作成し、その概要を表示させます(図9)。

  • 図9:script.js(createModel関数)

TensorFlow.jsにおけるモデルの作成方法には3種類あると、第2回で説明しました。今回も第2回と同様に単回帰分析が目的なので「Sequential」を採用します。モデルには、層を2つ追加します。それぞれ、入力層、出力層に相当します。ここでは、入力はエンジン出力のみ、出力は燃費のみなので入力数(inputShape)とユニット数(units)ともに1となるDense層とします(図10)。Dense層についても、第2回で説明しました。ここでは、useBiasによってバイアス(y=ax+bのbの部分)を使用した計算も指定しています(Dense層の既定値)。

  • 図10:作成するモデル

createModel関数の呼び出しを、train関数の末尾に追加します(図11)。

  • 図11:script.js(train関数)

HTMLファイルを読み込んで、図12のようにモデルの概要が表示されることを確認してください。今回、層を2つ作成したので、表も2行となっています。表の各列の意味は、図を参照してください。

  • 図13:モデルの概要を表示

訓練用データの準備

オリジナルデータをTensorFlow.jsで有効に扱えるようにするために、いくつかの変換作業を実施します(図13)。

  • 図13:script.js(prepareTraining関数)

  1. シャッフル
  2. テンソル(第1回を参照)に変換
  3. 正規化

1.のシャッフルは、第2回の事例ではなかった作業です。シャッフルとは、データをランダムに並び替えることです。これは訓練におけるデータの偏りを防ぐとともに、順序による影響を受けないようにする作業です。訓練では、大量のデータを小さなサブネット(バッチ)に分解して処理するので、全てのバッチでデータが偏りのない状態になることは学習上重要です。
2.について、TensorFlow.jsにおいてはデータはテンソルで流れるので、まずはオリジナルデータをエンジン出力と燃費のそれぞれの1次元配列に分解し、テンソル(この場合は1階テンソル)へ変換します。具体的には、例えばエンジン出力の1次元配列(要素数N)が、[N]形式の1階テンソルに変換されます。
3.の正規化とは、第2回で説明したようにテンソル内のデータの分布を0~1か-1~1の範囲(ここでは0~1)に収まるように調整することです。正規化は、データの絶対値に影響されにくくする、モデルを収束しやすくする、といった目的で行います。ここで用いている正規化の方法は「Min-Maxスケーリング」と呼ばれており、以下のように最小値と最大値から、その範囲内の値を調整するアルゴリズムとなっています。今回のようにゼロから離れた入力がほとんどである場合に特に有効です。

正規化された値=(値-最小値)÷(最大値-最小値)

なお、この一連の処理は、無名関数としてTensorFlow.jsのtidyメソッドに渡されています(4.)。tidyメソッドの役割は、無名関数の実行中に生成された中間的なテンソルを消去し、それに使われていたGPUリソースを解放することです。
この工程を終えたら、このデータを利用して訓練を実行します。

[NOTE]教師あり学習
今回使用する訓練方法は、「教師あり学習」です。教師あり学習とは、入力に対する出力が分かっている形式での学習をいいます。今回の場合は、エンジン出力と燃費の対応です。正解が分かっているので、傾向を見いだすのに適した方式といえます。これに対し「教師なし学習」というものもあります。正解を持たないので、入力から共通のパターンを見つけてグループ化したり、データ構造を抽出したりする目的で使用されます。
教師あり学習と教師なし学習の、どちらが優れているということはなく、それぞれに適した目的があります。教師あり学習は「回帰」や「分類」に適しています。教師なし学習には正解がないので、たとえグループ化してもそれへの意味付けが必要になりますが、「顧客の購買傾向のルールを見いだしたい」といった場合などに適しています。
今回の事例では、「教師あり学習」によって回帰分析がなされている、ということを理解できれば十分です。

訓練の実行

訓練用データの準備ができたら、いよいよ訓練を実行します(図14)。

  • 図14:script.js(trainModel関数)

作成するのは、trainModel関数です。この関数は、モデルと訓練用のデータ、検証用のデータを受け取り、モデルを訓練します。大きく2つの処理に分かれています。

  1. コンパイル
  2. 訓練のループ

コンパイルの意味については、バッチサイズ、エポック数などとともに第2回で説明しました。今回は、表1の関数を設定します。

表1:compileメソッドの引数

引数 意味 概要
optimizer オプティマイザ adam(Adaptive Moment Estimation)という安定性の高いアルゴリズムを使用
loss 損失関数 第2回同様にmse(meanSquaredError)すなわち平均二乗誤差を使用
metrics 評価関数 mse(平均二乗誤差)を使用

2.のfitメソッドによる訓練ループは、訓練用のデータ、検証用のデータ、バッチサイズ、エポック数などのオプションを与えて実行します。コールバックを指定することで、訓練中の出力を随時グラフに描画することなどが可能です。バッチサイズとエポック数は訓練の結果につながるのでデータ量に応じた適切な選択が必要ですが、ここでは固定値(32,100)としています。なお、fitメソッドは戻り値として訓練中の履歴を返してくれますが、本サンプルでは使用していません。
最後に、prepareTraining関数とtrainModel関数との呼び出しを、train関数の末尾に追加します(図15)。

  • 図15:script.js(train関数)

HTMLファイルを読み込んで、図16のように下降曲線を持つグラフが表示されることを確認してください。上は損失関数によるグラフ、下は評価関数によるグラフです。両者とも平均二乗誤差(mse)を用いているので、同じ曲線となります。最大エポックで0に近ければ近いほど優れた学習といえます。グラフは、訓練中にエポックの回数分だけ随時再描画されます。エポック数を増やすと精度は向上しますが、その分、訓練に要する時間は増大します。

  • 図16:モデルの訓練

予測の実施

いよいよ予測の実施です。予測は、連続した100個のエンジン出力データで実施します(図17)。この予測は、フォームの入力ボックスの内容には無関係に実施し、それに応じた表示は次の手順「予測結果の表示」に委ねることにしています。

  • 図17:script.js(doPredict関数)

作成するのは、doPredict関数です。この関数は、モデルと正規化済みのデータを受け取り、予測を実行して正規化されていない状態に戻した結果を返します。大きく2つの処理に分かれています。

  1. 連続データの作成と予測の実施
  2. 予測データの範囲をオリジナルデータと同じ範囲に戻す

1.では、正規化された(0~1の範囲)連続データを100個作成し、これを燃費を予測したいエンジン出力データとしてモデルに入力します(predictメソッド)。出力として得られるテンソルは、予測結果のデータとなります。これを2.で、オリジナルデータと同じ数値範囲となるように非正規化し、最後にテンソルを通常のオブジェクトに変換して関数の戻り値とします。 この工程を終えたら、このデータを利用して結果を表示します。

予測結果の表示

この記事は
Members+会員の方のみ御覧いただけます

ログイン/無料会員登録

会員サービスの詳細はこちら