实现0-6字符分类
数据准备: 训练数据:
train_data.txt
查看代码
D:/ocr/svm/train/imgs/0/0.png
0
D:/ocr/svm/train/imgs/0/0_1.jpg
0
D:/ocr/svm/train/imgs/1/1.png
1
D:/ocr/svm/train/imgs/1/1_1.jpg
1
D:/ocr/svm/train/imgs/1/1_2.jpg
1
D:/ocr/svm/train/imgs/1/1_3.jpg
1
D:/ocr/svm/train/imgs/2/2.png
2
D:/ocr/svm/train/imgs/2/2_1.jpg
2
D:/ocr/svm/train/imgs/2/2_2.jpg
2
D:/ocr/svm/train/imgs/3/3.png
3
D:/ocr/svm/train/imgs/3/3_1.jpg
3
D:/ocr/svm/train/imgs/3/3_2.jpg
3
D:/ocr/svm/train/imgs/4/4.png
4
D:/ocr/svm/train/imgs/4/4_1.jpg
4
D:/ocr/svm/train/imgs/4/4_2.jpg
4
D:/ocr/svm/train/imgs/5/5.png
5
D:/ocr/svm/train/imgs/5/5_1.jpg
5
D:/ocr/svm/train/imgs/5/5_2.jpg
5
D:/ocr/svm/train/imgs/6/6.png
6
D:/ocr/svm/train/imgs/6/6_1.jpg
6
D:/ocr/svm/train/imgs/6/6_2.jpg
6
数据处理及训练:
查看代码
#include <iostream>
#include <opencv2/ml/ml.hpp>
#include <opencv2/objdetect/objdetect.hpp>
void OcrTrain()
{
using namespace cv::ml;
using namespace std;
vector<string> imgpath; // path of train image
vector<int> imglabel; // label of train image
int nLine = 0;
string buf;
ifstream svm_data;
svm_data.open("d:/ocr/svm/train/train_data.txt", ios::in);
if (!svm_data.is_open())
{
cout << "read error" << endl;
exit(EXIT_FAILURE);
}
unsigned long n;
while (svm_data)
{
if (getline(svm_data, buf))
{
nLine++;
if (nLine % 2 == 0)
{
imglabel.push_back(atoi(buf.c_str()));
}
else
{
imgpath.push_back(buf);
}
}
}
svm_data.close(); //close file
/// <summary>
/// 训练
/// </summary>
Mat data_mat, res_mat;
int nImgNum = nLine / 2;
data_mat = Mat::zeros(nImgNum, 324, CV_32FC1); // store hog feature 324=9*4*9 single channel float
res_mat = Mat::zeros(nImgNum, 1, CV_32S); // store label 注意这里的数据类型32F不行
Mat src;
for (string::size_type i = 0; i != imgpath.size(); i++)
{
src = imread(imgpath[i], 1); // read train image
if (src.empty())
{
cout << "can not read the image: " << imgpath[i] << endl;
continue;
}
cout << "processing:" << endl;
Mat trainImg;
resize(src, trainImg, Size(28, 28)); // resize to 28*28
HOGDescriptor* hog = new HOGDescriptor(Size(28, 28), Size(14, 14), Size(7, 7), Size(7, 7), 9); // hog descriptor
vector<float> descriptors; // store result
hog->compute(trainImg, descriptors, Size(1, 1), Size(0, 0)); // compute hog descriptor
cout << "HOG dims:";
n = 0;
for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
{
data_mat.at<float>(i, n) = (*iter); // put hog descriptor into data_mat
n++;
}
res_mat.at<int>(i, 0) = imglabel[i]; // put label into res_mat
cout << "processing done:" << " " << endl;
}
Ptr<cv::ml::SVM> svm = SVM::create();//创建一个svm对象
svm->setType(cv::ml::SVM::C_SVC);
svm->setKernel(SVM::LINEAR);
svm->setDegree(0);
svm->setGamma(1);
svm->setCoef0(0);
svm->setC(1);
svm->setNu(0);
svm->setP(0);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 1000, TermCriteria::EPS));//设置SVM训练时迭代终止条件 10的12次方
//训练
cout << "开始进行训练..." << endl;
Ptr<TrainData> tData = TrainData::create(data_mat, ROW_SAMPLE, res_mat);
//svm->train(tData); //这两行代码和下面一行代码等效
svm->train(data_mat, cv::ml::SampleTypes::ROW_SAMPLE, res_mat);
Mat resp;
float err = svm->calcError(tData, false, resp);
//CvSVM svm;
//CvSVMParams param;
//param = CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 0.09, 1.0, 10.0, 0.5, 1.0, NULL, criteria); // svm parameter
//svm->train(data_mat, res_mat, Mat(), Mat(), param); //train
svm->save("D:/ocr/svm/HOG_SVM_OCR.xml"); // preserve result
}
加载模型预测:
查看代码
Mat test;
char result[512];
vector<string> img_tst_path;
ifstream img_tst("D:/ocr/svm/test_data.txt");
string test_dir = "D:/ocr/svm/test/";
while (img_tst)
{
if (getline(img_tst, buf))
{
buf = test_dir + buf;
img_tst_path.push_back(buf);
}
}
img_tst.close();
// 预测阶段
Ptr<cv::ml::SVM> svmLoad = StatModel::load<SVM>("D:/ocr/svm/HOG_SVM_OCR.xml");
ofstream predict_txt("D:/ocr/svm/SVM_PREDICT1.txt");
for (string::size_type j = 0; j != img_tst_path.size(); j++)
{
test = imread(img_tst_path[j], 1);
if (test.empty())
{
cout << "can not load the image:" << endl;
continue;
}
Mat trainTempImg;
resize(test, trainTempImg, Size(28, 28));
HOGDescriptor* hog = new HOGDescriptor(Size(28, 28), Size(14, 14), Size(7, 7), Size(7, 7), 9);
vector<float> descriptors;
hog->compute(trainTempImg, descriptors, Size(1, 1), Size(0, 0));
cout << "HOG dims:" << endl;
Mat SVMtrainMat(1, descriptors.size(), CV_32FC1);
int n = 0;
for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
{
SVMtrainMat.at<float>(0, n) = (*iter);
n++;
}
int ret = svmLoad->predict(SVMtrainMat); // predict by svm
sprintf_s(result, "%s %d\r\n", img_tst_path[j], ret);
cout << img_tst_path[j]<<" " <<ret << endl;
predict_txt << result;//predict result
}
predict_txt.close();
结果输出:
标签:svm,mat,jpg,opencv,SVM,train,imgs,ocr From: https://www.cnblogs.com/hakula/p/17702189.html