0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

DETR架构的内部工作方式分析

新机器视觉 来源:AI公园 作者:AI公园 2023-08-30 10:53 次阅读

这是一个Facebook的目标检测Transformer (DETR)的完整指南。

介绍

DEtection TRansformer (DETR)是Facebook研究团队巧妙地利用了Transformer 架构开发的一个目标检测模型。在这篇文章中,我将通过分析DETR架构的内部工作方式来帮助提供一些关于它的直觉。

下面,我将解释一些结构,但是如果你只是想了解如何使用模型,可以直接跳到代码部分。

结构

DETR模型由一个预训练的CNN骨干(如ResNet)组成,它产生一组低维特征集。这些特征被格式化为一个特征集合并添加位置编码,输入一个由Transformer组成的编码器和解码器中,和原始的Transformer论文中描述的Encoder-Decoder的使用方式非常的类似。解码器的输出然后被送入固定数量的预测头,这些预测头由预定义数量的前馈网络组成。每个预测头的输出都包含一个类预测和一个预测框。损失是通过计算二分匹配损失来计算的。

该模型做出了预定义数量的预测,并且每个预测都是并行计算的。

CNN主干

假设我们的输入图像,有三个输入通道。CNN backbone由一个(预训练过的)CNN(通常是ResNet)组成,我们用它来生成C个具有宽度W和高度H的低维特征(在实践中,我们设置C=2048, W=W₀/32和H=H₀/32)。

这留给我们的是C个二维特征,由于我们将把这些特征传递给一个transformer,每个特征必须允许编码器将每个特征处理为一个序列的方式重新格式化。这是通过将特征矩阵扁平化为H⋅W向量,然后将每个向量连接起来来实现的。

7a301044-4648-11ee-a2ef-92fbcf53809c.png

扁平化的卷积特征再加上空间位置编码,位置编码既可以学习,也可以预定义。

The Transformer

Transformer几乎与原始的编码器-解码器架构完全相同。不同之处在于,每个解码器层并行解码N个(预定义的数目)目标。该模型还学习了一组N个目标的查询,这些查询是(类似于编码器)学习出来的位置编码。

7a3c1754-4648-11ee-a2ef-92fbcf53809c.png

目标查询

下图描述了N=20个学习出来的目标查询(称为prediction slots)如何聚焦于一张图像的不同区域。

7a6fa8c6-4648-11ee-a2ef-92fbcf53809c.png

“我们观察到,在不同的操作模式下,每个slot 都会学习特定的区域和框大小。“ —— DETR的作者

理解目标查询的直观方法是想象每个目标查询都是一个人。每个人都可以通过注意力来查看图像的某个区域。一个目标查询总是会问图像中心是什么,另一个总是会问左下角是什么,以此类推。

使用PyTorch实现简单的DETR

importtorch
importtorch.nnasnn
fromtorchvision.modelsimportresnet50

