首页 > 其他分享 >7、caffe之train函数片段之初始网络开始

7、caffe之train函数片段之初始网络开始

时间:2022-09-28 11:33:47浏览次数:61  
标签:片段 solver param train FLAGS gpus caffe type


当运行下列命令的时候

ubuntu@ubuntu:~/caffe$ ./examples/mnist/train_lenet.sh 

这是脚本train_lenet.sh 命令行(如果只有cpu 需要修改这个文件的lenet_solver.prototxt,选择 cpu)

#!/usr/bin/env sh
set -e

./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt $@

下面是caffe.cpp里面的函数段

// Train / Finetune a model.
int train() {
CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())
<< "Give a snapshot to resume training or weights to finetune "
"but not both.";
vector<string> stages = get_stages_from_flags();

caffe::SolverParameter solver_param;
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

solver_param.mutable_train_state()->set_level(FLAGS_level);
for (int i = 0; i < stages.size(); i++) {
solver_param.mutable_train_state()->add_stage(stages[i]);
}

// If the gpus flag is not provided, allow the mode and device to be set
// in the solver prototxt.
if (FLAGS_gpu.size() == 0
&& solver_param.has_solver_mode()
&& solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
if (solver_param.has_device_id()) {
FLAGS_gpu = "" +
boost::lexical_cast<string>(solver_param.device_id());
} else { // Set default GPU if unspecified
FLAGS_gpu = "" + boost::lexical_cast<string>(0);
}
}

vector<int> gpus;
get_gpus(&gpus);
if (gpus.size() == 0) {
LOG(INFO) << "Use CPU.";
Caffe::set_mode(Caffe::CPU);
} else {
ostringstream s;
for (int i = 0; i < gpus.size(); ++i) {
s << (i ? ", " : "") << gpus[i];
}
LOG(INFO) << "Using GPUs " << s.str();
#ifndef CPU_ONLY
cudaDeviceProp device_prop;
for (int i = 0; i < gpus.size(); ++i) {
cudaGetDeviceProperties(&device_prop, gpus[i]);
LOG(INFO) << "GPU " << gpus[i] << ": " << device_prop.name;
}
#endif
solver_param.set_device_id(gpus[0]);
Caffe::SetDevice(gpus[0]);
Caffe::set_mode(Caffe::GPU);
Caffe::set_solver_count(gpus.size());
}

caffe::SignalHandler signal_handler(
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

solver->SetActionFunction(signal_handler.GetActionFunction());

if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
CopyLayers(solver.get(), FLAGS_weights);
}

LOG(INFO) << "Starting Optimization";
if (gpus.size() > 1) {
#ifdef USE_NCCL
caffe::NCCL<float> nccl(solver);
nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL);
#else
LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL";
#endif
} else {
solver->Solve();
}
LOG(INFO) << "Optimization Done.";
return 0;
}
RegisterBrewFunction(train);
--solver=examples/mnist/lenet_solver.prototxt 文件里面的参数

代码中如FLAGS_solver等都是命令解析的参数gflags的参数使用
关于train test time 注册 前面已经写过demo 不在详细解释

vector< string> stages = get_stages_from_flags();

这函数主要是针对命令行多个参数进行统计(详细参见本博客c++基本知识讲解)

 caffe::SolverParameter solver_param;

这行代码表示使用了probuff解析了caffe.proto文件,然后使用的变量在caffe.pb.cc文件里面定义了

caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

函数的原定义在/caffe/src/caffe/util/update_proto.cpp中,这里主要是把命令行中​​--solver=examples/mnist/lenet_solver.prototxt​​​的lenet_solver.prototxt里面的参数读到caffe.proto变量中;
过程:FLAGS_solver 存放着命令行solver=“路径”;然后调用util/io.hpp文件中的函数进行ReadProtoFromTextFile(…)读操作 ;之后又调用io.cpp文件中的ReadProtoFromTextFile(…)函数,并写入了caffe.proto生成的.cc文件中;

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

