首页 > trainer setup_Detectron2源码阅读笔记-(一)Configamp;Trainer

trainer setup_Detectron2源码阅读笔记-(一)Configamp;Trainer

一、代码结构概览

1.核心部分

  • configs:储存各种网络的yaml配置文件
  • datasets:存放数据集的地方
  • detectron2:运行代码的核心组件
  • tools:提供了运行代码的入口以及一切可视化的代码文件。

2.Tutorial部分

  • demo:显而易见就是demo
  • docs: 同样显而易见。。
  • tests:提供了一些测试代码
  • projects:提供了真实的项目代码示例,之后自己的代码结构可参照这个结构写。

二、代码逻辑分析

1.超参数配置

进入tools/train_net.pymain函数,第一行cfg = setup(args)是配置参数。Detectron2中的参数配置使用了yacs这个库,这个库能够很好地重用和拼接超参数文件配置。

我们先看一下detrctron2/config/的文件结构:

  • compat.py: 应该是对之前的Detectron库的兼容吧,可忽略。
  • config.py: 定义了一个CfgNode类,这个类继承自fvcore库(fb写的一个共公共库,提供一些共享的函数,方便各种不同项目使用)中定义的CfgNode,总之就是不断继承。。。继承关系是这样的: detrctron2.config.CfgNode->fcvore.common.config.CfgNode->yacs.config.CfgNode->dict 另外该文件还提供了get_cfg()方法,该方法会返回一个含有默认配置的CfgNode,而这些默认的配置值在下面的default.py中定义了,之所以这样做是因为要配置的默认值太多了,所以为了文档清晰才写到了一个新的文件中去,不过,yacs库的作者也建议这样做。
  • default.py: 如上面所说,该文件定义了各种参数的默认值。

了解配置函数的方法后我们再回到tools/train_net.py,我们一行一行的来理解。

  • tools/train_net.py
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
...def setup(args):"""Create configs and perform basic setups."""cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts)cfg.freeze()default_setup(cfg, args)return cfg
  • cfg = get_cfg(): 获取已经配置好默认参数的cfg
  • cfg.merge_from_file(args.config_file):config_file是指定的yaml配置文件,通过merge_from_file这个函数会将yaml文件中指定的超参数对默认值进行覆盖。
  • cfg.merge_from_list(args.opts):merge_from_list作用同上面的类似,只不过是通过命令行的方式覆盖。 例如
opts = ["SYSTEM.NUM_GPUS", 8, "TRAIN.SCALES", "(1, 2, 3, 4)"]
cfg.merge_from_list(opts)
print("cfgn",cfg)

那么最后会有

cfg
... (一些默认值超参数)
SYSTEM:NUM_GPUS: 8
TRAIN:SCALES: (1,2,3,4)
  • cfg.freeze(): freeze函数的作用是将超参数值冻结,避免被程序不小心修改。
  • default_setup(cfg, args):default_setupdetectron2/engine/default.py中提供的一个默认配置函数,具体是怎么配置的这里不详细说明了。不过需要知道的值这个文件中还提供了很多其他的配置函数,例如还提供了两个类:DefaultPredictorDefaultTrainer

2.Trainer

既然上面提到了DefaultTrainer,那么我们就从这个类入手了解一下detectron2.engine,其代码结构如下:

  • train_loop.py: 这个函数主要作用是提供了三个重要的类:
    • HookBase: 这是一个Hook的基类,用于指定在训练前后或者每一个step前后需要做什么事情,所以根据特定的需求需要对如下四种方法做不同的定义:before_train,after_train,before_step,after_step。以before_step
    • TrainerBase: 该类中定义的函数可以归纳成三种:
      • register_hooks:这个很好理解,就是将用户定义的一些hooks进行注册,说大白话就是把若干个Hook放在一个list里面去。之后只需要遍历这个list依次执行就可以了。
      • 第二类其实就是上面提到的遍历hook list并执行hook,不过这个遍历有四种,分别是before_train,after_train,before_step,after_step。还有一个就是run_step,这个函数其实就是平常我们在编写训练过程的代码,例如读数据,训练模型,获取损失值,求导数,反向梯度更新等,只不过在这个类里面没有定义。
      • 第三类就是train函数,它有两个参数,分别是开始的迭代数和最大的迭代数。之后就是重复依次执行第二类中的函数指定迭代次数。
    • SimpleTrainer:其实就是继承自TrainerBase,然后定义了run_step等方法。我们后面也可以继承这个类做进一步的自定义。
  • defaults.py: 上面已介绍,提供了两个类:DefaultPredictorDefaultTrainer,这个DefaultTrainer就继承自SimpleTrainer,所以存在如下继承关系: detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase
  • hooks.py:定义了很多继承自train_loop.HookBase的Hook。
  • launch.py: 前面提到过,可以理解成代码启动器,可以根据命令决定是否采用分布式训练(或者单机多卡)或者单机单卡训练。

