C#使用TensorFlow.NET训练自己的数据集的方法

2020-03-22 20:01:38王冬梅

数据集说明

为了模型测试的训练速度考虑,图像数据集主要节选了一小部分的OCR字符(X、Y、Z),数据集的特征如下:

分类数量:3 classes 【X/Y/Z】

图像尺寸:Width 64 × Height 64

图像通道:1 channel(灰度图)

数据集数量:

train:X - 384pcs ; Y - 384pcs ; Z - 384pcs validation:X - 96pcs ; Y - 96pcs ; Z - 96pcs test:X - 96pcs ; Y - 96pcs ; Z - 96pcs

其它说明:数据集已经经过 随机 翻转/平移/缩放/镜像 等预处理进行增强

整体数据集情况如下图所示:

代码说明

环境设置

.NET 框架:使用.NET Framework 4.7.2及以上,或者使用.NET CORE 2.2及以上 CPU 配置: Any CPU 或 X64 皆可 GPU 配置:需要自行配置好CUDA和环境变量,建议 CUDA v10.1,Cudnn v7.5

类库和命名空间引用

从NuGet安装必要的依赖项,主要是SciSharp相关的类库,如下图所示:

注意事项:尽量安装最新版本的类库,CV须使用 SciSharp 的 SharpCV 方便内部变量传递

<PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="1.15.0" />
<PackageReference Include="SciSharp.TensorFlowHub" Version="0.0.5" />
<PackageReference Include="SharpCV" Version="0.2.0" />
<PackageReference Include="SharpZipLib" Version="1.2.0" />
<PackageReference Include="System.Drawing.Common" Version="4.7.0" />
<PackageReference Include="TensorFlow.NET" Version="0.14.0" />

引用命名空间,包括 NumSharp、Tensorflow 和 SharpCV ;

using NumSharp;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using SharpCV;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using Tensorflow;
using static Tensorflow.Binding;
using static SharpCV.Binding;
using System.Collections.Concurrent;
using System.Threading.Tasks;

主逻辑结构

主逻辑:

准备数据

创建计算图

训练

预测

public bool Run()
{
 PrepareData();
 BuildGraph();
​
 using (var sess = tf.Session())
 {
  Train(sess);
  Test(sess);
 }
​
 TestDataOutput();
 return accuracy_test > 0.98;
}

数据集载入

数据集下载和解压

数据集地址:https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip