0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

如何使用Transformer来做物体检测?

中科院长春光机所 来源:AI公园 作者:Jacob Briones 2021-04-25 10:45 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

导读

本文为一个Facebook的目标检测Transformer (DETR)的完整指南,详细介绍了DETR架构的内部工作方式以及代码。

介绍

DEtection TRansformer (DETR)是Facebook研究团队巧妙地利用了Transformer 架构开发的一个目标检测模型。在这篇文章中,我将通过分析DETR架构的内部工作方式来帮助提供一些关于它的含义。下面,我将解释一些结构,但是如果你只是想了解如何使用模型,可以直接跳到代码部分。

结构

DETR模型由一个预训练的CNN骨干(如ResNet)组成,它产生一组低维特征集。这些特征被格式化为一个特征集合并添加位置编码,输入一个由Transformer组成的编码器和解码器中,和原始的Transformer论文中描述的Encoder-Decoder的使用方式非常的类似。解码器的输出然后被送入固定数量的预测头,这些预测头由预定义数量的前馈网络组成。每个预测头的输出都包含一个类预测和一个预测框。损失是通过计算二分匹配损失来计算的。

de1e84951e16f003cad41eb9d065b884.png

该模型做出了预定义数量的预测,并且每个预测都是并行计算的。

CNN主干

假设我们的输入图像,有三个输入通道。CNN backbone由一个(预训练过的)CNN(通常是ResNet)组成,我们用它来生成_C_个具有宽度W和高度H的低维特征(在实践中,我们设置_C_=2048, W=W₀/32和H=H₀/32)。这留给我们的是C个二维特征,由于我们将把这些特征传递给一个transformer,每个特征必须允许编码器将每个特征处理为一个序列的方式重新格式化。这是通过将特征矩阵扁平化为H⋅W向量,然后将每个向量连接起来来实现的。

287bf2488bbde18c1bec3c4f78fd9329.png

扁平化的卷积特征再加上空间位置编码,位置编码既可以学习,也可以预定义。

The Transformer

Transformer几乎与原始的编码器-解码器架构完全相同。不同之处在于,每个解码器层并行解码N个(预定义的数目)目标。该模型还学习了一组N个目标的查询,这些查询是(类似于编码器)学习出来的位置编码。

cfa6d245dd8b156b006edab65640f0a8.png

目标查询

下图描述了N=20个学习出来的目标查询(称为prediction slots)如何聚焦于一张图像的不同区域。

964d2f602f92857bcc5274b7d0774bf1.png

“我们观察到,在不同的操作模式下,每个slot 都会学习特定的区域和框大小。“ —— DETR的作者

理解目标查询的直观方法是想象每个目标查询都是一个人。每个人都可以通过注意力来查看图像的某个区域。一个目标查询总是会问图像中心是什么,另一个总是会问左下角是什么,以此类推。

使用PyTorch实现简单的DETR

import torchimport torch.nn as nnfrom torchvision.models import resnet50class SimpleDETR(nn.Module):“”“Minimal Example of the Detection Transformer model with learned positional embedding”“” def __init__(self, num_classes, hidden_dim, num_heads, num_enc_layers, num_dec_layers): super(SimpleDETR, self).__init__() self.num_classes = num_classes self.hidden_dim = hidden_dim self.num_heads = num_heads self.num_enc_layers = num_enc_layers self.num_dec_layers = num_dec_layers # CNN Backbone self.backbone = nn.Sequential( *list(resnet50(pretrained=True).children())[:-2]) self.conv = nn.Conv2d(2048, hidden_dim, 1) # Transformer self.transformer = nn.Transformer(hidden_dim, num_heads, num_enc_layers, num_dec_layers) # Prediction Heads self.to_classes = nn.Linear(hidden_dim, num_classes+1) self.to_bbox = nn.Linear(hidden_dim, 4) # Positional Encodings self.object_query = nn.Parameter(torch.rand(100, hidden_dim)) self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2) self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) def forward(self, X): X = self.backbone(X) h = self.conv(X) H, W = h.shape[-2:] pos_enc = torch.cat([ self.col_embed[:W].unsqueeze(0).repeat(H,1,1), self.row_embed[:H].unsqueeze(1).repeat(1,W,1)], dim=-1).flatten(0,1).unsqueeze(1) h = self.transformer(pos_enc + h.flatten(2).permute(2,0,1), self.object_query.unsqueeze(1)) class_pred = self.to_classes(h) bbox_pred = self.to_bbox(h).sigmoid() return class_pred, bbox_pred

