孪生网络(Siamese Network)实现手写数字聚类

Siamese Network通常用于小样本的学习,是meta learning的方法。

Siamese Network,其使用CNN网络作为特征提取器,不同类别的样本,共用一个CNN网络,在CNN网络之后添加了全连接层,可以用于判别输入的样本是否是同一类别。也就是二分类问题。

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:2eefa52c-1d9a-4816-b937-bfe221c02b91

[En]

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:6e5b01ff-a362-4d06-95b5-41cce0e8cadb

孪生网络(Siamese Network)实现手写数字聚类

图一特征提取

孪生网络(Siamese Network)实现手写数字聚类

图2 contrstive loss

相同类别相识度为1, 不同类别相识度为0

孪生网络(Siamese Network)实现手写数字聚类

图3 三元法 triplet loss

孪生网络(Siamese Network)实现手写数字聚类:margin (>0)超参数,期望不同类别的分离程度

孪生网络(Siamese Network)实现手写数字聚类

写成数学表达式就是:

孪生网络(Siamese Network)实现手写数字聚类

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:4af88f9a-df12-4296-b4b4-2d4e57115ae6

[En]

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:b9445694-38b4-4aa4-9736-62237e22d057

相应的loss function使用triplet loss,这种方法可以取得更好的效果。

这里先给出triplet loss,相应自定义数据集该日再补充

class TripletLoss(nn.Module):
"""
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
"""

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

这里给出完整代码,包含三个代码文件,siamese_dataset, model, main()

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:a127dfc1-3ef4-4bd0-aa80-06575b1c47f4

[En]

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:99f9e8ee-ce88-486d-8296-e45adffddecb

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torchvision.utils
import numpy as np
import random
from torch.utils.data.sampler import BatchSampler
from PIL import Image

class SiameseMNIST(Dataset):
"""
    Train: For each sample creates randomly a positive or a negative pair
    Test: Creates fixed pairs for testing
"""

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.targets
            self.train_data = self.mnist_dataset.data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.mnist_dataset.targets
            self.test_data = self.mnist_dataset.data
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,
                               random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),
                               0]
                              for i in range(1, len(self.test_data), 2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1, img2), target

    def __len__(self):
        return len(self.mnist_dataset)

因为MNIST数据集比较简单,所以模型也比较也简单。重点是,Contrastiveloss函数

import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=5),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64*4*4, 256),
            nn.PReLU(),

            nn.Linear(256, 256),
            nn.PReLU(),

            nn.Linear(256, 2))

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        distances = (output1-output2).pow(2).sum(1)
        loss = 0.5*(target.float()*distances +
                    (1 - target).float()*F.relu(self.margin - (distances+self.eps).sqrt()).pow(2))

        return loss.mean() if size_average else loss.sum()

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:7ff347fd-1816-4bdf-9170-93af532ebfbe

[En]

[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:6bd1e689-7b3d-4589-b440-fff4a2b3301e

import sys
import os
import torch
import torch.nn as nn
import torchvision
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import nibabel as nib
import argparse
from tqdm import tqdm
import visdom
from Siamese_minist import SiameseMNIST
from siamese_model import SiameseNetwork, ContrastiveLoss

parser = argparse.ArgumentParser()
parser.add_argument('--train_dir', type=str, default='./data')
parser.add_argument('--test_dir', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--epochs', type=int, default=20, help='number epoch to training')
parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--nw', type=int, default=16, help='Dataloader num_works')
parser.add_argument('--save_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/weight_path/Siamese_model.pth', help='model weight save path')
parser.add_argument('--train_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/data_selected/category_10/train_data', help='training data path')
parser.add_argument('--test_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/data_selected/category_10/test_data', help='test data path')
parser.add_argument('--margin', type=float, default=1.0, help='contrastive loss margin ')
parser.add_argument('--gamma', type=float, default=0.95, help='optimizer scheduler gamma')

torch.manual_seed(1)

opt = parser.parse_args()
print(opt)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#visdom 可视化,在Teminal窗口输入 python3 -m visdom.server
viz = visdom.Visdom()
train_dataset_path = opt.train_path
test_dataset_path = opt.test_path
mean, std = 0.1307, 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean,), (std,))])
minist_path = "/home/yang/cnn3d/mutipule_calssification/SiameseNetwork/mnist"

minist_train = dataset.MNIST(minist_path, train=True, transform=transform, download=False)
minist_test = dataset.MNIST(minist_path, train=False, transform=transform, download=False)

train_dataset = SiameseMNIST(minist_train)
test_dataset = SiameseMNIST(minist_test)

train_loader = DataLoader(minist_train, batch_size=64)
test_loader = DataLoader(minist_test, batch_size=64)

