C#使用TensorFlow.NET训练自己的数据集的方法

2020-03-22 20:01:38王冬梅

模型训练和模型保存

Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接;

使用.NET的异步线程安全队列BlockingCollection<T>,实现TensorFlow原生的队列管理器FIFOQueue;

在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。

模型的保存,可以选择每轮训练都保存,或最佳训练模型保存

#region Train
public void Train(Session sess)
{
 // Number of training iterations in each epoch
 var num_tr_iter = (ArrayLabel_Train.Length) / batch_size;

 var init = tf.global_variables_initializer();
 sess.run(init);

 var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);

 path_model = Name + "MODEL";
 Directory.CreateDirectory(path_model);

 float loss_val = 100.0f;
 float accuracy_val = 0f;

 var sw = new Stopwatch();
 sw.Start();
 foreach (var epoch in range(epochs))
 {
  print($"Training epoch: {epoch + 1}");
  // Randomly shuffle the training data at the beginning of each epoch 
  (ArrayFileName_Train, ArrayLabel_Train) = ShuffleArray(ArrayLabel_Train.Length, ArrayFileName_Train, ArrayLabel_Train);
  y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)];

  //decay learning rate
  if (learning_rate_step != 0)
  {
   if ((epoch != 0) && (epoch % learning_rate_step == 0))
   {
    learning_rate_base = learning_rate_base * learning_rate_decay;
    if (learning_rate_base <= learning_rate_min) { learning_rate_base = learning_rate_min; }
    sess.run(tf.assign(learning_rate, learning_rate_base));
   }
  }

  //Load local images asynchronously,use queue,improve train efficiency
  BlockingCollection<(NDArray c_x, NDArray c_y, int iter)> BlockC = new BlockingCollection<(NDArray C1, NDArray C2, int iter)>(TrainQueueCapa);
  Task.Run(() =>
     {
      foreach (var iteration in range(num_tr_iter))
      {
       var start = iteration * batch_size;
       var end = (iteration + 1) * batch_size;
       (NDArray x_batch, NDArray y_batch) = GetNextBatch(sess, ArrayFileName_Train, y_train, start, end);
       BlockC.Add((x_batch, y_batch, iteration));
      }
      BlockC.CompleteAdding();
     });

  foreach (var item in BlockC.GetConsumingEnumerable())
  {
   sess.run(optimizer, (x, item.c_x), (y, item.c_y));

   if (item.iter % display_freq == 0)
   {
    // Calculate and display the batch loss and accuracy
    var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y));
    loss_val = result[0];
    accuracy_val = result[1];
    print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"));
    sw.Restart();
   }
  }    

  // Run validation after every epoch
  (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid));
  print("CNN:" + "---------------------------------------------------------");
  print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
  print("CNN:" + "---------------------------------------------------------");

  if (SaverBest)
  {
   if (accuracy_val > max_accuracy)
   {
    max_accuracy = accuracy_val;
    saver.save(sess, path_model + "CNN_Best");
    print("CKPT Model is save.");
   }
  }
  else
  {
   saver.save(sess, path_model + string.Format("CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val));
   print("CKPT Model is save.");
  }
 }
 Write_Dictionary(path_model + "dic.txt", Dict_Label);
}
private void Write_Dictionary(string path, Dictionary<Int64, string> mydic)
{
 FileStream fs = new FileStream(path, FileMode.Create);
 StreamWriter sw = new StreamWriter(fs);
 foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "rn"); }
 sw.Flush();
 sw.Close();
 fs.Close();
 print("Write_Dictionary");
}
private (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
 var perm = np.random.permutation(y.shape[0]);
 np.random.shuffle(perm);
 return (x[perm], y[perm]);
}
private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
 var slice = new Slice(start, end);
 var x_batch = x[slice];
 var y_batch = y[slice];
 return (x_batch, y_batch);
}
private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end)
{
 NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels);
 int n = 0;
 for (int i = start; i < end; i++)
 {
  NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE);
  x_batch[n] = sess.run(normalized, (decodeJpeg, img4));
  n++;
 }
 var slice = new Slice(start, end);
 var y_batch = y[slice];
 return (x_batch, y_batch);
}
#endregion