二分匹配损失 (Optional)

让为预测的集合,其中是包括了预测类别(可以是空类别)和包围框的二元组,其中上划线表示框的中心点, 和表示框的宽和高。设y为ground truth集合。假设y和_ŷ_之间的损失为L,每一个yᵢ和_ŷ_ᵢ之间的损失为Lᵢ。由于我们是在集合的层次上工作,损失L必须是排列不变的,这意味着无论我们如何排序预测,我们都将得到相同的损失。因此,我们想找到一个排列,它将预测的索引映射到ground truth目标的索引上。在数学上,我们求解:

86dc5236fcca1b7bd7080630260c36d6.png

计算的过程称为寻找最优的二元匹配。这可以用匈牙利算法找到。但为了找到最优匹配,我们需要实际定义一个损失函数,计算和之间的匹配成本。

回想一下,我们的预测包含一个边界框和一个类。现在让我们假设类预测实际上是一个类集合上的概率分布。那么第_i_个预测的总损失将是类预测产生的损失和边界框预测产生的损失之和。作者在http://arxiv.org/abs/1906.05909中将这种损失定义为边界框损失和类预测概率的差异:

992dad5a7a1dc3075cbcd33f150d10f7.png

其中,是的argmax,是是来自包围框的预测的损失,如果,则表示匹配损失为0。

框损失的计算为预测值与ground truth的L₁损失和的GIOU损失的线性组合。同样,如果你想象两个不相交的框,那么框的错误将不会提供任何有意义的上下文(我们可以从下面的框损失的定义中看到)。

183c84881c17d3e38dced802e8291566.png

其中,λᵢₒᵤ和是超参数。注意,这个和也是面积和距离产生的误差的组合。为什么会这样呢?

可以把上面的等式看作是与预测相关联的总损失,其中面积误差的重要性是λᵢₒᵤ,距离误差的重要性是。现在我们来定义GIOU损失函数。定义如下:

1d3224e47d3956fe8afbefa144918b38.png

由于我们从已知的已知类的数目来预测类,那么类预测就是一个分类问题,因此我们可以使用交叉熵损失来计算类预测误差。我们将损失函数定义为每N个预测损失的总和:

0b00557f7e5daf116fe7264009ad9421.png

为目标检测使用DETR

在这里,你可以学习如何加载预训练的DETR模型,以便使用PyTorch进行目标检测。

加载模型

首先导入需要的模块。

# Import required modulesimport torchfrom torchvision import transforms as T import requests # for loading images from webfrom PIL import Image # for viewing imagesimport matplotlib.pyplot as plt

下面的代码用ResNet50作为CNN骨干从torch hub加载预训练的模型。其他主干请参见DETR github:https://github.com/facebookresearch/detr

detr = torch.hub.load(‘facebookresearch/detr’, ‘detr_resnet50’, pretrained=True)

加载一张图像

要从web加载图像,我们使用requests库:

url = ‘https://www.tempetourism.com/wp-content/uploads/Postino-Downtown-Tempe-2.jpg’ # Sample imageimage = Image.open(requests.get(url, stream=True).raw) plt.imshow(image)plt.show()

设置目标检测的Pipeline

为了将图像输入到模型中,我们需要将PIL图像转换为张量,这是通过使用torchvision的transforms库来完成的。

transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