classSimpleDETR(nn.Module):
"""
MinimalExampleoftheDetectionTransformermodelwithlearnedpositionalembedding
"""
def__init__(self,num_classes,hidden_dim,num_heads,
num_enc_layers,num_dec_layers):
super(SimpleDETR,self).__init__()
self.num_classes=num_classes
self.hidden_dim=hidden_dim
self.num_heads=num_heads
self.num_enc_layers=num_enc_layers
self.num_dec_layers=num_dec_layers
#CNNBackbone
self.backbone=nn.Sequential(
*list(resnet50(pretrained=True).children())[:-2])
self.conv=nn.Conv2d(2048,hidden_dim,1)
#Transformer
self.transformer=nn.Transformer(hidden_dim,num_heads,
num_enc_layers,num_dec_layers)
#PredictionHeads
self.to_classes=nn.Linear(hidden_dim,num_classes+1)
self.to_bbox=nn.Linear(hidden_dim,4)
#PositionalEncodings
self.object_query=nn.Parameter(torch.rand(100,hidden_dim))
self.row_embed=nn.Parameter(torch.rand(50,hidden_dim//2)
self.col_embed=nn.Parameter(torch.rand(50,hidden_dim//2))

defforward(self,X):
X=self.backbone(X)
h=self.conv(X)
H,W=h.shape[-2:]
pos_enc=torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H,1,1),
self.row_embed[:H].unsqueeze(1).repeat(1,W,1)],
dim=-1).flatten(0,1).unsqueeze(1)
h=self.transformer(pos_enc+h.flatten(2).permute(2,0,1),
self.object_query.unsqueeze(1))
class_pred=self.to_classes(h)
bbox_pred=self.to_bbox(h).sigmoid()

returnclass_pred,bbox_pred

二分匹配损失 (Optional)

让为预测的集合,其中是包括了预测类别(可以是空类别)和包围框的二元组,其中上划线表示框的中心点,和表示框的宽和高。

设y为ground truth集合。假设y和ŷ之间的损失为L,每一个yᵢ和ŷᵢ之间的损失为L。由于我们是在集合的层次上工作,损失L必须是排列不变的,这意味着无论我们如何排序预测,我们都将得到相同的损失。因此,我们想找到一个排列,它将预测的索引映射到ground truth目标的索引上。在数学上,我们求解:

7a985e92-4648-11ee-a2ef-92fbcf53809c.png

计算的过程称为寻找最优的二元匹配。这可以用匈牙利算法找到。但为了找到最优匹配,我们需要实际定义一个损失函数,计算和之间的匹配成本。

回想一下,我们的预测包含一个边界框和一个类。现在让我们假设类预测实际上是一个类集合上的概率分布。那么第i个预测的总损失将是类预测产生的损失和边界框预测产生的损失之和。作者在http://arxiv.org/abs/1906.05909中将这种损失定义为边界框损失和类预测概率的差异:

7ab186e2-4648-11ee-a2ef-92fbcf53809c.png

其中,是的argmax,是是来自包围框的预测的损失,如果,则表示匹配损失为0。

框损失的计算为预测值与ground truth的L₁损失和的GIOU损失的线性组合。同样,如果你想象两个不相交的框,那么框的错误将不会提供任何有意义的上下文(我们可以从下面的框损失的定义中看到)。

7ab99706-4648-11ee-a2ef-92fbcf53809c.png

其中,λᵢₒᵤ和是超参数。注意,这个和也是面积和距离产生的误差的组合。为什么会这样呢?

可以把上面的等式看作是与预测相关联的总损失,其中面积误差的重要性是λᵢₒᵤ,距离误差的重要性是。

现在我们来定义GIOU损失函数。定义如下:

7acecafe-4648-11ee-a2ef-92fbcf53809c.png

由于我们从已知的已知类的数目来预测类,那么类预测就是一个分类问题,因此我们可以使用交叉熵损失来计算类预测误差。我们将损失函数定义为每N个预测损失的总和:

7ad9d0c0-4648-11ee-a2ef-92fbcf53809c.png

为目标检测使用DETR

在这里,你可以学习如何加载预训练的DETR模型,以便使用PyTorch进行目标检测。

加载模型

首先导入需要的模块。

#Importrequiredmodules
importtorch
fromtorchvisionimporttransformsasTimportrequests#forloadingimagesfromweb
fromPILimportImage#forviewingimages
importmatplotlib.pyplotasplt

下面的代码用ResNet50作为CNN骨干从torch hub加载预训练的模型。其他主干请参见DETR github:https://github.com/facebookresearch/detr

detr=torch.hub.load('facebookresearch/detr',
'detr_resnet50',
pretrained=True)

加载一张图像

要从web加载图像,我们使用requests库:

url='https://www.tempetourism.com/wp-content/uploads/Postino-Downtown-Tempe-2.jpg'#Sampleimageimage=Image.open(requests.get(url,stream=True).raw)plt.imshow(image)
plt.show()

设置目标检测的Pipeline

为了将图像输入到模型中,我们需要将PIL图像转换为张量,这是通过使用torchvision的transforms库来完成的。

transform=T.Compose([T.Resize(800),
T.ToTensor(),
T.Normalize([0.485,0.456,0.406],
[0.229,0.224,0.225])])

上面的变换调整了图像的大小,将PIL图像进行转换,并用均值-标准差对图像进行归一化。其中[0.485,0.456,0.406]为各颜色通道的均值,[0.229,0.224,0.225]为各颜色通道的标准差。

我们装载的模型是预先在COCO Dataset上训练的,有91个类,还有一个表示空类(没有目标)的附加类。我们用下面的代码手动定义每个标签

CLASSES=
['N/A','Person','Bicycle','Car','Motorcycle','Airplane','Bus','Train','Truck','Boat','Traffic-Light','Fire-Hydrant','N/A','Stop-Sign','ParkingMeter','Bench','Bird','Cat','Dog','Horse','Sheep','Cow','Elephant','Bear','Zebra','Giraffe','N/A','Backpack','Umbrella','N/A','N/A','Handbag','Tie','Suitcase','Frisbee','Skis','Snowboard','Sports-Ball','Kite','BaseballBat','BaseballGlove','Skateboard','Surfboard','TennisRacket','Bottle','N/A','WineGlass','Cup','Fork','Knife','Spoon','Bowl','Banana','Apple','Sandwich','Orange','Broccoli','Carrot','Hot-Dog','Pizza','Donut','Cake','Chair','Couch','PottedPlant','Bed','N/A','DiningTable','N/A','N/A','Toilet','N/A','TV','Laptop','Mouse','Remote','Keyboard','Cell-Phone','Microwave','Oven','Toaster','Sink','Refrigerator','N/A','Book','Clock','Vase','Scissors','Teddy-Bear','Hair-Dryer','Toothbrush']

如果我们想输出不同颜色的边框,我们可以手动定义我们想要的RGB格式的颜色

COLORS=[
[0.000,0.447,0.741],
[0.850,0.325,0.098],
[0.929,0.694,0.125],
[0.494,0.184,0.556],
[0.466,0.674,0.188],
[0.301,0.745,0.933]
]

格式化输出

我们还需要重新格式化模型的输出。给定一个转换后的图像,模型将输出一个字典,包含100个预测类的概率和100个预测边框。

每个包围框的形式为(x, y, w, h),其中(x,y)为包围框的中心(包围框是单位正方形[0,1]×[0,1]), w, h为包围框的宽度和高度。因此,我们需要将边界框输出转换为初始和最终坐标,并重新缩放框以适应图像的实际大小。

下面的函数返回边界框端点:

#Getcoordinates(x0,y0,x1,y0)frommodeloutput(x,y,w,h)defget_box_coords(boxes):
x,y,w,h=boxes.unbind(1)
x0,y0=(x-0.5*w),(y-0.5*h)
x1,y1=(x+0.5*w),(y+0.5*h)
box=[x0,y0,x1,y1]
returntorch.stack(box,dim=1)

我们还需要缩放了框的大小。下面的函数为我们做了这些:

#Scaleboxfrom[0,1]x[0,1]to[0,width]x[0,height]defscale_boxes(output_box,width,height):
box_coords=get_box_coords(output_box)
scale_tensor=torch.Tensor(
[width,height,width,height]).to(
torch.cuda.current_device())returnbox_coords*scale_tensor

现在我们需要一个函数来封装我们的目标检测pipeline。下面的detect函数为我们完成了这项工作。

#ObjectDetectionPipelinedefdetect(im,model,transform):
device=torch.cuda.current_device()
width=im.size[0]
height=im.size[1]

#mean-stdnormalizetheinputimage(batch-size:1)
img=transform(im).unsqueeze(0)
img=img.to(device)

#demomodelonlysupportbydefaultimageswithaspectratiobetween0.5and2assertimg.shape[-2]<= 1600 and img.shape[-1] <= 1600,    # propagate through the model
    outputs = model(img)    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values >0.85

#convertboxesfrom[0;1]toimagescales
bboxes_scaled=scale_boxes(outputs['pred_boxes'][0,keep],width,height)returnprobas[keep],bboxes_scaled

现在,我们需要做的是运行以下程序来获得我们想要的输出:

probs,bboxes=detect(image,detr,transform)

绘制结果

现在我们有了检测到的目标,我们可以使用一个简单的函数来可视化它们。

#PlotPredictedBoundingBoxesdefplot_results(pil_img,prob,boxes,labels=True):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax=plt.gca()

forprob,(x0,y0,x1,y1),colorinzip(prob,boxes.tolist(),COLORS*100):ax.add_patch(plt.Rectangle((x0,y0),x1-x0,y1-y0,
fill=False,color=color,linewidth=2))
cl=prob.argmax()
text=f'{CLASSES[cl]}:{prob[cl]:0.2f}'
iflabels:
ax.text(x0,y0,text,fontsize=15,
bbox=dict(facecolor=color,alpha=0.75))
plt.axis('off')
plt.show()

现在可以可视化结果:

plot_results(image,probs,bboxes,labels=True)

审核编辑:彭菁

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • Facebook
    +关注

    关注

    3

    文章

    1428

    浏览量

    54033
  • 代码
    +关注

    关注

    30

    文章

    4555

    浏览量

    66766
  • 检测模型
    +关注

    关注

    0

    文章

    15

    浏览量

    7275
  • Transformer
    +关注

    关注

    0

    文章

    130

    浏览量

    5898

原文标题:Transformer (DETR) 对象检测实操!

文章出处:【微信号:vision263com,微信公众号:新机器视觉】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    MCS-51内部定时/计数器有哪几种工作方式

    《单片机》实验——实验3 MCS-51内部定时/计数器实验(1)一、实验目的二、实验内容一、实验目的掌握定时/计数器的4种工作方式工作特点及应用掌握长时间段定时的实现方法掌握查询控制的定时/计数器
    发表于 12-01 08:08

    GPIO基本结构和工作方式介绍

    GPIO的8种工作方式一、GPIO基本结构和工作方式1、战舰/精英板2、Min板3、基本结构4、工作方式二、GPIO寄存器说明1、GPIO相关寄存器2、端口配置低寄存器(GPIOx_CRL)、端口
    发表于 01-11 07:02

    SPI总线的工作方式是什么?

    SPI总线具有哪些特点?SPI总线的工作方式是什么?
    发表于 01-25 06:57

    如何实现Protothread这种工作方式

    如何实现Protothread这种工作方式呢?
    发表于 02-25 07:56

    α调制工作方式原理

    以单相—单相直接变频电路为例说明α调制工作方式的原理及其实现方法。图4.2为单相桥式AC/AC变换电路。为了在负载一获得交变电压,可以交替地让正组变流器和负组变流器轮流
    发表于 07-27 09:10 511次阅读
    α调制<b class='flag-5'>工作方式</b>原理

    鼠标的工作方式

    鼠标的工作方式 工作方式是指鼠标采用什么工作原理或方式进行工作。常见的鼠标工作方式有滚轮式和光
    发表于 12-28 11:38 790次阅读

    Wifi模块的工作方式功能是什么?

    Wifi模块的工作方式是什么呢,Wifi模块的主要功能又有哪些呢?本文主要介绍了有关Wifi模块的基础知识即:Wifi模块的工作方式、主要功能及应用领域。
    发表于 06-12 14:22 5736次阅读

    步进电机及驱动电路工作原理及工作方式介绍

    步进电机及驱动电路工作原理及工作方式介绍
    发表于 05-11 18:00 0次下载

    cd4013无稳态工作方式及无稳态电路应用

    CD4013是CMOS双D触发器,内部集成了两个性能相同,引脚独立(电源共用)的D触发器,采用14引脚双列直插塑料封装,CD4013有四种基本方式,即数据锁存器,单稳态工作方式,无稳态工作方式
    发表于 12-01 11:14 9671次阅读
    cd4013无稳态<b class='flag-5'>工作方式</b>及无稳态电路应用

    一文总结蓝牙模块的工作方式汇总,很全值得收藏!

    蓝牙模块的工作方式有哪些呢?资料总结了蓝牙模块的常见的7种工作方式,需要的亲可以收藏下
    发表于 04-26 15:05 15次下载
    一文总结蓝牙模块的<b class='flag-5'>工作方式</b>汇总,很全值得收藏!

    8255a有哪几种工作方式?8251a的工作方式工作原理

    本文首先介绍了8255芯片的概念与特性,其次介绍了8255A引脚图及功能,最后介绍了8255a的几种工作方式工作原理。
    的头像 发表于 05-23 14:40 6.1w次阅读
    8255a有哪几种<b class='flag-5'>工作方式</b>?8251a的<b class='flag-5'>工作方式</b>及<b class='flag-5'>工作</b>原理

    单片机定时器的四种工作方式解析

    1 工作方式0 定时器/计数器T0工作方式0时,16位计数器只用了13位,即TH0的高8位和TL0的低5位,组成一个13位定时器/计数器。 1)、工作在定时
    发表于 09-18 15:57 4.9w次阅读
    单片机定时器的四种<b class='flag-5'>工作方式</b>解析

    AD级联的工作方式配置和AD双排序的工作方式配置详细说明

    本文档的主要内容详细介绍的是AD级联的工作方式配置和AD双排序的工作方式配置详细说明
    发表于 12-23 08:00 2次下载
    AD级联的<b class='flag-5'>工作方式</b>配置和AD双排序的<b class='flag-5'>工作方式</b>配置详细说明

    如何使用Transformer来做物体检测?

    )是Facebook研究团队巧妙地利用了Transformer 架构开发的一个目标检测模型。在这篇文章中,我将通过分析DETR架构内部
    的头像 发表于 04-25 10:45 2341次阅读
    如何使用Transformer来做物体检测?

    UPS电源有哪些工作方式

    UPS电源是较为常见的应急电源系统,其在市电正常与市电异常的情况下,工作方式也有所不同,以下介绍UPS电源的四种工作方式:正常运行、电池工作、旁路运行和旁路维护。1、正常运行方式
    发表于 11-09 09:06 31次下载
    UPS电源有哪些<b class='flag-5'>工作方式</b>?