# ViT结构详解（附pytorch代码）

ViT把tranformer用在了图像上, transformer的文章: Attention is all you need

ViT的结构如下：

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary


image输入要是224x224x3, 所以先reshape一下


transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0)
x.shape


patch_size = 16
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)


rearrange里面的(h s1)表示hxs1,而s1是patch_size=16, 那通过hx16=224可以算出height里面包含了h个patch_size，

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(

Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
nn.Linear(patch_size * patch_size * in_channels, emb_size)
)

def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x
PatchEmbedding()(x).shape


torch.Size([1, 196, 768])

##### CLS token

cls token就是每个sequence开头的一个数字。

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.proj = nn.Sequential(

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)

self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.proj(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

x = torch.cat([cls_tokens, x], dim=1)
return x
PatchEmbedding()(x).shape


##### Position embedding

class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))

self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

x = torch.cat([cls_tokens, x], dim=1)

x += self.positions
return x

PatchEmbedding()(x).shape


##### Attention

Attention有3个输入：query, key. value

class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)

out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out


query, key, value的shape通常是相同的，这里只有一个input x。

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)


‘bhqd, bhkd -> bhqk’这个看成矩阵的shape，(b,h,q,d)的矩阵 ✖ (b,h,k.d)的矩阵
qxd ✖ (kxd 的转置) -> qxk
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
att = F.softmax(energy, dim=-1) / scaling
out = torch.einsum('bhal, bhlv -> bhav ', att, values)


class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size

self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:

qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)

out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out

patches_embedded = PatchEmbedding()(x)

##### Residuals

class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x


MLP是多层感知器，结构如下


class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)


class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size: int = 768,
drop_p: float = 0.,
forward_expansion: int = 4,
forward_drop_p: float = 0.,
** kwargs):
super().__init__(
nn.LayerNorm(emb_size),
nn.Dropout(drop_p)
)),
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))


patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape


Encoder是L个（图中的Lx）TransformerEncoderBlock,

class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))

##### ViT

class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
)


Original: https://blog.csdn.net/level_code/article/details/126173408
Author: 蓝羽飞鸟
Title: ViT结构详解（附pytorch代码）

(0)

### 大家都在看

• #### pytest入门看着一篇就够了

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年1月18日
075
• #### No module named ‘Torch’解决办法

作者：非妃是公主专栏：《python学习》个性签：顺境不惰，逆境不馁，以心制境，万事可成。——曾国藩 转载请标明，原文链接：https://blog.csdn.net/myf_66…

Python 2023年8月1日
037
• #### 关于gym新版本0.23.0版本的一些问题以及Box2D的安装

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年1月20日
073

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年1月4日
073
• #### c++ trivial, standard layout和POD类型解析

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年1月29日
086
• #### YOLOv7 Tensorrt Python部署教程

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2022年8月19日
0273
• #### 在不同电脑间同步pycharm的配置

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年2月3日
0111
• #### python神奇功能_python – PyOpenGL如何使用glGenBuffers实现它的神奇功能？

以下是使用glGenBuffers创建一个或多个缓冲区的一些示例.如果必须使用像glGenBuffers(n,buffers)这样的函数调用,可以直接使用ctypes,也可以使用P…

Python 2023年9月25日
015
• #### CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2022年12月24日
0172

Python 2023年6月3日
032
• #### 2022需求最大的8种编程语言排名

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年1月19日
097
• #### Django 学习

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2022年12月28日
0142
• #### RV1126笔记十七：吸烟行为检测及部署＜五＞

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2023年1月24日
063
• #### Python 数据集：乳腺癌数据集（from sklearn.datasets import load_breast_cancer）。

注入产生的原理: 数据库设置为GBK编码: 宽字节注入源于程序员设置MySQL连接时错误配置为:set character_set_client=gbk,这样配置会引发编码转换从而…

Python 2022年12月25日
0103