模型训练和模型保存
Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接;
使用.NET的异步线程安全队列BlockingCollection<T>,实现TensorFlow原生的队列管理器FIFOQueue;
在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。
模型的保存,可以选择每轮训练都保存,或最佳训练模型保存
#region Train
public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = (ArrayLabel_Train.Length) / batch_size;
var init = tf.global_variables_initializer();
sess.run(init);
var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);
path_model = Name + "MODEL";
Directory.CreateDirectory(path_model);
float loss_val = 100.0f;
float accuracy_val = 0f;
var sw = new Stopwatch();
sw.Start();
foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
// Randomly shuffle the training data at the beginning of each epoch
(ArrayFileName_Train, ArrayLabel_Train) = ShuffleArray(ArrayLabel_Train.Length, ArrayFileName_Train, ArrayLabel_Train);
y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)];
//decay learning rate
if (learning_rate_step != 0)
{
if ((epoch != 0) && (epoch % learning_rate_step == 0))
{
learning_rate_base = learning_rate_base * learning_rate_decay;
if (learning_rate_base <= learning_rate_min) { learning_rate_base = learning_rate_min; }
sess.run(tf.assign(learning_rate, learning_rate_base));
}
}
//Load local images asynchronously,use queue,improve train efficiency
BlockingCollection<(NDArray c_x, NDArray c_y, int iter)> BlockC = new BlockingCollection<(NDArray C1, NDArray C2, int iter)>(TrainQueueCapa);
Task.Run(() =>
{
foreach (var iteration in range(num_tr_iter))
{
var start = iteration * batch_size;
var end = (iteration + 1) * batch_size;
(NDArray x_batch, NDArray y_batch) = GetNextBatch(sess, ArrayFileName_Train, y_train, start, end);
BlockC.Add((x_batch, y_batch, iteration));
}
BlockC.CompleteAdding();
});
foreach (var item in BlockC.GetConsumingEnumerable())
{
sess.run(optimizer, (x, item.c_x), (y, item.c_y));
if (item.iter % display_freq == 0)
{
// Calculate and display the batch loss and accuracy
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y));
loss_val = result[0];
accuracy_val = result[1];
print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"));
sw.Restart();
}
}
// Run validation after every epoch
(loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid));
print("CNN:" + "---------------------------------------------------------");
print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
print("CNN:" + "---------------------------------------------------------");
if (SaverBest)
{
if (accuracy_val > max_accuracy)
{
max_accuracy = accuracy_val;
saver.save(sess, path_model + "CNN_Best");
print("CKPT Model is save.");
}
}
else
{
saver.save(sess, path_model + string.Format("CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val));
print("CKPT Model is save.");
}
}
Write_Dictionary(path_model + "dic.txt", Dict_Label);
}
private void Write_Dictionary(string path, Dictionary<Int64, string> mydic)
{
FileStream fs = new FileStream(path, FileMode.Create);
StreamWriter sw = new StreamWriter(fs);
foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "rn"); }
sw.Flush();
sw.Close();
fs.Close();
print("Write_Dictionary");
}
private (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation(y.shape[0]);
np.random.shuffle(perm);
return (x[perm], y[perm]);
}
private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
var slice = new Slice(start, end);
var x_batch = x[slice];
var y_batch = y[slice];
return (x_batch, y_batch);
}
private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end)
{
NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels);
int n = 0;
for (int i = start; i < end; i++)
{
NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE);
x_batch[n] = sess.run(normalized, (decodeJpeg, img4));
n++;
}
var slice = new Slice(start, end);
var y_batch = y[slice];
return (x_batch, y_batch);
}
#endregion










