PINN学习与实验(一)

目录

今天第一天接触PINN,用深度学习的方法求解PDE,看来是非常不错的方法。做了一个简单易懂的例子,这个例子非常适合初学者。做了一个小demo, 大家可以参考参考

所用工具

使用了python和pytorch进行实现
python3.6
toch1.10

数学方程

使用一个最简单的常微分方程:
f ′ ( x ) = f ( x ) ( 1 ) f ( 0 ) = 1 ( 2 ) f'(x) = f(x) \hspace{2cm}(1) \ f(0) = 1 \hspace{2.6cm}(2)f ′(x )=f (x )(1 )f (0 )=1 (2 )
这个微分方程其实就是:
f ( x ) = e x ( 3 ) f(x)=e^{x} \hspace{2.45cm}(3)f (x )=e x (3 )

模型搭建

核心-使用最简单的全连接层:

class Net(nn.Module):
    def __init__(self, NL, NN):

        super(Net, self).__init__()
        self.input_layer = nn.Linear(1, NN)
        self.hidden_layer = nn.linear(NN,int(NN/2))
        self.output_layer = nn.Linear(int(NN/2), 1)

    def forward(self, x):
        out = torch.tanh(self.input_layer(x))
        out = torch.tanh(self.hidden_layer(out))
        out_final = self.output_layer(out)
        return out_final

偏微分方程定义,也就是公式(1):

def ode_01(x,net):
    y=net(x)
    y_x = autograd.grad(y, x,grad_outputs=torch.ones_like(net(x)),create_graph=True)[0]
    return y-y_x

所有实现代码

一下代码复制粘贴,可直接运行:

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch import autograd

"""
用神经网络模拟微分方程,f(x)'=f(x),初始条件f(0) = 1
"""

class Net(nn.Module):
    def __init__(self, NL, NN):

        super(Net, self).__init__()
        self.input_layer = nn.Linear(1, NN)
        self.hidden_layer = nn.Linear(NN,int(NN/2))
        self.output_layer = nn.Linear(int(NN/2), 1)

    def forward(self, x):
        out = torch.tanh(self.input_layer(x))
        out = torch.tanh(self.hidden_layer(out))
        out_final = self.output_layer(out)
        return out_final

net=Net(4,20)
mse_cost_function = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(net.parameters(),lr=1e-4)

def ode_01(x,net):
    y=net(x)
    y_x = autograd.grad(y, x,grad_outputs=torch.ones_like(net(x)),create_graph=True)[0]
    return y-y_x

plt.ion()
iterations=200000
for epoch in range(iterations):

    optimizer.zero_grad()

    x_0 = torch.zeros(2000, 1)
    y_0 = net(x_0)
    mse_i = mse_cost_function(y_0, torch.ones(2000, 1))

    x_in = np.random.uniform(low=0.0, high=2.0, size=(2000, 1))
    pt_x_in = autograd.Variable(torch.from_numpy(x_in).float(), requires_grad=True)
    pt_y_colection=ode_01(pt_x_in,net)
    pt_all_zeros= autograd.Variable(torch.from_numpy(np.zeros((2000,1))).float(), requires_grad=False)
    mse_f=mse_cost_function(pt_y_colection, pt_all_zeros)

    loss = mse_i + mse_f
    loss.backward()
    optimizer.step()

    if epoch%1000==0:
            y = torch.exp(pt_x_in)
            y_train0 = net(pt_x_in)
            print(epoch, "Traning Loss:", loss.data)
            print(f'times {epoch}  -  loss: {loss.item()} - y_0: {y_0}')
            plt.cla()
            plt.scatter(pt_x_in.detach().numpy(), y.detach().numpy())
            plt.scatter(pt_x_in.detach().numpy(), y_train0.detach().numpy(),c='red')
            plt.pause(0.1)

结果展示

训练0次时的结果也就是没训练,蓝色是真实值、红色是预测值:

PINN学习与实验(一)

训练2000次时的结果,蓝色是真实值、红色是预测值:

PINN学习与实验(一)

训练7000次和13000时的结果,蓝色是真实值、红色是预测值:

PINN学习与实验(一)
PINN学习与实验(一)

训练20000时的结果,蓝色是真实值、红色是预测值,不过红色已经完全把蓝色覆盖了,也就是完全拟合了:

PINN学习与实验(一)

; 参考文献

[1]. 每天进步一点点吧. PINN学习记录.

https://blog.csdn.net/weixin_45805559/article/details/121574293

