博主想对神经网络模型的参数写入 bin
文件,方便在后续创建IP的过程中读取数据进行验证,记录 python
读取 pytorch
的模块参数并进行bin文件写入和读取操作。本文以3×3卷积为例。
本文涉及的模块
pytorch
:神经网络框架
简单示例:
import struct
SAVE_DIR = "./conv3x3_pool_relu_outputs"
import struct
val = -1
a = struct.pack('i', val)
print(a)
file = os.path.join(SAVE_DIR, "wt.bin")
with open(file, "ab+") as fw:
fw.write(a)
with open(file, "rb") as fr:
b = struct.unpack('i', fr.read(4))
print(b[0])
print(b[0] == val)
完整保存参数代码:
"""
for generate conv3x3_pool_relu and data for test.
"""
import os
import torch
import torch.nn as nn
Hin = 6
Win = 12
CHin = 16
CHout = 16
step = 0.1
G_SIZE = 8
SAVE_DIR = "./conv3x3_pool_relu_outputs"
seed = 2021
torch.random.manual_seed(seed)
def format_num(x):
"""
>0 -> 1, -1. switch func.
"""
return (torch.randn_like(x) > 0).to(torch.float32) * 2 - 1
def save_conv3x3_weight(weight, save_dir="./outputs", filename="conv3x3", size=8):
"""
写入文件,
"""
shape = weight.shape
print("save {} weights(bin format) ".format(filename), shape, end=" ---------wait---------- ")
assert len(shape) == 4 and shape[0] % size == 0 and shape[1] % size == 0, "input error"
if not ".dat" in filename:
filename = filename + "_weight.bin"
if type(weight) in [torch.nn.Parameter, torch.Tensor]:
weight = weight.cpu().detach().numpy()
filepath = os.path.join(save_dir, filename)
with open(filepath, "wb+") as fw:
for i in range(0, shape[0], size):
for j in range(0, shape[1], size):
for co in range(i, i + size):
for ci in range(j, j + size):
for h in range(3):
for w in range(3):
fw.write(struct.pack('i', int(weight[co][ci][h][w])))
print("save conv3x3_weight done. save weights to {}".format(filepath))
return filepath
class Conv3x3PoolRelu(nn.Module):
def __init__(self, in_channels=16, out_channels=32, save=False, out_dir="./outputs", save_size=8):
super().__init__()
assert in_channels % G_SIZE == 0 and out_channels % G_SIZE == 0, "input error!!"
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.act = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2, 0)
self.init_weights()
self.save_dir = out_dir
def forward(self, x):
for name, module in self.named_children():
print(name, module)
if type(module) in [nn.Conv2d]:
save_conv3x3_weight(module.weight, self.save_dir)
x = module(x)
return x
def init_weights(self):
for idx, m in self.named_modules():
if type(m) in [nn.Conv2d]:
weight, bias = m.weight, m.bias
m_weight, m_bias = format_num(weight), format_num(bias)
m.weight, m.bias = nn.Parameter(m_weight, requires_grad=False), nn.Parameter(m_bias,
requires_grad=False)
if __name__ == '__main__':
model = Conv3x3PoolRelu(8, 8, out_dir=SAVE_DIR)
x = format_num(torch.randn(1, 8, 4, 4))
y = model(x)
print(y.shape, y)
w = torch.empty(8, 8, 3, 3)
con, cin, kh, kw = w.shape
with open("./conv3x3_pool_relu_outputs/conv3x3_weight.bin", "rb") as fr:
for co in range(con):
for ci in range(cin):
for i in range(kh):
for j in range(kw):
data = struct.unpack("i", fr.read(4))
w[co][ci][i][j] = data[0]
print(w)
详细步骤如下:
Note:写入文件的格式和数据类型之间的关系如下:
格式C 类型Python 类型标准大小
填充字节无
长度为 1 的字节串1
整数1
整数1
bool1
整数2
整数2
整数4
整数4
整数4
整数4
整数8
整数8
(6)float2
float4
float8
字节串
字节串
写入 bin
文件主要是将二进制数据写入,如果一开始就是二进制数据,那么就不需要进行 struct
的 pack
操作。另外,对于python的数据类型,写入文件的字节顺序、大小与对齐方式可以设置,详细见 官方文档[2]。
Original: https://blog.csdn.net/qq_33808481/article/details/121919344
Author: 浅尝这只
Title: 使用python对bin文件进行操作
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/710494/
转载文章受原作者版权保护。转载请注明原作者出处!