train_dataloader = DataLoader(train_dataset,
                        shuffle=True,
                        num_workers=opt.nw,
                        batch_size=opt.batch_size)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=opt.batch_size, num_workers=opt.nw)

net = SiameseNetwork().to(device)

criterion = ContrastiveLoss(margin=opt.margin)
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
scheduler = ExponentialLR(optimizer, gamma=opt.gamma)
scheduler1 = MultiStepLR(optimizer, [10, 20], gamma=0.1)
def show_plot(iteration,loss):
    plt.plot(iteration, loss)
    plt.show()

mnist_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf']

def plot_embeddings(embeddings, targets, xlim=None, ylim=None):
    plt.figure(figsize=(10, 10))
    for i in range(10):
        inds = np.where(targets==i)[0]
        plt.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5, color=colors[i])
    if xlim:
        plt.xlim(xlim[0], xlim[1])
    if ylim:
        plt.ylim(ylim[0], ylim[1])
    plt.legend(mnist_classes)

def extract_embeddings(dataloader, model, cuda=True):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), 2))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.to(device)
            embeddings[k:k+len(images)] = model.forward_once(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels

viz.line([0.], [0.], win='train_loss', opts=dict(title='training Loss'))
viz.line([0.], [0.], win='val_loss', opts=dict(title='valuation Loss'))

def main():
    net.train()
    counter = []
    loss_history = []
    iteration_number = 0
    global_step = 0.0
    val_step = 0.0
    for epoch in range(opt.epochs):
        train_loss = 0.0
        train_bar = tqdm(train_dataloader, file=sys.stdout)
        for index, data in enumerate(train_bar):
            (image0, image1), label = data
            image0, image1, label = image0.to(device), image1.to(device), label.to(device)
            optimizer.zero_grad()

            output1, output2 = net(image0, image1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            train_loss += loss_contrastive.item()
            global_step += 1

            optimizer.step()
            viz.line([loss_contrastive.item()], [global_step], win='train_loss', opts=dict(title='training Loss'),
                     update='append')

            if index % 10 == 0:
                iteration_number += 10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.item())

        print("Epoch number {} Current loss {}".format(epoch+1, train_loss/(len(train_dataloader))))
        print("第%d个epoch的学习率:%f" % (epoch + 1, optimizer.param_groups[0]['lr']))
        scheduler1.step()
        if epoch % 5 == 0:
            net.eval()
            with torch.no_grad():
                loss = 0.0
                val_bar = tqdm(test_dataloader, file=sys.stdout)
                for index, data in enumerate(val_bar):
                    val_step += 1
                    (val_image0, val_image1), val_label = data
                    val_image0, val_image1, val_label = val_image0.to(device), val_image1.to(device), val_label.to(device)
                    output1, output2 = net(val_image0, val_image1)
                    loss_contrastive = criterion(output1, output2, val_label)
                    loss += loss_contrastive.item()
                    viz.line([loss_contrastive.item()], [val_step], win='val_loss', opts=dict(title='valuation loss'), update='append')
                print('epoch %d| valuation Loss:%.4f' % (epoch, loss/len(test_dataloader)))
    # torch.save(net.state_dict(), opt.save_path)
    show_plot(counter, loss_history)

def valuation():
    net.eval()
    dataiter = iter(test_dataloader)
    with torch.no_grad():
        num = 0.0
        x0, _, label1 = next(dataiter)
        min_diatance = 10
        predic_label = None
        for i in range(len(test_dataset)-1):
            _, x1, label2 = next(dataiter)
            output1, output2 = net(Variable(x0).cuda(), Variable(x1).cuda())
            euclidean_distance = F.pairwise_distance(output1, output2)
            if euclidean_distance < min_diatance:
                min_diatance = euclidean_distance
                predic_label = label2
            if predic_label == label1:
                num += 1

        print('min diatance: ', min_diatance)
        print('predicted label', predic_label)

if __name__ == '__main__':
    main()
    #&#x805A;&#x7C7B;&#x7ED3;&#x679C;&#x53EF;&#x89C6;&#x5316;
    train_embeddings, train_labels = extract_embeddings(train_loader, net)
    #figure1 train data
    plot_embeddings(train_embeddings, train_labels)
    val_embeddings, val_labels = extract_embeddings(test_loader, net)
    #figure2 test data
    plot_embeddings(val_embeddings, val_labels)
    plt.show()

运行结果:

训练的loss曲线:

孪生网络(Siamese Network)实现手写数字聚类

训练集数据效果:

孪生网络(Siamese Network)实现手写数字聚类

Original: https://blog.csdn.net/m0_58256026/article/details/125107657
Author: ‘韫玉’
Title: 孪生网络(Siamese Network)实现手写数字聚类

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/561310/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球