【TorchSharp】C#で機械学習 手書き文字認識アプリの実装(サンプルコード付き)

機械学習に関するプログラムは,基本的にpythonにより実装されます.しかしC#(.Net)で作成済みのプログラムに機械学習の機能を追加したい場合や,GUIも作成したい場合など,C#(.Net)で機械学習の処理を実装しなければならない状況もあると思います.C#用の機械学習ライブラリはいくつかありますが,.NET Foundationに組み込まれているTorchSharpが無難な選択肢だと思います.Pytorchベースなので,Pytorchに精通している方はTorchSharpも使いこなせると思います.しかし,TorchSharpに関する情報が少なかったため(Qiitaでは2件のみ),本稿ではTorchSharp + C#(.Net 8.0)による教師あり学習の実装例として,手書き数字のクラス分類アプリケーションをGUI付で作成します.

教師データの学習に加えて,以下のようにGUIにキャンバスを設置して,描かれた数字を認識する機能を持つアプリケーションを作成します.実際に作成したアプリケーションのソースコードは以下からダウンロード可能です.

ソースコードのダウンロード: https://github.com/kkaneko1090/TorchSharpSupervisedLearning

学習データセットは,ソースコード内のTrainningDataフォルダに格納しています.手書きの”0”,”1”,”2”の3種類を21枚ずつ用意したので,今回は3クラスの分類となります.詳しくはTrainningDataフォルダの中身を見てください.

1. プロジェクトの作成

今回はWPFでGUIも作成したいので,.Net WPFアプリケーションを選択します.

現時点での最新版のTorchSharp(v0.103)は.Net 6.0 を対象としていますが,.Net 8.0でも動作することを確認しているため,今回は.Net8.0でプロジェクトを作成します.

2. パッケージの準備

最新のTorchSharpをNugetでインストールします.GPUを使って学習を行いたいので,同じバージョンのTorchSharp-cuda-windowsも同様にインストールします.また,本プログラムではBitmapを扱うため,System.Drawing.Commonも追加しました.

3. 機械学習モデルの実装

以下のように,機械学習モデルのクラス ”MLModel” を実装しました.メンバ変数としてモデルを構成する各層をメンバ変数として記述します.各層の次元は,コンストラクタで初期化するようにしました.全結合層の入力次元を把握するために,ダミーの入力データを用いて,畳み込み層の出力”dammyConvOutput”を計算しています.

if (torch.cuda.is_available()) _device = CUDA;では,GPUが使用可能かを判定し,使用できる場合はGPUで学習を実行するようにコーディングしました._device = CUDAであれば,this.to(_device);でモデルがGPUに転送されます.

using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.torch.nn.functional;

namespace TorchSharpSupervisedLearning
{
    public class MLModel : Module<Tensor, Tensor>
    {
        #region メンバ変数
        /// <summary>
        /// 畳み込み層1
        /// </summary>
        private Conv2d _conv1;
        /// <summary>
        /// 畳み込み層2
        /// </summary>
        private Conv2d _conv2;
        /// <summary>
        /// 全結合層1
        /// </summary>
        private Linear _linear1;
        /// <summary>
        /// 全結合層2
        /// </summary>
        private Linear _linear2;
        /// <summary>
        /// 隠れ層のサイズ
        /// </summary>
        private int _hiddenLayerSize = 32;

        /// <summary>
        /// デバイス(CPU or GPU)
        /// </summary>
        private Device _device = CPU;
        #endregion

