mmdetection详解指北 (二)

上一篇博客主要介绍到mmdetection这个检测框架的一些结构设计以及代码的总体逻辑。这篇就主要介绍一下在mmdetection被大量使用的配置和注册。

配置类

配置方式支持 python/json/yaml, 从 mmcv 的 Config 解析, 其功能同 maskrcnn-benchmark 的 yacs 类似, 将字典的取值方式属性化. 这里帖部分代码,以供学习。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
class Config(object):
"""A facility for config and config files.

It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.

Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"

"""

@staticmethod
def fromfile(filename):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
if filename.endswith('.py'):
module_name = osp.basename(filename)[:-3]
if '.' in module_name:
raise ValueError('Dots are not allowed in config file path.')
config_dir = osp.dirname(filename)
sys.path.insert(0, config_dir)
mod = import_module(module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(filename)
else:
raise IOError('Only py/yml/yaml/json type are supported now!')
return Config(cfg_dict, filename=filename)

@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)
"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.fromfile(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument('config', help='config file path')
add_args(parser, cfg)
return parser, cfg

def __init__(self, cfg_dict=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but got {}'.format(
type(cfg_dict)))

super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if filename:
with open(filename, 'r') as f:
super(Config, self).__setattr__('_text', f.read())
else:
super(Config, self).__setattr__('_text', '')

@property
def filename(self):
return self._filename

@property
def text(self):
return self._text

def __repr__(self):
return 'Config (path: {}): {}'.format(self.filename,
self._cfg_dict.__repr__())

def __len__(self):
return len(self._cfg_dict)
# 获取key值
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
# 序列化
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
# 序列化
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
# 更新key值
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
# 迭代器
def __iter__(self):
return iter(self._cfg_dict)

主要考虑点是自己怎么实现类似的东西,核心点就是 python 的基本魔法函数的应用,可同时参考 yacs。

注册器

把基本对象放到一个继承了字典的对象中,实现了对象的灵活管理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class Registry:
"""A registry to map strings to classes.

Args:
name (str): Registry name.
"""

def __init__(self, name):
self._name = name
self._module_dict = dict()

def __len__(self):
return len(self._module_dict)

def __contains__(self, key):
return self.get(key) is not None

def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str

@property
def name(self):
return self._name

@property
def module_dict(self):
return self._module_dict

def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.

Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)

def _register_module(self, module_class, module_name=None, force=False):
# 校验当前注册的module_class是否是类对象
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')

if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class # 类 名 : 类

def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls

def register_module(self, name=None, force=False, module=None):
# 作 为 类 name 的 装 饰 器
"""Register a module.

A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.

Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass

Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module

# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')

# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls

return _register
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.

Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.

Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')

args = cfg.copy()
obj_type = args.pop('type')
if is_str(obj_type):
# 从 注 册 类 中 拿 出obj_type类
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# 增 加 一 些 新 的 参 数
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)# **args 是 将 字 典 解 析 成 位 置 参 数(k=v)。