上面的变换调整了图像的大小,将PIL图像进行转换,并用均值-标准差对图像进行归一化。其中[0.485,0.456,0.406]为各颜色通道的均值,[0.229,0.224,0.225]为各颜色通道的标准差。我们装载的模型是预先在COCO Dataset上训练的,有91个类,还有一个表示空类(没有目标)的附加类。我们用下面的代码手动定义每个标签

CLASSES = [‘N/A’, ‘Person’, ‘Bicycle’, ‘Car’, ‘Motorcycle’, ‘Airplane’, ‘Bus’, ‘Train’, ‘Truck’, ‘Boat’, ‘Traffic-Light’, ‘Fire-Hydrant’, ‘N/A’, ‘Stop-Sign’, ‘Parking Meter’, ‘Bench’, ‘Bird’, ‘Cat’, ‘Dog’, ‘Horse’, ‘Sheep’, ‘Cow’, ‘Elephant’, ‘Bear’, ‘Zebra’, ‘Giraffe’, ‘N/A’, ‘Backpack’, ‘Umbrella’, ‘N/A’, ‘N/A’, ‘Handbag’, ‘Tie’, ‘Suitcase’, ‘Frisbee’, ‘Skis’, ‘Snowboard’, ‘Sports-Ball’, ‘Kite’, ‘Baseball Bat’, ‘Baseball Glove’, ‘Skateboard’, ‘Surfboard’, ‘Tennis Racket’, ‘Bottle’, ‘N/A’, ‘Wine Glass’, ‘Cup’, ‘Fork’, ‘Knife’, ‘Spoon’, ‘Bowl’, ‘Banana’, ‘Apple’, ‘Sandwich’, ‘Orange’, ‘Broccoli’, ‘Carrot’, ‘Hot-Dog’, ‘Pizza’, ‘Donut’, ‘Cake’, ‘Chair’, ‘Couch’, ‘Potted Plant’, ‘Bed’, ‘N/A’, ‘Dining Table’, ‘N/A’,‘N/A’, ‘Toilet’, ‘N/A’, ‘TV’, ‘Laptop’, ‘Mouse’, ‘Remote’, ‘Keyboard’, ‘Cell-Phone’, ‘Microwave’, ‘Oven’, ‘Toaster’, ‘Sink’, ‘Refrigerator’, ‘N/A’, ‘Book’, ‘Clock’, ‘Vase’, ‘Scissors’, ‘Teddy-Bear’, ‘Hair-Dryer’, ‘Toothbrush’]

如果我们想输出不同颜色的边框,我们可以手动定义我们想要的RGB格式的颜色

COLORS = [ [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933] ]

格式化输出

我们还需要重新格式化模型的输出。给定一个转换后的图像,模型将输出一个字典,包含100个预测类的概率和100个预测边框。每个包围框的形式为(x, y, w, h),其中(x,y)为包围框的中心(包围框是单位正方形[0,1]×[0,1]), w, h为包围框的宽度和高度。因此,我们需要将边界框输出转换为初始和最终坐标,并重新缩放框以适应图像的实际大小。下面的函数返回边界框端点:

# Get coordinates (x0, y0, x1, y0) from model output (x, y, w, h)def get_box_coords(boxes): x, y, w, h = boxes.unbind(1) x0, y0 = (x - 0.5 * w), (y - 0.5 * h) x1, y1 = (x + 0.5 * w), (y + 0.5 * h) box = [x0, y0, x1, y1] return torch.stack(box, dim=1)

我们还需要缩放了框的大小。下面的函数为我们做了这些:

# Scale box from [0,1]x[0,1] to [0, width]x[0, height]def scale_boxes(output_box, width, height): box_coords = get_box_coords(output_box) scale_tensor = torch.Tensor( [width, height, width, height]).to( torch.cuda.current_device()) return box_coords * scale_tensor

现在我们需要一个函数来封装我们的目标检测pipeline。下面的detect函数为我们完成了这项工作。

