Jupyterでscala smileを使って機械学習する

Scala Almondの使い方

Scala Almondは、Jupyter NotebookでScalaを使うためのカーネルです。この記事では、Scala Almondをインストールする方法、Jupyter Notebook上でScala Almondを使う方法について説明します。

Scala Almondのインストール方法

今回、AlmondのインストールはDockerを使用します。このGitHubリポジトリをローカルにクローンしてください。Dockerイメージのベースは、DockerHubで公開されているAlmondのイメージを利用しておりますが、Java17を別途インストールしています。

以下のコマンドを端末で実行し、Dockerイメージをビルドしてください。予め、Dockerとdocker-composeをインストールしておいてください。

docker-compose build

Scala AlmondをJupyter Notebook上で使う方法

Scala AlmondをJupyter Notebook上で使うには、以下のコマンドを端末で実行してください。

docker-compose up

すると、Dockerコンテナが立ち上がり、Jupyterのプロセスが実行されます。端末上に以下のURLが表示されます。このURLをブラウザで開いてください。

http://127.0.0.1:8888/lab?token=d3bab8c8494393cf8c3551da3c6475bd85925aaa27efe950

すると以下の様なページが表示されます。これでAlmondが使えます!

Scala Almondをの使い方

Jupyterのページで表示されている左のウィンドウは、ファイル・フォルダブラウザです。work/src/notebookにデモで用意したJupyterノートブックがあります。demo_notebook.ipynbというファイル名を選びましょう。GitHubのページはこちらです。

最初のセルは、利用するライブラリをインストールしています。import $ivy.`XXXという感じでインストールすることができます。インストールしたパッケージ名やバージョンはMVN Repositoryで検索できます。

Jupyterを使う用途は主にデータ分析です。データを処理したり視覚化するのに非常に便利なツールです。

後は、Scalaのコードをそのまま書くだけです。使用するライブラリをインポートします。

import java.awt.Color
import scala.language.postfixOps
import smile.read
import smile.data.DataFrame
import smile.interpolation.BicubicInterpolation
import smile.data.formula._
import smile.plot.show
import smile.plot.swing.heatmap
import smile.plot.swing.Palette
import smile.plot.swing._
import smile.plot.Render._
import smile.validation.{cv, RegressionValidation}
import smile.regression.GradientTreeBoost
import smile.write

smileを使って配列を生成し、ヒートマップを作成します。

機械学習で有名なアヤメデータを読み込み、グラフで表示することもできます。

用意したJupyterノートブックには、機械学習の例も用意しました。Gradient Tree Boost(勾配ブースティング決定木)と糖尿病データを使って、血糖値の回帰モデルを学習します。

// データを読み込む
val diabetes = read.csv("diabetes.csv")

// 交差検定を使ってモデルを学習する
val validated = cv.regression(10, "y"~, diabetes) { (formula, data) => smile.regression.gbm(formula, data) }

// 最も結果の良かったモデルを取得する。この場合、RMSEが最も低かったもの
var bestVal: RegressionValidation[GradientTreeBoost] = null
var rmse = Double.MaxValue
validated.rounds.forEach{validation =>
    if (rmse > validation.metrics.rmse) {
        bestVal = validation
        rmse = validation.metrics.rmse
    }
}

// 使用した学習データとモデルを使って予測する
val results = diabetes.merge(DataFrame.of(bestVal.model.predict(diabetes).map(Array(_)), "prediction"))

// 散布図で予測結果と真値を描画する
val scatter = plot(results.select("prediction", "y").toArray(),  mark='o', color=Color.BLUE)
show(scatter)

結果が以下です。

おわり

以上が、Scala Almondのインストール方法と使い方についての説明です。Scala Almondは、Scalaを学ぶ上で非常に便利なツールです。ぜひ、お試しください。

Leave a Reply

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください