C#使用TensorFlow.NET训练自己的数据集的方法

2020-03-22 20:01:38王冬梅

数据集下载和解压代码 ( 部分封装的方法请参考 GitHub完整代码 ):

string url = "https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip";
Directory.CreateDirectory(Name);
Utility.Web.Download(url, Name, "data_CnnInYourOwnData.zip");
Utility.Compress.UnZip(Name + "data_CnnInYourOwnData.zip", Name);

字典创建

读取目录下的子文件夹名称,作为分类的字典,方便后面One-hot使用

private void FillDictionaryLabel(string DirPath)
 {
  string[] str_dir = Directory.GetDirectories(DirPath, "*", SearchOption.TopDirectoryOnly);
  int str_dir_num = str_dir.Length;
  if (str_dir_num > 0)
  {
   Dict_Label = new Dictionary<Int64, string>();
   for (int i = 0; i < str_dir_num; i++)
   {
    string label = (str_dir[i].Replace(DirPath + "", "")).Split('').First();
    Dict_Label.Add(i, label);
    print(i.ToString() + " : " + label);
   }
   n_classes = Dict_Label.Count;
  }
 }

文件List读取和打乱

从文件夹中读取train、validation、test的list,并随机打乱顺序。

读取目录

ArrayFileName_Train = Directory.GetFiles(Name + "train", "*.*", SearchOption.AllDirectories);
ArrayLabel_Train = GetLabelArray(ArrayFileName_Train);
​
ArrayFileName_Validation = Directory.GetFiles(Name + "validation", "*.*", SearchOption.AllDirectories);
ArrayLabel_Validation = GetLabelArray(ArrayFileName_Validation);
​
ArrayFileName_Test = Directory.GetFiles(Name + "test", "*.*", SearchOption.AllDirectories);
ArrayLabel_Test = GetLabelArray(ArrayFileName_Test);

获得标签

private Int64[] GetLabelArray(string[] FilesArray)
{
 Int64[] ArrayLabel = new Int64[FilesArray.Length];
 for (int i = 0; i < ArrayLabel.Length; i++)
 {
  string[] labels = FilesArray[i].Split('');
  string label = labels[labels.Length - 2];
  ArrayLabel[i] = Dict_Label.Single(k => k.Value == label).Key;
 }
 return ArrayLabel;
}

随机乱序

public (string[], Int64[]) ShuffleArray(int count, string[] images, Int64[] labels)
{
 ArrayList mylist = new ArrayList();
 string[] new_images = new string[count];
 Int64[] new_labels = new Int64[count];
 Random r = new Random();
 for (int i = 0; i < count; i++)
 {
  mylist.Add(i);
 }
​
 for (int i = 0; i < count; i++)
 {
  int rand = r.Next(mylist.Count);
  new_images[i] = images[(int)(mylist[rand])];
  new_labels[i] = labels[(int)(mylist[rand])];
  mylist.RemoveAt(rand);
 }
 print("shuffle array list: " + count.ToString());
 return (new_images, new_labels);
}

部分数据集预先载入

Validation/Test数据集和标签一次性预先载入成NDArray格式。

private void LoadImagesToNDArray()
{
 //Load labels
 y_valid = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Validation)];
 y_test = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Test)];
 print("Load Labels To NDArray : OK!");

 //Load Images
 x_valid = np.zeros(ArrayFileName_Validation.Length, img_h, img_w, n_channels);
 x_test = np.zeros(ArrayFileName_Test.Length, img_h, img_w, n_channels);
 LoadImage(ArrayFileName_Validation, x_valid, "validation");
 LoadImage(ArrayFileName_Test, x_test, "test");
 print("Load Images To NDArray : OK!");
}
private void LoadImage(string[] a, NDArray b, string c)
{
 for (int i = 0; i < a.Length; i++)
 {
  b[i] = ReadTensorFromImageFile(a[i]);
  Console.Write(".");
 }
 Console.WriteLine();
 Console.WriteLine("Load Images To NDArray: " + c);
}
private NDArray ReadTensorFromImageFile(string file_name)
{
 using (var graph = tf.Graph().as_default())
 {
  var file_reader = tf.read_file(file_name, "file_reader");
  var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: n_channels, name: "DecodeJpeg");
  var cast = tf.cast(decodeJpeg, tf.float32);
  var dims_expander = tf.expand_dims(cast, 0);
  var resize = tf.constant(new int[] { img_h, img_w });
  var bilinear = tf.image.resize_bilinear(dims_expander, resize);
  var sub = tf.subtract(bilinear, new float[] { img_mean });
  var normalized = tf.divide(sub, new float[] { img_std });

  using (var sess = tf.Session(graph))
  {
   return sess.run(normalized);
  }
 }
}