DETR

RocheL
Aug 31, 2022
Last edited: 2022-9-4
type
Post
status
Published
date
Aug 31, 2022
slug
detr
summary
FAIR DETR论文+代码解析
tags
AI
PaperLib
engineer
category
学习思考
icon
password
Property
Aug 28, 2022 03:01 AM
URL
CNN+Transformer+集合预测做端到端的目标检测

DETR

文章挺好懂的,简单说一嘴,毕竟主打的就是简单易实现,性能其实也比较一般,是一个早期baseline,cite2k+,star9k+,半年就优化到sota也是很厉害了

传统的detection任务:

  • anchor 针对中心点做可学习的初步的多样的锚框
  • proposal 针对anchor做regression
具体可以看
🧱
目标检测&语义分割(CNN)
,也可以直接搜目标检测、two-stage和one-stage等,有好多资料
总之最后需要对大量的锚框做后处理,包括非极大值抑制等等,不利于调参和部署。这些实际上是把一个集合预测(即哪一类对哪个框)变成一个间接的分类/回归问题,因此,detection领域呼唤一个真正端到端的架构。
notion image
训练:CNNbackbone抽特征,transformer encoder全局建模特征,过完之后连同一个object query进transformer decoder,object query限定出多少框,论文中选择100,即decoder固定出100个框。 集合预测,用二分图匹配(匈牙利算法)来解决loss的问题,在对比出来的100个框和gt时,先用matching loss得到跟gt最独一无二对应的输出中的两个框,然后与目标检测相同,算一个分类loss和banding box的loss,其余98个框被标记为背景类(无物体)
推理:推理的时候直接对100个数据卡阈值筛一步就完事啦
notion image
notion image
1066*800(三通道)进来,过cnn到batch*2048*25*34,1*1卷积降维到256*25*34,和256*25*34的位置编码相加,对长宽维度flatten成850*256的序列送transformer(是一个序列长度为256,每个元素850维度的input),过完encoder还是850*256。
encoder:输入850*256的src,还有与之对应的850*256的pos,先二者相加得到q和k,再将src作为v做自注意力,过dropout,然后skip connection,然后norm,然后再过两层线性层,重复上述6次就ok了,得到850*256的encoder输出 decoder:输入首先是一个100*256的objectquery,代码里叫tgt,初始化全0作为src,然后是100个元素torch.nn.Embedding作为跟src对应的位置编码(query_pos),
  1. tgt和query_pos即decoder的src和pos相加做q和k,tgt做v,先做一个多头自注意力
  1. 过dropout,然后skip connection,然后norm(norm的位置对connection可前可后)更新tgt
  1. tgt和query_pos相加得到q,memory即encoder的输出和其对应的位置编码作k,memory作v,做一个多头cross attention
  1. 过dropout,然后skip connection,然后norm
  1. 重复上述6次
tgt是一个输出的100*256,加FFN,得到两个矩阵,一个1*91代表类别,1个1*4代表左上角点坐标和长宽,matcher和criterion与我无关,所以就结束嘞
 
 

DETR:CODE

facebookresearch/detr
 
先回顾一下python中@修饰符的用法
DETR源码笔记(一)_在努力的松鼠的博客-CSDN博客_detr源码
源码获取:https://gitee.com/fgy120/DETR 首先对DETR做个简单介绍 目录 main函数 (一) 参数设置 main函数(二)搭建模型 build_backbone():包括构建位置编码器以及backbone 上图即为DETR的流程pipeline,相比以前的RCNN系列、YOLO系列等,最特别的在于加入了Transformer。 直接看源码,从train.py的主函数开始。 if __name__ == '__main__': parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True)#以output_dir创建Path对象并执行mkdir创建文件夹操作 main(args) 首先是常规的参数解析操作,利用的argparse库,主要通过解析命令行输入的参数来设置模型训练的超参数或其他设置。第一步创建解析对象parser,运行parser.parse_args方法得到解析后的各个参数args,默认为解析运行代码的命令行。如果其中包含output_dir参数且output_dir不存在,利用Pathlib中的Path库的mkdir方法创建output_dir的路径文件夹。 Path(args.output_dir).mkdir(parents=True, exist_ok=True) parents:如果父目录不存在,是否创建父目录。 exist_ok:只有在目录不存在时创建目录,目录已存在时不会抛出异常。 argparse具体介绍可以看这篇。 argparse解析器_在努力的松鼠的博客-CSDN博客 argparse是一个Python模块,用来解析命令行参数,主要有三个步骤:1、创建 ArgumentParser() 对象2、调用 add_argument() 方法添加参数3、使用 parse_args() 解析添加的参数一、创建解析器代码示例:parser = argparse.ArgumentParser('DETR training and evaluation script',
DETR源码笔记(一)_在努力的松鼠的博客-CSDN博客_detr源码
考虑到之后会从DETR里嫖代码的需求,以上面别人的blog为基础重新自己看了一遍做了注释,如下
不过我不做detection,所以看到搭完transformer就不看嘞
 
AI&IS部分知识点Transformer系列