DETR
RocheL
Aug 31, 2022
Last edited: 2022-9-4
type
Post
status
Published
date
Aug 31, 2022
slug
detr
summary
FAIR DETR论文+代码解析
tags
AI
PaperLib
engineer
category
学习思考
icon
password
Property
Aug 28, 2022 03:01 AM
URL
CNN+Transformer+集合预测做端到端的目标检测
DETR
文章挺好懂的,简单说一嘴,毕竟主打的就是简单易实现,性能其实也比较一般,是一个早期baseline,cite2k+,star9k+,半年就优化到sota也是很厉害了
传统的detection任务:
- anchor 针对中心点做可学习的初步的多样的锚框
- proposal 针对anchor做regression
具体可以看目标检测&语义分割(CNN) ,也可以直接搜目标检测、two-stage和one-stage等,有好多资料
总之最后需要对大量的锚框做后处理,包括非极大值抑制等等,不利于调参和部署。这些实际上是把一个集合预测(即哪一类对哪个框)变成一个间接的分类/回归问题,因此,detection领域呼唤一个真正端到端的架构。
训练:CNNbackbone抽特征,transformer encoder全局建模特征,过完之后连同一个object query进transformer decoder,object query限定出多少框,论文中选择100,即decoder固定出100个框。
集合预测,用二分图匹配(匈牙利算法)来解决loss的问题,在对比出来的100个框和gt时,先用matching loss得到跟gt最独一无二对应的输出中的两个框,然后与目标检测相同,算一个分类loss和banding box的loss,其余98个框被标记为背景类(无物体)
推理:推理的时候直接对100个数据卡阈值筛一步就完事啦
1066*800(三通道)进来,过cnn到batch*2048*25*34,1*1卷积降维到256*25*34,和256*25*34的位置编码相加,对长宽维度flatten成850*256的序列送transformer(是一个序列长度为256,每个元素850维度的input),过完encoder还是850*256。
encoder
:输入850*256的src,还有与之对应的850*256的pos,先二者相加得到q和k,再将src作为v做自注意力,过dropout,然后skip connection,然后norm,然后再过两层线性层,重复上述6次就ok了,得到850*256的encoder输出
decoder
:输入首先是一个100*256的objectquery,代码里叫tgt,初始化全0作为src,然后是100个元素torch.nn.Embedding作为跟src对应的位置编码(query_pos),- tgt和query_pos即decoder的src和pos相加做q和k,tgt做v,先做一个多头自注意力
- 过dropout,然后skip connection,然后norm(norm的位置对connection可前可后)更新tgt
- tgt和query_pos相加得到q,memory即encoder的输出和其对应的位置编码作k,memory作v,做一个多头cross attention
- 过dropout,然后skip connection,然后norm
- 重复上述6次
tgt是一个输出的100*256,加FFN,得到两个矩阵,一个1*91代表类别,1个1*4代表左上角点坐标和长宽,matcher和criterion与我无关,所以就结束嘞
DETR:CODE
facebookresearch/detr
先回顾一下python中@修饰符的用法
考虑到之后会从DETR里嫖代码的需求,以上面别人的blog为基础重新自己看了一遍做了注释,如下
不过我不做detection,所以看到搭完transformer就不看嘞