# ViT结构详解（附pytorch代码）

ViT把tranformer用在了图像上, transformer的文章: Attention is all you need

ViT的结构如下：

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary


image输入要是224x224x3, 所以先reshape一下


transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0)
x.shape


patch_size = 16
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)


rearrange里面的(h s1)表示hxs1,而s1是patch_size=16, 那通过hx16=224可以算出height里面包含了h个patch_size，

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(

Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
nn.Linear(patch_size * patch_size * in_channels, emb_size)
)

def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x
PatchEmbedding()(x).shape


torch.Size([1, 196, 768])

##### CLS token

cls token就是每个sequence开头的一个数字。

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.proj = nn.Sequential(

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)

self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.proj(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

x = torch.cat([cls_tokens, x], dim=1)
return x
PatchEmbedding()(x).shape


##### Position embedding

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

x = torch.cat([cls_tokens, x], dim=1)

x += self.positions
return x

PatchEmbedding()(x).shape


##### Attention

Attention有3个输入：query, key. value

class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)

out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out


query, key, value的shape通常是相同的，这里只有一个input x。

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)


‘bhqd, bhkd -> bhqk’这个看成矩阵的shape，(b,h,q,d)的矩阵 ✖ (b,h,k.d)的矩阵
qxd ✖ (kxd 的转置) -> qxk
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
att = F.softmax(energy, dim=-1) / scaling
out = torch.einsum('bhal, bhlv -> bhav ', att, values)


class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size

self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:

qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)

out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out

patches_embedded = PatchEmbedding()(x)

##### Residuals

class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x


MLP是多层感知器，结构如下


class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)


class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size: int = 768,
drop_p: float = 0.,
forward_expansion: int = 4,
forward_drop_p: float = 0.,
** kwargs):
super().__init__(
nn.LayerNorm(emb_size),
nn.Dropout(drop_p)
)),
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))


patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape


Encoder是L个（图中的Lx）TransformerEncoderBlock,

class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))

##### ViT

class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
)


Original: https://blog.csdn.net/level_code/article/details/126173408
Author: 蓝羽飞鸟
Title: ViT结构详解（附pytorch代码）

(0)

### 大家都在看

• #### 新零售场景（图像检索、识别，分类）sku级别数据集

1.AiProducts-Challenge（阿里2020） 下载地址：2020-AiProducts-Challenge-dataset数据介绍：Large-scale Prod…

人工智能 2023年7月9日
0152
• #### 数据分析之pandas（一）

一、pandas简介 pandas是python的一种数据包，是基于numpy的一个工具，里面有很多丰富的库和复杂的函数，是专门用作数据处理和数据分析的（Tips:数据分析一般都会…

人工智能 2023年6月11日
0149
• #### openCV （c++）

文章目录 * – openCV 配置 （C++) – chapter 1 读取图像、视频和摄像头 – chapter 2 基本功能 &#8211…

人工智能 2023年7月19日
092
• #### 计算机视觉方面的三大顶级会议：ICCV,CVPR,ECCV（统称ICE）

ICCV/CVPR/ECCV发论文的难度， 相当于顶级SCI期刊 和目前国内评价学术水平是以在学术期刊发表SCI论文的情况不一样，大家要注意： 在计算机视觉方向，会议论文> …

人工智能 2023年7月30日
0101
• #### YOLOV2-理论笔记

1.加入BN层 卷积过后添加 舍弃Dropout 更加容易收敛 基本上都是卷积网络标配好处就是 ①加快收敛 ②提高了2%map2.训练时候不同v1训练使用224 _224 测试使用…

人工智能 2023年5月31日
0103
• #### CUDA入门技术路线及基础知识

最近工作主要集中在目标检测算法部署方面，在树莓派4B和NVIDIA GPU平台上做了一些内容，个人觉得GPU多核计算对于深度学习的加持作用意义重大，而NVIDIA出品的软硬件是GP…

人工智能 2023年7月14日
0111
• #### 【机器学习】数据增强(Data Augmentation)

文章目录 一、引言 – 背景 二、为什么需要数据增强？ 三、什么是数据增强？ * 定义 分类 四、有监督的数据增强 * 1. 单样本数据增强 – （1）几何…

人工智能 2023年7月25日
0116
• #### 使用Python对xlsx，csv, txt格式文件进行读、写并绘图

0. 背景 最近需要用到python通过读取，写入Excel数据，并画一些图。虽然以前学过一些，但是都忘得差不多了，故翻出以前学习的资料，整理在此，常用常新，也方便自己以后复习。 …

人工智能 2023年7月7日
098
• #### 双目相机基本原理

双目相机基本原理 * – + * 双目图像 * 视差 * 深度 * 深度与视差之间的关系 * 极平面 * 极线 * 极线约束 * 单应性矩阵 双目图像 如图所示，双目图…

人工智能 2023年6月19日
0134
• #### GRU(门控循环单元)，易懂。

一、什么是GRU？ GRU（Gate Recurrent Unit）是循环神经网络（RNN）的一种，可以解决RNN中不能长期记忆和反向传播中的梯度等问题，与LSTM的作用类似，不过…

人工智能 2023年5月27日
0143
• #### 有监督学习：回归（进阶——跳过简单的线性回归）

总所周知，回归是机器学习的入门，而对于这篇文章，我也是下了很大的功夫。对于最基础的线性回归（也就是），这里我就不再过多叙述了，并且在该文章里面涉及到回归基础的东西我也不再过多啰嗦，…

人工智能 2023年6月18日
0148
• #### 论文解读 | NeurIPS 2022：基于因果推理的多轮药物推荐模型

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入！ 孙宏达： 中国人民大学高瓴人工智能学院直博三年级，研究兴趣包括药物发现、机器学习、自然语言处理等。 报告简介 人工智…

人工智能 2023年6月29日
0237
• #### 中文自动文本摘要生成指标计算，Rouge/Bleu/BertScore/QA代码实现

本部分讲述下如何计算生成摘要与参考摘要的指标，指标方面分为两类，一类基于n-grams计算，如Rouge-1，Rouge-2，Rouge-L，BLEU，主要衡量摘要的句法的连贯性，…

人工智能 2023年5月30日
0143
• #### [人工智能-深度学习-67]：目标检测 – 常见目标检测算法大汇总

### 回答1： 算法_7 –4：深度优先搜索 1. 从起点开始，将其标记为已访问。 2. 对于起点的每个未访问的邻居，递归地执行步骤1 –2。 3. 重复…

人工智能 2023年7月9日
0125
• #### Explaining Knowledge Graph Embedding via Latent Rule Learning

研究问题 将知识图谱嵌入由一步预测转变为多步推理预测，通过向量空间中的相似性解释每一步，提高其可解释性的同时保证预测效果不下降 背景动机 知识图谱嵌入由于是一步预测出最终结果，缺乏…

人工智能 2023年6月1日
0119
• #### 从0到1构建一个基于知识图谱的智能问答系统

目录 一、前言 二、知识图谱 * 2.1 数据入库 – 2.1.1 Nebula Graph搭建 2.1.2数据导入 三、后端 * 3.1 搭建Flask框架，处理ht…

人工智能 2023年7月27日
0117