训练AI玩贪吃蛇(强化学习与Q-learning)

欢迎加入我们卧虎藏龙的python讨论qq群:729683466

导 语

AI近些年发展很火

不但可以下围棋

还可以玩各种各样的游戏

国外有人构建AI模型

然后训练出了非常厉害的拳击手

相信有一天

AI和机器人结合

肯定能训练出比泰森还强的高手

今天我们看一个简单的AI项目

训练AI玩贪吃蛇(强化学习与Q-learning)

代码及相关资源获取

1:关注”python趣味爱好者”公众号, 回复”AI玩贪吃蛇 “获取源代码

效果演示

游戏只训练了很短的时间,所以效果不太好,感兴趣的同学可以多训练几次。

训练AI玩贪吃蛇(强化学习与Q-learning)

往期精选

sd

🛠

开发工具

python3.10

第三方库:pygame,collections,enum

讲解部分

一:原理介绍

这里用到的核心知识是强化学习,这是一种无监督的算法,与有标签的监督学习不同,在深度学习中,监督学习的意思就是需要人工标注的正确答案,通过这些正确答案来训练神经网络,最终使神经网络可以准确模拟出下一步的操作。

但是在无监督学习中,并没有人告诉程序,你走的哪一步是正确的,哪一步是错误的。程序自动不断试错,如果撞到了墙壁或者尾巴直接死亡,吃到食物获得奖励。这样神经网络才知道哪些动作正确,哪些不正确。

二:Q-learning算法

在Q-learning算法自我训练的过程中,有5个比较重要的元素来评估当前状态的价值,首先是state,也就是当前的状态,对于贪吃蛇来说,这个状态很简单,就是蛇的位置,食物的位置,蛇移动方向等等。将这些信息组合在一起,形成一个向量,这就是当前状态,然后第二个元素就是蛇的选择action,一共4个方向,向前是无效的,向后是自杀,所以只有两个方向,向左向右(相对于蛇的移动方向来说)。这个选择就是action。随后,我们把state做出选择action后产生的新状态写作state1,这是第三个元素。如果state1是游戏结束(死亡),那么第四个元素done就是False,否则就是True。最后一个元素就是reward,也就是这个状态state1的奖励,如果吃到了食物就是奖励10分,否则就0分,如果死亡了,分数那就是-10分。

用一个图表示上述状态就是。

训练AI玩贪吃蛇(强化学习与Q-learning)

Q-learning核心公式如下:

训练AI玩贪吃蛇(强化学习与Q-learning)

在代码中实现公式如下:

Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

用于评估每个状态的value,以训练神经网络

整个Q-learning的代码如下:

class QTrainer:
    def __init__(self,model,lr,gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimer = optim.Adam(model.parameters(),lr = self.lr)
        self.criterion = nn.MSELoss()
        for i in self.model.parameters():
            print(i.is_cuda)

    def train_step(self,state,action,reward,next_state,done):

        state = torch.tensor(state,dtype=torch.float)
        next_state = torch.tensor(next_state,dtype=torch.float)
        action = torch.tensor(action,dtype=torch.long)
        reward = torch.tensor(reward,dtype=torch.float)

        if(len(state.shape) == 1): # only one parameter to train , Hence convert to tuple of shape (1, x)
            #(1 , x)
            state = torch.unsqueeze(state,0)
            next_state = torch.unsqueeze(next_state,0)
            action = torch.unsqueeze(action,0)
            reward = torch.unsqueeze(reward,0)

            done = (done, )

        # 1. Predicted Q value with current state

        pred = self.model(state)
        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                #Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])).cuda()
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
            target[idx][torch.argmax(action).item()] = Q_new
        # 2. Q_new = reward + gamma * max(next_predicted Qvalue) -> only do this if not done
        # pred.clone()
        # preds[argmax(action)] = Q_new
        self.optimer.zero_grad()
        loss = self.criterion(target,pred)
        loss.backward()

        self.optimer.step()

三:神经网络

神经网络是深度学习的基本概念,这里不过多介绍,直接展示一下我们设计的神经网络代码,相信大家能看懂。

class Linear_QNet(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super().__init__()
        # self.linear1 = nn.Linear(input_size,hidden_size).cuda()
        # self.linear2 = nn.Linear(hidden_size,output_size).cuda()
        self.linear1 = nn.Linear(input_size,hidden_size)
        self.linear2 = nn.Linear(hidden_size,output_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

作者|齐

编辑|齐

感谢大家观看

有钱的老板可打赏一下小编哦

扫描二维码

关注我们

QQ群:729683466

参考来源

代码来源:github

Original: https://blog.csdn.net/qq_42680814/article/details/126397478
Author: 剑心211
Title: 训练AI玩贪吃蛇(强化学习与Q-learning)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/651619/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球