cnn+lstm+attention对时序数据进行预测

1、摘要

本文主要讲解:bilstm-cnn-attention对时序数据进行预测
主要思路:

  1. 对时序数据进行分块,生成三维时序数据块
  2. 建立模型,卷积层-bilstm层-attention按顺序建立,attention层可放中间也可放前面,效果各不相同
  3. 训练模型,使用训练好的模型进行预测
  4. 调参优化,保存模型

2、数据介绍

需要完整代码和数据介绍请移步我的下载:
cnn+lstm+attention对时序数据进行预测

3、相关技术

BiLSTM:前向和方向的两条LSTM网络,被称为双向LSTM,也叫BiLSTM。其思想是将同一个输入序列分别接入向前和先后的两个LSTM中,然后将两个网络的隐含层连在一起,共同接入到输出层进行预测。

cnn+lstm+attention对时序数据进行预测

attention注意力机制

cnn+lstm+attention对时序数据进行预测

一维卷积

cnn+lstm+attention对时序数据进行预测
cnn+lstm+attention 网络结构图
cnn+lstm+attention对时序数据进行预测

; 4、完整代码和步骤

此程序运行代码版本为:

tensorflow==2.5.0
numpy==1.19.5
keras==2.6.0
matplotlib==3.5.2

主运行程序入口

from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers import Conv1D, Bidirectional, Multiply, LSTM
from keras.layers.core import *
from keras.models import *
from sklearn.metrics import mean_absolute_error
from keras import backend as K
from tensorflow.python.keras.layers import CuDNNLSTM

from my_utils.read_write import pdReadCsv
import numpy as np

SINGLE_ATTENTION_VECTOR = False
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["TF_KERAS"] = '1'

def attention_3d_block(inputs):
    input_dim = int(inputs.shape[2])
    a = inputs
    a = Dense(input_dim, activation='softmax')(a)

    a_probs = Permute((1, 2), name='attention_vec')(a)

    output_attention_mul = Multiply()([inputs, a_probs])
    return output_attention_mul

def create_dataset(dataset, look_back):
    dataX, dataY = [], []
    for i in range(len(dataset) - look_back - 1):
        a = dataset[i:(i + look_back), :]
        dataX.append(a)
        dataY.append(dataset[i + look_back, :])
    TrainX = np.array(dataX)
    Train_Y = np.array(dataY)

    return TrainX, Train_Y

def attention_model():
    inputs = Input(shape=(TIME_STEPS, INPUT_DIMS))

    x = Conv1D(filters=64, kernel_size=1, activation='relu')(inputs)
    x = Dropout(0.3)(x)

    lstm_out = Bidirectional(CuDNNLSTM(lstm_units, return_sequences=True))(x)
    lstm_out = Dropout(0.3)(lstm_out)
    attention_mul = attention_3d_block(lstm_out)

    attention_mul = Flatten()(attention_mul)

    output = Dense(1, activation='linear')(attention_mul)
    model = Model(inputs=[inputs], outputs=output)
    return model

def fit_size(x, y):
    from sklearn import preprocessing
    x_MinMax = preprocessing.MinMaxScaler()
    y_MinMax = preprocessing.MinMaxScaler()
    x = x_MinMax.fit_transform(x)
    y = y_MinMax.fit_transform(y)
    return x, y, y_MinMax

def flatten(X):
    flattened_X = np.empty((X.shape[0], X.shape[2]))
    for i in range(X.shape[0]):
        flattened_X[i] = X[i, (X.shape[1] - 1), :]
    return (flattened_X)

src = r'E:\dat'
path = r'E:\dat'
trials_path = r'E:\dat'
train_path = src + r'merpre.csv'
df = pdReadCsv(train_path, ',')
df = df.replace("--", '0')
df.fillna(0, inplace=True)
INPUT_DIMS = 43
TIME_STEPS = 12
lstm_units = 64

def load_data(df_train):
    X_train = df_train.drop(['Per'], axis=1)
    y_train = df_train['wap'].values.reshape(-1, 1)
    return X_train, y_train, X_train, y_train

groups = df.groupby(['Per'])
for name, group in groups:
    X_train, y_train, X_test, y_test = load_data(group)

    train_x, train_y, train_y_MinMax = fit_size(X_train, y_train)
    test_x, test_y, test_y_MinMax = fit_size(X_test, y_test)

    train_X, _ = create_dataset(train_x, TIME_STEPS)
    _, train_Y = create_dataset(train_y, TIME_STEPS)
    print(train_X.shape, train_Y.shape)

    m = attention_model()
    m.summary()
    m.compile(loss='mae', optimizer='Adam', metrics=['mae'])
    model_path = r'me_pre\\'
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=2, verbose=0),
        ModelCheckpoint(model_path, monitor='val_loss', save_best_only=True, verbose=0),
    ]
    m.fit(train_X, train_Y, batch_size=32, epochs=111, shuffle=True, verbose=1,
          validation_split=0.1, callbacks=callbacks)

    test_X, _ = create_dataset(test_x, TIME_STEPS)
    _, test_Y = create_dataset(test_y, TIME_STEPS)

    pred_y = m.predict(test_X)
    inv_pred_y = test_y_MinMax.inverse_transform(pred_y)
    inv_test_Y = test_y_MinMax.inverse_transform(test_Y)
    mae = int(mean_absolute_error(inv_test_Y, inv_pred_y))
    print('test_mae : ', mae)
    mae = str(mae)
    print(name)
    m.save(
        model_path + name[0] + '_' + name[1] + '_' + name[2] + '_' + mae + '.h5')

