transformer由encoder和decoder俩部分组成。
Encoder
一个encoder由多个encoder_layer组成,在detr中默认是6层。
1 | class TransformerEncoder(nn.Module): |
EncoderLayer 的前向过程分为两种情况,一种是在输入多头自注意力层和前向反馈层前先进行归一化,另一种则是在这两个层输出后再进行归一化操作。对应实现可以参考如下图左侧部分:
1 | class TransformerEncoderLayer(nn.Module): |
需要注意的是,在输入多头自注意力层时需要先进行位置嵌入,即结合位置编码。注意仅对query和key实施,而value不需要。query和key是在图像特征中各个位置之间计算相关性,而value作为原图像特征,使用计算出来的相关性加权上去,得到各位置结合了全局相关性(增强/削弱)后的特征表示。
Query Embedding
在解析Decoder前,有必要先简要地谈谈query embedding,因为它是Decoder的主要输入之一。query embedding 有点anchor的味道,而且是自学习的anchor,作者使用了nn.Embedding实现:
1 | self.query_embed = nn.Embedding(num_queries, hidden_dim) |
其中num_queries 代表图像中有多少个目标(位置),默认是100个,对这些目标(位置)全部进行嵌入,维度映射到 hidden_dim,将 query_embedding 的权重作为参数输入到Transformer的前向过程,使用时与position encoding的方式相同:直接相加。
1 | hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] |
而这个query embedding应该加在哪呢?当然是我们需要预测的目标(query object)咯!可是网络一开始还没有输出,我们都不知道预测目标在哪里呀,如何将它实体化?作者也不知道,于是就简单粗暴地直接将它初始化为全0,shape和query embedding 的权重一致(从而可以element-wise add)。
1 | class Transformer(nn.Module): |
Decoder
Decoder的结构和Encoder十分类似。
1 | class TransformerDecoder(nn.Module): |
注意,在detr中,tgt_mask和memory_mask并未使用。需要注意的是intermediate中记录的是每层输出后的归一化结果,而每一层的输入是前一层输出(没有归一化)的结果。
DecoderLayer与Encoder的实现类似,只不过多了一层cross attention,其实质也是多头自注意力层,但是key和value来自于Encoder的输出。
1 | class TransformerDecoderLayer(nn.Module): |
注意,在tgt在输入到self_attn之前,需要经过position embedding,tgt+query_pos。在第二个多头注意力模块multihead_attn上,key和value均来自Encoder的输出。同样地,query和key要进行位置嵌入(而value不用)。这里cross attention计算的相关性是目标物体与图像特征各位置的相关性,然后再把这个相关性系数加权到Encoder编码后的图像特征(value)上,相当于获得了object features的意思,更好地表征了图像中的各个物体。从上面encoder和decoder的实现可以看出,作者非常强调位置嵌入的作用,每次进行attention计算前都需要进行position embedding,究其原因是因为transformer的转置不变性,即对排列和位置是不care的,然而在detection任务中却是十分重要的。
Transformer
将Encoder和Decoder封装在一起构成Transformer。
1 | class Transformer(nn.Module): |
注意,tgt是与query embedding形状一直且设置为全0的结果,意为初始化需要预测的目标。因为一开始并不清楚这些目标,所以初始化为全0。其会在Decoder的各层不断被refine,相当于一个coarse-to-fine的过程,但是真正要学习的是query embedding,学习到的是整个数据集中目标物体的统计特征,而tgt在每次迭代训练(一个batch数据刚到来)时会被重新初始化为0。
DETR
DETR包含backbone,encoder, decoder, prediction heads四个部分。encoder和decoder通常会用一个transformer来实现。prediction heads部分包括分类和回归。
1 | class DETR(nn.Module): |
Postprocess
一部分DETR的输出并不是最终预测结果的形式,还需要进行简单的后处理。但是这里的后处理并不是NMS哦!DETR预测的是集合,并且在训练过程中经过匈牙利算法与GT一对一匹配学习,因此不存在重复框的情况。
1 | class PostProcess(nn.Module): |
Loss Fuction
这一部分主要介绍一下和损失函数相关的部分源码。先看一下与损失函数相关的代码:
1 | matcher = build_matcher(args) |
matcher是将预测结果与gt进行匹配的匈牙利算法,weight_dict是各部分loss设置的权重参数,包括分类与回归损失。分类使用的是CE loss,回归包括l1 loss和giou loss。如果包含分割任务,还有mask相关损失函数,另外如果设置了aux_loss,则代表计算decoder中间层预测结果对应的loss。 loss函数的实例化使用SetCriterion进行构建的。
1 | class SetCriterion(nn.Module): |
从forward函数可以看出,首先进行匈牙利匹配的是decoder最后一层的输出,之后再计算匹配后的损失函数包括losses = [‘labels’, ‘boxes’, ‘cardinality’],具体计算部分可以看get_loss方法中映射的对应计算方法,其中包括self.loss_labels,self.loss_cardinality,self.loss_boxes。
匈牙利匹配
匈牙利算法,在这里用于预测集(prediction set)和GT的匹配,最终匹配方案是选取“loss总和”最小的分配方式。注意,这里计算的loss与损失函数中计算loss并不相同,在这里是用来作为代价cost,cost大小决定匹配程度。
1 | class HungarianMatcher(nn.Module): |
从上面可以看到,匈牙利匹配在前向计算过程中,是不需要梯度的。其中分类cost是直接采用1减去预测概率的形式,同时由于1是常数,于是作者甚至连1都省去了,在box上计算了l1和giou两种cost,之后对各部分进行加权求和得到总的cost。匹配方法使用的是 scipy 优化模块中的 linear_sum_assignment(),其输入是二分图的度量矩阵,该方法是计算这个二分图度量矩阵的最小权重分配方式,返回的是匹配方案对应的矩阵行索引和列索引。
End
至此,DETR所有相关源码均已解读完毕。