进入了caffe/include/solver_factory.hpp文件,进行”注册
简单总结如下:
(1) 在caffe中,不同的配置文件中存在不同的梯度下降算法;以minist 手写为例,因为其lenet_solver.prototxt 配置文件中并未指示明确,但是在caffe.proto中,已经默认设置梯度下降算法SGD(其它可参考不同的配置文件:如lenet_adadelta_solver.prototxt;lenet_solver_adam.prototxt)
SolverRegistry::CreateSolver(solver_param)。
(2)然后匹配到map中注册过字符串和函数指针,匹配到了 SGDSolver;
(3)执行new SGDSolver(solver_param)创建solver。
(4) 又因为 new 的新对象 SGDSolver ,需要构造自身的构造函数,因为其是子类,因此需要先初始化父类solver.cpp的构造函数;所以执行了slover.cpp 构造函数,进行了初始化网络
具体流程:
程序启动之前:以sgd函数为例 其它函数类似(AdaDelta,RMSProp,…)
执行sgd_solver.cpp 宏定义REGISTER_SOLVER_CLASS(SGD);
然后执行solver_factory.hpp 宏展开#define REGISTER_SOLVER_CLASS(type) 的 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)的各类型梯度函数注册,第二个参数为

Solver* Creator_##type##Solver( 
const SolverParameter& param)
{ LOG(INFO) << "000000 solver type: " ;
return new type##Solver(param);
}
然后定义触发
#define REGISTER_SOLVER_CREATOR(type, creator)
static SolverRegisterer g_creator_f_##type(#type, creator);
static SolverRegisterer g_creator_d_##type(#type, creator)
在然后进行注册:
SolverRegistry::AddCreator(type, creator);
完成注册之后;
caffe.cpp 的
shared_ptr<caffe::Solver >
solver(caffe::SolverRegistry::CreateSolver(solver_param));
会进入solver_factory.hpp 执行registry[type](param);进行函数回调(map);
将具体执行
Solver* Creator_##type##Solver(
const SolverParameter& param)
{ LOG(INFO) << "000000 solver type: " ;
return new type##Solver(param);
}


子类继承父类构造 --初始化网络

   Solver<Dtype>::Solver(const SolverParameter& param)
: net_(), callbacks_(), requested_early_exit_(false) {
Init(param);
}

可以进入ubuntu系统的文件/tmp目录,会在运行一次产生caffe.ubuntu.ubuntu.log.INFO.20170627-141243.4440 日志信息,可以跟踪其打印的日志信息跟踪代码的阅读
下面是截取一小段glogs日志信息

Log file created at: 2017/06/27 14:12:43
Running on machine: ubuntu
Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg
I0627 14:12:43.187585 4440 caffe.cpp:211] Use CPU.
I0627 14:12:43.188031 4440 solver.cpp:44] Initializing solver from parameters:
test_iter: 100
test_interval: 500
base_lr: 0.01
display: 100
max_iter: 10000
lr_policy: "inv"
gamma: 0.0001
power: 0.75
momentum: 0.9
weight_decay: 0.0005
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
solver_mode: CPU
net: "examples/mnist/lenet_train_test.prototxt"
train_state {
level: 0
stage: ""
}
I0627 14:12:43.188195 4440 solver.cpp:87] Creating training net from net file: examples/mnist/lenet_train_test.prototxt
I0627 14:12:43.188627 4440 net.cpp:294] The NetState phase (0) differed from the phase (1) specified by a rule in layer mnist
I0627 14:12:43.188666 4440 net.cpp:294] The NetState phase (0) differed from the phase (1) specified by a rule in layer accuracy
I0627 14:12:43.188832 4440 net.cpp:51] Initializing net from parameters:
name: "LeNet"
state {
phase: TRAIN
level: 0
stage: ""
}
layer {


标签:片段,solver,param,train,FLAGS,gpus,caffe,type
From: https://blog.51cto.com/u_12504263/5719083

相关文章