        /// <summary>
        /// コンストラクタ
        /// </summary>
        /// <param name="inputSize">入力する画像のサイズ</param>
        /// <param name="outputSize">出力するベクトルのサイズ</param>
        public MLModel(int[] inputSize, int outputSize) : base("CNN")
        {
            //ダミーの入力データを作成(各層の次元の初期化に使用)
            Tensor dammyInput = zeros([inputSize[0], inputSize[1]]).unsqueeze(0).unsqueeze(0);
            //畳み込み層の初期化
            _conv1 = Conv2d(in_channels: 1, out_channels: 16, kernelSize: 8, stride: 2);
            _conv2 = Conv2d(in_channels: 16, out_channels: 16, kernelSize: 8, stride: 2);
            //ダミーの畳み込み層の出力
            Tensor dammyConvOutput = _conv1.forward(dammyInput); //畳み込み層1
            dammyConvOutput = _conv2.forward(dammyConvOutput); //畳み込み層2
            dammyConvOutput = flatten(dammyConvOutput, start_dim: 1); //平滑化
            //全結合層の初期化
            _linear1 = Linear(inputSize: dammyConvOutput.shape[1], outputSize: _hiddenLayerSize);
            _linear2 = Linear(inputSize: _hiddenLayerSize, outputSize: outputSize);

            //コンポーネントの登録
            RegisterComponents();

            //GPUを使用できるか
            if (torch.cuda.is_available()) _device = CUDA; //GPUを活用
            //デバイスに転送
            this.to(_device); 
        }

        /// <summary>
        /// 順伝播処理のオーバーライド
        /// </summary>
        /// <param name="input">入力データ</param>
        /// <returns></returns>
        public override Tensor forward(Tensor input)
        {
          //後述する
        }

        /// <summary>
        /// バッチ学習
        /// </summary>
        /// <param name="dataset">教師データのリスト</param>
        /// <param name="epochCount">エポック数</param>
        /// <param name="batchSize">バッチサイズ</param>
        public void TrainOnBatch(List<(Tensor input, Tensor output)> dataset, int epochCount, int batchSize)
        {
         //後述する
        }

        /// <summary>
        /// 推論
        /// </summary>
        /// <param name="input">単一の入力データ</param>
        /// <returns></returns>
        public (int index, float probability) Predict(Tensor input) 
        {
              //後述する
        }
    }
}

コンストラクタに加えて,順伝播の関数”forward”,バッチ学習を行う”TrainOnBatch”,学習後に推論を行うための”Predict”の3つの関数も定義しています.それぞれの関数は以下の通りです.

”forward”は親クラスの関数をoverrideしています.単純に各層の出力を順番に計算していくだけです.今回はクラス分類を行うので,出力層の活性化関数はSoftmaxとしております.

”TrainOnBatch”では学習データをバッチに分割して,まずvar predicted = this.forward(input);で順伝播します.その後教師データとの差分をvar error = loss.forward(predicted, output);で計算することで,計算グラフが構築されるので,error.backward();でグラフを辿って逆伝播します.

”Predict”は学習後に呼び出す推論用の関数です.Tensor化した画像がどのクラスに分類されるか予測します.戻り値は予測されるクラスのインデックスと確率です.

        /// <summary>
        /// 順伝播処理のオーバーライド
        /// </summary>
        /// <param name="input">入力データ</param>
        /// <returns></returns>
        public override Tensor forward(Tensor input)
        {
            //畳み込み
            var x = relu(_conv1.forward(input)); //活性化関数はReLUを使用
            x = relu(_conv2.forward(x));
            //平坦化
            x = torch.flatten(x, start_dim: 1);
            //全結合層
            x = relu(_linear1.forward(x));
            x = softmax(_linear2.forward(x), dim:1); //クラス分類のためSoftmax関数
            return x;
        }

        /// <summary>
        /// バッチ学習
        /// </summary>
        /// <param name="dataset">教師データのリスト</param>
        /// <param name="epochCount">エポック数</param>
        /// <param name="batchSize">バッチサイズ</param>
        public void TrainOnBatch(List<(Tensor input, Tensor output)> dataset, int epochCount, int batchSize)
        {
            //オプティマザの初期化
            var optimizer = optim.Adam(parameters: this.parameters(), lr: 0.001); //学習率を0.001に設定
            //損失関数
            var loss = CrossEntropyLoss();

            //エポック数だけ繰り返し
            for (int epoch = 0; epoch < epochCount; epoch++)
            {
                //バッチ取り出し
                var batcheArray = Utility.GetBatch(dataset, batchSize);

                //バッチの繰り返し
                for (int batch = 0; batch < batcheArray.Length; batch++)
                {
                    //入力データ
                    var input = batcheArray[batch].input.to(_device);
                    //出力データ
                    var output = batcheArray[batch].output.to(_device);
                    
                    //オプティマイザの勾配を初期化
                    optimizer.zero_grad();
                    //推論
                    var predicted = this.forward(input);
                    //残差
                    var error = loss.forward(predicted, output);
                    //逆伝播
                    error.backward();
                    optimizer.step();

                    Console.WriteLine(error.ToSingle());
                }

                //メモリ解放
                GC.Collect();
            }
        }

        /// <summary>
        /// 推論
        /// </summary>
        /// <param name="input">単一の入力データ</param>
        /// <returns></returns>
        public (int index, float probability) Predict(Tensor input) 
        {
            //順伝播
            Tensor output =  this.forward(input.unsqueeze(0).to(_device)).squeeze(0);
            //配列に変換
            float[] array = new float[output.shape[0]];
            for(int i = 0; i < array.Length; i++) array[i] = output[i].ToSingle();
            //最大値をとるインデックスを取得
            int maxIndex = Array.IndexOf(array, array.Max());
            return (maxIndex, array[maxIndex]);
        }