Original: https://blog.csdn.net/qq_24211837/article/details/124383808
Author: 刘文凯
Title: PINN学习与实验(一)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/729131/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • pandas 第三章 索引

    import numpy as np import pandas as pd 一、索引器 列索引是最常见的索引形式,一般通过 []来实现。通过 [列&…

    Python 2023年8月21日
    044
  • JavaWeb-MVC、过滤器

    一、MVC架构图 Model 业务处理:业务逻辑(Service) 数据持久层:CRUD(Dao) View 展示数据 提供连接发起Servlet请求(a,form,img&#82…

    Python 2023年6月9日
    081
  • 重复造轮子 SimpleMapper

    接手的项目还在用 TinyMapper 的一个早期版本用来做自动映射工具,TinyMapper 虽然速度快,但在配置里不能转换类型,比如 deleted 在数据库中用 0、1 表示…

    Python 2023年10月22日
    052
  • python优质网站合集

    Python 键盘/鼠标监听及控制 – 简书正在上传…重新上传取消 jianshu.com Shell和Python获取键盘事件 – 星夜之夏 …

    Python 2023年9月22日
    056
  • django下

    django-admin startproject 项目名python manage.py runserver 运行项目python manage.py startapp 子应用名…

    Python 2023年8月5日
    074
  • pandas中DataFrame各种方法总结(持续更新)

    pandas 最有趣的地方在于里面隐藏了很多包。它是一个核心包,里面有很多其他包的功能。这点很棒,因为你只需要使用 pandas 就可以完成工作。pandas 相当于 python…

    Python 2023年8月6日
    048
  • scrapy框架数据库存储遇到问题即反馈

    1.mysql连接不上 (1)最开始以为这是因为mysql服务未启动原因 遭遇挫折,输入net start mysql命令后出现 MySQL 服务正在启动 . MySQL 服务无法…

    Python 2023年10月3日
    071
  • python发送请求给服务器参数传递方式以及服务器响应方式

    python发送请求给服务器参数传递方式以及服务器响应方式 1、(一) 2、利用Python进行图片发送与接收的两种方法—包含客户端和服务器端代码 * 1、方式一:第一…

    Python 2023年8月11日
    073
  • 使用pycharm+flask创建一个html网页

    准备工作:在pycharm中将flask设置为debug模式,点击Flask(app.py),编辑配置,进来后将FLASK_DEBUG的勾打上;上面的弄好之后,再来看一下Flask…

    Python 2023年8月9日
    055
  • YOLOv5的head详解

    YOLOv5的head详解 在前两篇文章中我们对YOLO的backbone和neck进行了详尽的解读,如果有小伙伴没看这里贴一下传送门:YOLOv5的Backbone设计YOLOv…

    Python 2023年10月24日
    059
  • DataFrame对象(创建,读取,添加,删除,方法)

    创建DataFrame对象 语法: pandas.DataFrame( data, index, columns, dtype, copy)data 支持多种数据类型,如:ndar…

    Python 2023年8月7日
    057
  • 八个超级好用的Python自动化脚本,简直太好用了

    每天你都可能会执行许多重复的任务,例如阅读新闻、发邮件、查看天气、打开书签、清理文件夹等等,使用自动化脚本,就无需手动一次又一次地完成这些任务,非常方便。而在某种程度上,Pytho…

    Python 2023年11月2日
    066
  • 用ChatGPT写一篇关于如何教新手Python的文章

    啊哦~你想找的内容离你而去了哦 内容不存在,可能为如下原因导致: ① 内容还在审核中 ② 内容以前存在,但是由于不符合新 的规定而被删除 ③ 内容地址错误 ④ 作者删除了内容。 可…

    Python 2023年11月4日
    024
  • vscode调试golang环境搭建及配置

    准备VSCode 在官网下载最新版的VSCode: 安装Golang插件 打开扩展面板 VSCode->查看->扩展 找到Go插件 在搜索框里输入Go, 找到第二行写有…

    Python 2023年6月16日
    059
  • Jmeter——结合Allure展示测试报告

    在平时用jmeter做测试时,生成报告的模板,不是特别好。大家应该也知道allure报告,页面美观。 先来看效果图,报告首页,如下所示: 报告详情信息,如下所示: 运行run.py…

    Python 2023年10月13日
    076
  • 关于DEJA_VU3D – Cesium功能集专栏说明

    博主简介 博主90后专业GIS行业开发人员,一直从事GIS相关工作5年左右,主要涉及三维和地图可视化等内容。工作中难免要接触到相关开发框架,对Cesium,Three.js,ope…

    Python 2023年11月5日
    049
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球