# Object Detection Pipelinedef detect(im, model, transform): device = torch.cuda.current_device() width = im.size[0] height = im.size[1] # mean-std normalize the input image (batch-size: 1) img = transform(im).unsqueeze(0) img = img.to(device) # demo model only support by default images with aspect ratio between 0.5 and 2 assert img.shape[-2] 《= 1600 and img.shape[-1] 《= 1600, # propagate through the model outputs = model(img) # keep only predictions with 0.7+ confidence probas = outputs[‘pred_logits’].softmax(-1)[0, :, :-1] keep = probas.max(-1).values 》 0.85 # convert boxes from [0; 1] to image scales bboxes_scaled = scale_boxes(outputs[‘pred_boxes’][0, keep], width, height) return probas[keep], bboxes_scaled

现在,我们需要做的是运行以下程序来获得我们想要的输出:

probs, bboxes = detect(image, detr, transform)

绘制结果

现在我们有了检测到的目标,我们可以使用一个简单的函数来可视化它们。

# Plot Predicted Bounding Boxesdef plot_results(pil_img, prob, boxes,labels=True): plt.figure(figsize=(16,10)) plt.imshow(pil_img) ax = plt.gca() for prob, (x0, y0, x1, y1), color in zip(prob, boxes.tolist(), COLORS * 100): ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, color=color, linewidth=2)) cl = prob.argmax() text = f‘{CLASSES[cl]}: {prob[cl]:0.2f}’ if labels: ax.text(x0, y0, text, fontsize=15, bbox=dict(facecolor=color, alpha=0.75)) plt.axis(‘off’) plt.show()

现在可以可视化结果:

plot_results(image, probs, bboxes, labels=True)

英文原文:https://medium.com/swlh/object-detection-with-transformers-437217a3d62e

编辑:jq

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 函数
    +关注

    关注

    3

    文章

    4406

    浏览量

    66812
  • 代码
    +关注

    关注

    30

    文章

    4940

    浏览量

    73116
  • cnn
    cnn
    +关注

    关注

    3

    文章

    355

    浏览量

    23244
  • pytorch
    +关注

    关注

    2

    文章

    813

    浏览量

    14680

原文标题:实操教程|如何使用Transformer来做物体检测?DETR模型完整指南