好了,我们继续回到tools/train_net.py的main函数,代码如下所示。

def main(args):cfg = setup(args)if args.eval_only:...trainer = Trainer(cfg)trainer.resume_or_load(resume=args.resume)if cfg.TEST.AUG.ENABLED:trainer.register_hooks([hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))])return trainer.train()

可以看到下面定义了一个Trainer,它继承自detectron2.engine.default.DefaultTrainer,这个父类会自动解析cfg。之后只需要调用trainer.train()就可以开始训练了。

三、小结

v2-acc028de4e2fdaaee59a2e6723422fd6_b.jpg

至此我们对detectron2的逻辑有了大致的了解了,那么接下来我们来了解一下detectron2.engine.default.DefaultTrainer是如何解析cfg的,这部分内容请参见

Detectron2代码阅读笔记-(二) - marsggbo - 博客园​www.cnblogs.com
v2-a43c03c054b5cfbdd69037b3f710cc31_180x120.jpg

MARSGGBO♥原创

微信公众号: 【AutoML机器学习】

v2-3454a181dd5eadde86b29fe7c44f1fa2_b.jpg
AutoML机器学习
2019-10-15 10:37:50

更多相关:

  •     我刚刚接手这个项目的时候就被一系列不知所措的文件命名给深深的震惊了,那种振聋发聩不亚于听到赌王离世的消息。 首先请看,文件本来是用于处理“请假审批”,但是文件名居然叫做“teaApprove”,不要欺负我的初中英语不好,这个teaApprove我第一个感觉就是和“喝茶、茶叶”有关的业务,可是和我们这个项目八竿子打不着...

  • 这个问题简单,不做过多描述,如题所述,如果因为这个导致错误,请安装 npm install stylus-loader css-loader style-loader -D...

  • 使用这个宏TS_VERSION_MAOR来判断,这个宏定义在编译时生成在apidefs.h,它包含在ts/ts.h中,所以请在插件这包含...

  • linux valgrind Memcheck–内存检查工具 使用方法: 注意,这里要用debug版本,如果是release的运行文件,则用debug编译出来的可执行文件替换 输出到终端: valgrind --tool=memcheck --leak-check=full ./test.out 输出到文件: valgri...

  • THE START更新堪称轻量级MATLAB的一款软件最新版-Maplesoft Maple 2019.2 中文版。Maple是符号和数字计算环境,也是一种多范式编程语言,由Maplesoft开发,还涵盖了技术计算的其他方面,包括可视化,数据分析,矩阵计算和MATLAB连接。MapleSim工具箱添加了用于多域物理建模和代码生成的...

  • 同学们,你们在学习他人的代码,是否见过这样的代码 def main(): def user_info(gender): 当你还是个小萌新时,你一定会认为这是个很牛逼的语法。 当你有了一点基础时,你一定会想要了解这个语法,并且尝试去使用它。 那么今天,我们便来了解这个牛语法。 有了一点点的python基础,我们来看这段代...

  •     自从用了这些快捷键,鼓励师也不需要了,代码开发效率蹭蹭提升!!! ctrl+shift+[折叠代码 (这个比ctrl+k ctrl+l、ctrl+k ctr+j不知道好用多少倍!) ctrl+shift+]展开代码 ctrl+shift+T打开手贱不小心关掉的窗口 【推荐】ctrl+shift+O打开当前文件...

  • 在提交代码之前,建议最好先Fetch代码下来(如果有冲突,系统会提示),然后再操作Merge到本地分支,这样做是为了避免有其他人同时修改了当前分支,如果直接用Ctrl+T(pull代码)极有可能覆盖本地分支最新代码,安全起见先Fetch代码(Ctrl+Alt+Shift+1)——所谓:小心驶得万年船!...

  • 每次复制代码时,如果代码里有 // 这样的注释就容易让格式乱掉,通过下面的设置就可以避免这种情况。 粘贴代码时取消自动缩进 VIM在粘贴代码时会自动缩进,把代码搞得一团糟糕,甚至可能因为某行的一个注释造成后面的代码全部被注释掉,我知道有同学这个时候会用vi去打开文件再粘贴上去(鄙人以前就是这样),其实需要先设置一下 s...