TorchSharpの”DataLoader”を使えばバッチ学習をよりスマートに実装できそうですが,今回はUtility.csにバッチ生成用のコードを実装しました.Utility.csには,Bitmap→Tensorの変換なども記載しておりますが,詳細はGitHubに公開したソースコードを参照してください.

4. GUIの作成

手書きの数字について,学習と推論の両方の機能を備えたアプリケーションを作成したいので,以下のようなGUIを作成しました.

まず学習のために”Train”ボタンを配置しました.このボタンをクリックするとフォルダ選択ダイアログが立ち上がるので,学習用の画像データセット(ソースコードに付属のTrainningDataフォルダ)を選択すると学習が開始します.

GUIの上部の白い正方形領域はInkCanvasで,ここにマウスでドラッグして数字を描き,”Predict”ボタンをクリックするとGUI上に推論結果が表示されます.また”Reset”ボタンでInkCanvasを白紙に戻せます.

学習中のLossの推移は,GUIと同時に立ち上がるコンソールウィンドウに表示するようにしました.

XAMLは以下の通り.

<Window x:Class="TorchSharpSupervisedLearning.MainWindow"
        xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
        xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
        xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
        xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
        xmlns:local="clr-namespace:TorchSharpSupervisedLearning"
        mc:Ignorable="d"
        Title="MainWindow" Height="500" Width="309"
        Background="Gray">
    <Grid>
        <StackPanel Orientation="Vertical">
            <Canvas Background="Transparent" Margin="10" Width="257" Height="257">
                <InkCanvas x:Name="cnvDrawingArea" Width="256" Height="256" Margin="0" Background="White"/>
            </Canvas>
            <TextBlock x:Name="txtPredicted" Text="null" Height="32" Width="250" Foreground="Yellow" TextAlignment="Center" FontSize="18" FontStyle="Normal"/>
            <Button x:Name="btnReset" Content="Reset" Height="32" Width="150" Margin="5" Click="btnReset_Click"/>
            <Button x:Name="btnPredict" Content="Predict" Height="32" Width="150" Margin="5" Click="btnPredict_Click"/>
            <Button x:Name="btnTrain" Content="Train" Height="32" Width="150" Margin="5" Click="btnTrain_Click"/>
        </StackPanel>
    </Grid>
</Window>

5. アプリケーションの実行

以下の動画のように操作することで,①手書き数字の学習,②Canvasのリセット,③手書き文字の推論を行えます.動画では,手書きの数字を精度よく推論できていることがわかります.このようにTorchSharpを用いることで,C#でも十分に機械学習が行えます.

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA