# ViT结构详解（附pytorch代码）

ViT把tranformer用在了图像上, transformer的文章: Attention is all you need

ViT的结构如下：

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary


image输入要是224x224x3, 所以先reshape一下


transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0)
x.shape


patch_size = 16
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)


rearrange里面的(h s1)表示hxs1,而s1是patch_size=16, 那通过hx16=224可以算出height里面包含了h个patch_size，

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(

Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
nn.Linear(patch_size * patch_size * in_channels, emb_size)
)

def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x
PatchEmbedding()(x).shape


torch.Size([1, 196, 768])

##### CLS token

cls token就是每个sequence开头的一个数字。

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.proj = nn.Sequential(

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)

self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.proj(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

x = torch.cat([cls_tokens, x], dim=1)
return x
PatchEmbedding()(x).shape


##### Position embedding

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

x = torch.cat([cls_tokens, x], dim=1)

x += self.positions
return x

PatchEmbedding()(x).shape


##### Attention

Attention有3个输入：query, key. value

class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)

out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out


query, key, value的shape通常是相同的，这里只有一个input x。

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)


‘bhqd, bhkd -> bhqk’这个看成矩阵的shape，(b,h,q,d)的矩阵 ✖ (b,h,k.d)的矩阵
qxd ✖ (kxd 的转置) -> qxk
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
att = F.softmax(energy, dim=-1) / scaling
out = torch.einsum('bhal, bhlv -> bhav ', att, values)


class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size

self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:

qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)

out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out

patches_embedded = PatchEmbedding()(x)

##### Residuals

class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x


MLP是多层感知器，结构如下


class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)


class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size: int = 768,
drop_p: float = 0.,
forward_expansion: int = 4,
forward_drop_p: float = 0.,
** kwargs):
super().__init__(
nn.LayerNorm(emb_size),
nn.Dropout(drop_p)
)),
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))


patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape


Encoder是L个（图中的Lx）TransformerEncoderBlock,

class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))

##### ViT

class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
)


Original: https://blog.csdn.net/level_code/article/details/126173408
Author: 蓝羽飞鸟
Title: ViT结构详解（附pytorch代码）

(0)

### 大家都在看

• #### 《数字图像处理》dlib人脸检测获取关键点，delaunay三角划分，实现人脸的几何变换warpping,接着实现两幅人脸图像之间的渐变合成morphing

这学期在上《数字图像处理》这门课程，老师布置了几个大作业，自己和同学一起讨论完成后，感觉还挺有意思的，就想着把这个作业整理一下 ： 目录 1.实验任务和要求 2.实验原理 3.实验…

人工智能 2023年7月18日
027
• #### 数据挖掘框架（结构化数据）

数据量数据缺失情况描述性统计特征理解特征分布周期性分析对比分析相关性分析训练集和测试集的分布一致性 缺失值处理异常值处理内存优化数据增强欠采样/过采样 1.ID特征处理需要考虑训练…

人工智能 2023年7月18日
023
• #### 机器学习-有监督学习-分类算法：逻辑回归/Logistic回归（二分类模型）【值域符合二项分布律 ==似然函数最大化==＞ 交叉熵/对数损失函数】、Softmax回归（多分类模型）【交叉熵损失函数】

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

人工智能 2022年11月19日
0130
• #### QQ机器人制作教程，超详细

目录 前期准备 * 1、机器人框架的下载和配置 2、python的配置和安装 具体实现 * 1、发送信息 2、获取群成员列表 3、接收上报的事件 4、实现简单的自动回复 5、解决重…

人工智能 2023年7月3日
023
• #### 行为稀疏场景下的图模型实践

本系列将系统介绍召回技术在内容推荐的实践与总结。 第一篇： 第二篇： 第三篇： 第四篇：‍‍ 第五篇： 背景 在视频推荐系统中，有相当一部分用户的行为较为稀疏，基于行为序列建模的方…

人工智能 2023年6月10日
036
• #### 8月份，我靠这一份PDF文档面试BAT，收到了5个offer

这份PDF面经知识点包括了五个大部分，26个知识点： J ava部分：Java基础，集合，并发，多线程，JVM，设计模式 数据结构算法：Java算法，数据结构 开源框架部分：Spr…

人工智能 2023年6月28日
027
• #### [Python从零到壹] 三十四.OpenCV入门详解——显示读取修改及保存图像

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

人工智能 2022年11月24日
0170
• #### 性能测试场景设计之普通性能场景设计

常见的六种设计方法之一 普通性能场景设计 使用jmeter 线程组，模拟用户并发数，线程组Jmeter本身是没有对线程数做限制,但是对于机器本身性能有要求，受电脑cpu的主频限制，…

人工智能 2023年6月29日
031
• #### Linux搭建深度学习平台tensorflow，并使用jupyter notebook远程访问服务器。

文章目录 前言 一、Tensorflow 二、screen命令 运行jupyter notebook 前言 本文介绍如何搭建深度学习平台，并在jupyter notebook上运行…

人工智能 2023年5月24日
035
• #### 如何使用Python进行深度学习任务

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

人工智能 2023年3月29日
045
• #### tensorflow利用for循环进行训练遇到的内存爆炸问题(OOM)

最近在用tensorflow学习模型的知识蒸馏，自己基于cifar10数据集训练得到的teacher模型，在对3种不同参数量的student模型使用相同的alpha和tempera…

人工智能 2023年5月23日
035
• #### 从智能对话系统导论，到如何设计第一个对话机器人

从智能对话系统导论，到如何设计第一个对话机器人 一、智能对话系统导论 * 1、生活中的 Conversational AI 2、一种新的人机交互方式 3、一些关于 Conversa…

人工智能 2023年5月30日
036
• #### 使用python开发二维码识别功能、Docker镜像安装opencv-contrib-python、

使用python开发二维码识别功能、Docker镜像安装opencv-contrib-python、 背景 开发二维码识别功能，使用到开源三方库opencv-contrib-pyt…

人工智能 2023年7月19日
019
• #### 解决 cv_bridge 与 opencv4 版本冲突问题

解决了在 ROS melodic / noetic 下 cv_bridge 与 opencv4 版本冲突导致的 opencv 操作 导致 Segmentation fault (c…

人工智能 2023年6月19日
040
• #### Python分类实例之猫狗大战

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

人工智能 2022年11月27日
0130
• #### 人脸识别—-face_recognition安装与应用（附代码）

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

人工智能 2022年9月18日
0206