需要数据和代码代写请私聊,pytorch版本的也有,cnn+lstm+attention对时序数据进行预测

参考论文:
基于CNN和LSTM的多通道注意力机制文本分类模型

Original: https://blog.csdn.net/qq_30803353/article/details/121875376
Author: AI信仰者
Title: cnn+lstm+attention对时序数据进行预测

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/726667/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • 如何在SpringBoot中优雅地重试调用第三方API?

    作为后端程序员,我们的日常工作就是调用一些第三方服务,将数据存入数据库,返回信息给前端。但你不能保证所有的事情一直都很顺利。像有些第三方API,偶尔会出现超时。此时,我们要重试几次…

    Python 2023年10月12日
    046
  • GraphPad Prism使用

    GraphPad Prism使用 一、基本表 * 1.XY型 2.列类 3.分组 二、图设计 * 1.图形选择 2.坐标轴 3.图表 三、例图 * 1.热图 一、基本表 1.XY型…

    Python 2023年9月20日
    053
  • 详解使用SSH远程连接Ubuntu服务器系统

    演示环境: 1.Windows10系统2.VMware Workstation Pro虚拟机2.Ubuntu16.04.6(以上版本通用) 回归正题 一、在Ubuntu端:1.首先…

    Python 2023年11月5日
    029
  • 商品销售关联分析

    商品销售关联分析 导入相关库 * 读取数据 数据编码 使用算法进行关联运算 导入相关库 import pandas as pd from mlxtend.frequent_patt…

    Python 2023年8月21日
    060
  • Matplotlib基础绘图

    目录 Matplotlib基础绘图 * 1.pyplot 2.绘图标记 – + marker参数 fmt参数 标记的大小和颜色 3.绘图线 轴标签和标题 Matplot…

    Python 2023年9月3日
    042
  • matplotlib升级遇到到的问题, “You probably need to get an updated matplotlibrc file from”

    环境与问题描述: windows 10 专业版 python 3.6 matplotlib 3.0.3 未查看分析matplotlib版本支持情况,贸然升级。 升级matplotl…

    Python 2023年9月1日
    069
  • MMDetection 使用示例:从入门到出门

    前言 最近对目标识别感兴趣,想做一些有趣目标识别项目自己玩耍,本来选择的是 YOLOV5 的,但无奈自己使用 YOLOV5 环境训练模型时,不管训练多少次 mAP 指标总是为 0,…

    Python 2023年10月25日
    041
  • SpringBoot自定义注解+异步+观察者模式实现业务日志保存

    一、前言 我们在企业级的开发中,必不可少的是对日志的记录,实现有很多种方式,常见的就是基于 AOP+注解进行保存,但是考虑到程序的流畅和…

    Python 2023年10月17日
    079
  • windos系统下Scrapy下载失败,报错解决方法(三步解决)

    windos系统下Scrapy下载失败,报错解决方法(三步解决) 如果觉得本文章对你有用,请给我一个免费的点赞和收藏,谢谢!我也是偶尔才上来看看而已,你们点赞收藏对我来说是一种认可…

    Python 2023年10月2日
    048
  • 打印 Logger 日志时,需不需要再封装一下工具类?

    在开发过程中,打印日志是必不可少的,因为日志关乎于应用的问题排查、应用监控等。现在打印日志一般都是使用 slf4j,因为使用日志门面,有助于打印方式统一,即使后面更换日志框架,也非…

    Python 2023年10月20日
    081
  • 汇编逆向-Qt

    Qt源码解析 索引 汇编逆向— 授权破解示例分析 问题模拟 运行环境 x64dbg Windows 10 Qt5.12.3 示例代码 使用Qt显示当前时间,模拟一般授权软件的时间判…

    Python 2023年9月15日
    042
  • IOS脱壳+反编译

    dumpdecryptedios砸壳 git clone https://github.com/stefanesser/dumpdecrypted.git #git 下载源码 su…

    Python 2023年11月9日
    038
  • 【Python】数据预处理之将类别数据转换为数值的方法(含Python代码分析)

    在进行Python数据分析的时候,首先要进行数据预处理。但是有时候不得不处理一些非数值类别的数据,遇到这类问题时该怎么解决呢? 目前为止,总结了三种方法,这里分享给大家。 这种方法…

    Python 2023年8月2日
    063
  • mysql索引

    MySQL索引: MySQL索引的建立对于MySQL的高效运行是很重要的,索引可以大大提高MySQL的检索速度。 1.创建索引1.1单独创建索引 1.2修改表结构创建索引 1.3创…

    Python 2023年6月16日
    080
  • Python基础编程(二十二)——类与对象3

    本篇文章主要是对python学习时的一些总结,作为学习笔记记录。 上篇文章主要介绍类中的方法。本篇文章主要是对python中的继承进行简单介绍。 C/C++是一种面向对象的编程语言…

    Python 2023年9月24日
    039
  • Pandas数据分析库

    一、Pandas 数据结构 Series :⽤列表⽣成 Series 时, Pandas 默认⾃动⽣成整数索引,也可以指定索引 s1 = pd.Series(data = l) #…

    Python 2023年8月8日
    059
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球