文章出处:【微信号:cas-ciomp,微信公众号:中科院长春光机所】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    2025年12月多气体检测仪十大品牌权威榜单

    在工业生产、环境监测与公共安全领域,多气体检测仪扮演着至关重要的“电子哨兵”角色。面对市场上琳琅满目的品牌,如何选择一家技术可靠、服务完善的生产商,是众多企业安全负责人关注的焦点。本文基于2025年
    的头像 发表于 12-02 15:05 75次阅读
    2025年12月多气<b class='flag-5'>体检测</b>仪十大品牌权威榜单

    MTCH9010液体检测芯片技术解析:双模传感与低功耗设计的完美结合

    Microchip Technology MTCH9010液体检测器提供数字和原始数据输出,是一种在不同传感器上检测液体是否存在的灵活方式。该检测器支持各种形状和尺寸的传感器。合适的MTCH9010液体探测器允许系统运行电容式或
    的头像 发表于 09-28 11:22 485次阅读
    MTCH9010液<b class='flag-5'>体检测</b>芯片技术解析:双模传感与低功耗设计的完美结合

    自动驾驶汽车如何准确识别小物体

    地面上常见的小坑、碎石、塑料袋、纸箱角落、掉落的车载零件,甚至是一只小鸟或小猫,都可能对车辆的行驶安全与乘坐舒适性造成影响。 小物体检测听起来“小事儿一桩”,但实际难度会高很多。小物体具有目标体积小、与背景对
    的头像 发表于 08-22 09:11 451次阅读
    自动驾驶汽车如何准确识别小<b class='flag-5'>物体</b>?

    设计信息赋能 - AI 让半导体检测与诊断更给力

    人工智能(AI)的进步正为包括半导体制造在内的多个行业带来革命性变革。利用AI开展半导体检测与诊断工作,已成为一种可能改变行业格局的策略,有助于提高生产效率、识别以往无法察觉的缺陷,并缩短产品上市
    的头像 发表于 08-19 13:45 884次阅读
    设计信息赋能 - AI 让半导<b class='flag-5'>体检测</b>与诊断更给力

    如何在 M55M1 系列微控制器上以低功耗模式使用运动检测功能?

    如何在 M55M1 系列微控制器上以低功耗模式使用运动检测功能。根据物体检测结果,系统将动态启用或禁用运动检测块,以实现最佳性能和能效。
    发表于 08-19 06:56

    在树莓派5上使用YOLO进行物体和动物识别-入门指南

    大家好,接下来会为大家开一个树莓派5和YOLO的专题。内容包括四个部分:在树莓派5上使用YOLO进行物体和动物识别-入门指南在树莓派5上开启YOLO人体姿态估计识别之旅YOLO物体检测在树莓派
    的头像 发表于 07-17 17:16 1555次阅读
    在树莓派5上使用YOLO进行<b class='flag-5'>物体</b>和动物识别-入门指南

    【嘉楠堪智K230开发板试用体验】01 Studio K230开发板Test2——手掌,手势检测,字符检测

    理解: 它不仅能检测图像内容,更能进行精确的识别和定位。例如: 人脸检测与定位(位置和尺寸)。 物体检测、识别(分类)、定位(位置和尺寸)。 高性能: 其计算能力显著提升,官方数据显示其性能是前代
    发表于 07-10 09:45

    【HarmonyOS 5】VisionKit人脸活体检测详解

    【HarmonyOS 5】VisionKit人脸活体检测详解 ##鸿蒙开发能力 ##HarmonyOS SDK应用服务##鸿蒙金融类应用 (金融理财# 一、VisionKit人脸活体检测
    的头像 发表于 06-21 11:52 629次阅读
    【HarmonyOS 5】VisionKit人脸活<b class='flag-5'>体检测</b>详解

    云南恩田有毒有害气体检测系统# 的安全#隧道施工#有毒有害气体检测

    体检测
    恩田智能设备
    发布于 :2025年05月15日 15:06:40

    汉威科技推出新款便携式气体检测

    便携式气体检测仪是石油、化工、燃气、环境监测、应急救援等领域日常巡检、有限空间作业的必备工具。
    的头像 发表于 04-25 17:30 1036次阅读

    便携式+多功能+可定制!工厂直发,重新定义气体检测效率

    体检测
    奕帆科技
    发布于 :2025年04月25日 15:05:25

    体检漏仪如何操作?注意事项有哪些?

    体检漏仪 ,从名称上就能看出,这是一种用于检测气体泄漏情况的专业设备,在工业、环保等领域中有着广泛的应用潜力。那么,气体检漏仪如何操作?注意事项有哪些?为方便大家了解,下面就让小编
    发表于 03-12 15:08

    要设计CH气体检测设备应用的激光源波长为3370nm,请问DMD微镜的反射波长是多少?

    请问:我现在要设计CH气体检测设备应用的激光源波长为3370nm,请问贵司的DMD微镜的反射波长是多少?我们的要求能满足吗?
    发表于 02-24 08:08

    原来ESP32竟可《一“芯”两用》既做人体检测传感器也Wi-Fi数据传输

    今天将介绍ESP32如何"一芯两用",既做人体检测传感器也Wi-Fi数据传输模块;对于使用ESP32Wi-Fi数据通讯,相信玩ESP32的基本上都知道怎么玩了,但是
    的头像 发表于 12-18 18:12 5329次阅读
    原来ESP32竟可《一“芯”两用》既做人<b class='flag-5'>体检测</b>传感器也<b class='flag-5'>做</b>Wi-Fi数据传输

    研华AIMB-523工业主板:赋能半导体检测设备,性能提升超20%

    研华科技近日推出的AIMB-523工业主板,专为半导体检测设备设计,适配AMD Ryzen 7000系列处理器与B650芯片组,其卓越性能较同类设备提升超过20%,充分满足了半导体检测设备对超高
    的头像 发表于 12-11 11:32 1347次阅读