Informer源码分析

首先是数据准备阶段的入口函数,位于Exp_Informer类的train函数内

train_data, train_loader = self._get_data(flag = 'train')

self._get_data的实现如下,该函数主要就是根据所选择的数据集加载数据,之后构建DataSet和DataLoader:

def _get_data(self, flag):
    args = self.args

    data_dict = {
        'ETTh1':Dataset_ETT_hour,
        'ETTh2':Dataset_ETT_hour,
        'ETTm1':Dataset_ETT_minute,
        'ETTm2':Dataset_ETT_minute,
        'WTH':Dataset_Custom,
        'ECL':Dataset_Custom,
        'Solar':Dataset_Custom,
        'custom':Dataset_Custom,
    }
    Data = data_dict[self.args.data]

    timeenc = 0 if args.embed!='timeF' else 1

    if flag == 'test':
        shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freq
    elif flag=='pred':
        shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freq
        Data = Dataset_Pred
    else:
        shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq
    data_set = Data(
        root_path=args.root_path,
        data_path=args.data_path,
        flag=flag,
        size=[args.seq_len, args.label_len, args.pred_len],
        features=args.features,
        target=args.target,
        inverse=args.inverse,
        timeenc=timeenc,
        freq=freq,
        cols=args.cols
    )
    print(flag, len(data_set))
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)

    return data_set, data_loader

数据集的加载可以按照不同的时间粒度进行构建,这里以Dataset_ETT_hour类为例子,其__init__函数如下

def __init__(self, root_path, flag='train', size=None,
             features='S', data_path='ETTh1.csv',
             target='OT', scale=True, inverse=False, timeenc=0, freq='h', cols=None):

    if size == None:
        self.seq_len = 24*4*4
        self.label_len = 24*4
        self.pred_len = 24*4
    else:
        self.seq_len = size[0]
        self.label_len = size[1]
        self.pred_len = size[2]

    assert flag in ['train', 'test', 'val']
    type_map = {'train':0, 'val':1, 'test':2}
    self.set_type = type_map[flag]

    self.features = features
    self.target = target
    self.scale = scale
    self.inverse = inverse
    self.timeenc = timeenc
    self.freq = freq

    self.root_path = root_path
    self.data_path = data_path
    self.__read_data__()

在初始化时最重要的函数就是_ _read_data__

def __read_data__(self):
    self.scaler = StandardScaler()

    df_raw = pd.read_csv(os.path.join(self.root_path,
                                      self.data_path))

    border1s = [0, 12*30*24 - self.seq_len, 12*30*24+4*30*24 - self.seq_len]
    border2s = [12*30*24, 12*30*24+4*30*24, 12*30*24+8*30*24]

    border1 = border1s[self.set_type]
    border2 = border2s[self.set_type]

    if self.features=='M' or self.features=='MS':
        cols_data = df_raw.columns[1:]
        df_data = df_raw[cols_data]
    elif self.features=='S':
        df_data = df_raw[[self.target]]

    if self.scale:
        train_data = df_data[border1s[0]:border2s[0]]
        self.scaler.fit(train_data.values)
        data = self.scaler.transform(df_data.values)
    else:
        data = df_data.values

    df_stamp = df_raw[['date']][border1:border2]
    df_stamp['date'] = pd.to_datetime(df_stamp.date)

    data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq)

    self.data_x = data[border1:border2]
    if self.inverse:
        self.data_y = df_data.values[border1:border2]
    else:
        self.data_y = data[border1:border2]
    self.data_stamp = data_stamp

1.1 时间的处理

在上面read_data中需要详细了解的是time_features函数,该函数的实现如下:

def time_features(dates, timeenc=1, freq='h'):
"""
    > time_features takes in a dates dataframe with a 'dates' column and extracts the date down to freq where freq can be any of the following if timeenc is 0:
    > * m - [month]
    > * w - [month]
    > * d - [month, day, weekday]
    > * b - [month, day, weekday]
    > * h - [month, day, weekday, hour]
    > * t - [month, day, weekday, hour, *minute]
    >
    > If timeenc is 1, a similar, but different list of freq values are supported (all encoded between [-0.5 and 0.5]):
    > * Q - [month]
    > * M - [month]
    > * W - [Day of month, week of year]
    > * D - [Day of week, day of month, day of year]
    > * B - [Day of week, day of month, day of year]
    > * H - [Hour of day, day of week, day of month, day of year]
    > * T - [Minute of hour*, hour of day, day of week, day of month, day of year]
    > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year]

    *minute returns a number from 0-3 corresponding to the 15 minute period it falls into.

"""

    if timeenc==0:
        dates['month'] = dates.date.apply(lambda row:row.month,1)
        dates['day'] = dates.date.apply(lambda row:row.day,1)
        dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1)
        dates['hour'] = dates.date.apply(lambda row:row.hour,1)
        dates['minute'] = dates.date.apply(lambda row:row.minute,1)
        dates['minute'] = dates.minute.map(lambda x:x//15)
        freq_map = {
            'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'],
            'b':['month','day','weekday'],'h':['month','day','weekday','hour'],
            't':['month','day','weekday','hour','minute'],
        }
        return dates[freq_map[freq.lower()]].values
    if timeenc==1:
        dates = pd.to_datetime(dates.date.values)

        return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0)

time_features_from_frequency_str的实现如下:

def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
"""
    Returns a list of time features that will be appropriate for the given frequency string.

    Parameters
    ----------
    freq_str
        Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.

"""

    features_by_offsets = {
        offsets.YearEnd: [],
        offsets.QuarterEnd: [MonthOfYear],
        offsets.MonthEnd: [MonthOfYear],
        offsets.Week: [DayOfMonth, WeekOfYear],
        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Minute: [
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
        offsets.Second: [
            SecondOfMinute,
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
    }

    offset = to_offset(freq_str)

    for offset_type, feature_classes in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return [cls() for cls in feature_classes]

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}
    The following frequencies are supported:
        Y   - yearly
            alias: A
        M   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
        S   - secondly
"""
    raise RuntimeError(supported_freq_msg)

接下来回过来继续分析Dataset_ETT_hour类的函数_ _getitem__

def __getitem__(self, index):
    s_begin = index
    s_end = s_begin + self.seq_len
    r_begin = s_end - self.label_len
    r_end = r_begin + self.label_len + self.pred_len

    seq_x = self.data_x[s_begin:s_end]
    if self.inverse:
        seq_y = np.concatenate([self.data_x[r_begin:r_begin+self.label_len], self.data_y[r_begin+self.label_len:r_end]], 0)
    else:
        seq_y = self.data_y[r_begin:r_end]
    seq_x_mark = self.data_stamp[s_begin:s_end]
    seq_y_mark = self.data_stamp[r_begin:r_end]

    return seq_x, seq_y, seq_x_mark, seq_y_mark
def train(self, setting):
    train_data, train_loader = self._get_data(flag = 'train')
    vali_data, vali_loader = self._get_data(flag = 'val')
    test_data, test_loader = self._get_data(flag = 'test')

    path = os.path.join(self.args.checkpoints, setting)
    if not os.path.exists(path):
        os.makedirs(path)

    time_now = time.time()

    train_steps = len(train_loader)
    early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

    model_optim = self._select_optimizer()
    criterion =  self._select_criterion()

    if self.args.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    for epoch in range(self.args.train_epochs):
        iter_count = 0
        train_loss = []

        self.model.train()
        epoch_time = time.time()
        for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):
            iter_count += 1

            model_optim.zero_grad()
            pred, true = self._process_one_batch(
                train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
            loss = criterion(pred, true)
            train_loss.append(loss.item())

            if (i+1) % 100==0:
                print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                speed = (time.time()-time_now)/iter_count
                left_time = speed*((self.args.train_epochs - epoch)*train_steps - i)
                print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                iter_count = 0
                time_now = time.time()

            if self.args.use_amp:
                scaler.scale(loss).backward()
                scaler.step(model_optim)
                scaler.update()
            else:
                loss.backward()
                model_optim.step()

        print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))
        train_loss = np.average(train_loss)
        vali_loss = self.vali(vali_data, vali_loader, criterion)
        test_loss = self.vali(test_data, test_loader, criterion)

        print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
            epoch + 1, train_steps, train_loss, vali_loss, test_loss))
        early_stopping(vali_loss, self.model, path)
        if early_stopping.early_stop:
            print("Early stopping")
            break

        adjust_learning_rate(model_optim, epoch+1, self.args)

    best_model_path = path+'/'+'checkpoint.pth'
    self.model.load_state_dict(torch.load(best_model_path))

    return self.model

这里核心是函数_process_one_batch,其实现如下:

def _process_one_batch(self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark):
    batch_x = batch_x.float().to(self.device)
    batch_y = batch_y.float()

    batch_x_mark = batch_x_mark.float().to(self.device)
    batch_y_mark = batch_y_mark.float().to(self.device)

    if self.args.padding==0:

        dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
    elif self.args.padding==1:
        dec_inp = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
    dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)

    if self.args.use_amp:
        with torch.cuda.amp.autocast():
            if self.args.output_attention:
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
            else:
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    else:
        if self.args.output_attention:
            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
        else:
            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    if self.args.inverse:
        outputs = dataset_object.inverse_transform(outputs)
    f_dim = -1 if self.args.features=='MS' else 0
    batch_y = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)

    return outputs, batch_y

在了解具体的训练逻辑之前,我们需要看一下模型是在什么时候初始化的。模型的初始化是在Exp_Informer类初始话的时候完成的,其中调用了函数_build_model。Exp_Informer继承自Exp_Basic类,Exp_Basic类的定义如下:

class Exp_Basic(object):
    def __init__(self, args):
        self.args = args
        self.device = self._acquire_device()
        self.model = self._build_model().to(self.device)

    def _build_model(self):
        raise NotImplementedError
        return None

    def _acquire_device(self):
        if self.args.use_gpu:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
            device = torch.device('cuda:{}'.format(self.args.gpu))
            print('Use GPU: cuda:{}'.format(self.args.gpu))
        else:
            device = torch.device('cpu')
            print('Use CPU')
        return device

    def _get_data(self):
        pass

    def vali(self):
        pass

    def train(self):
        pass

    def test(self):
        pass

下面是在Exp_Informer中实现的_build_model函数的详细代码

def _build_model(self):
    model_dict = {
        'informer':Informer,
        'informerstack':InformerStack,
    }
    if self.args.model=='informer' or self.args.model=='informerstack':
        e_layers = self.args.e_layers if self.args.model=='informer' else self.args.s_layers
        model = model_dict[self.args.model](
            self.args.enc_in,
            self.args.dec_in,
            self.args.c_out,
            self.args.seq_len,
            self.args.label_len,
            self.args.pred_len,
            self.args.factor,
            self.args.d_model,
            self.args.n_heads,
            e_layers,
            self.args.d_layers,
            self.args.d_ff,
            self.args.dropout,
            self.args.attn,
            self.args.embed,
            self.args.freq,
            self.args.activation,
            self.args.output_attention,
            self.args.distil,
            self.args.mix,
            self.device
        ).float()

    if self.args.use_multi_gpu and self.args.use_gpu:
        model = nn.DataParallel(model, device_ids=self.args.device_ids)
    return model

3.1 Informer架构

以下时Informer的初始化过程

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-x11hnMVn-1654163034088)(C:\work\Note\img\encoderStructure.png)]

def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
            factor=5, d_model=512, n_heads=8, e_layers=3, d_layers=2, d_ff=512,
            dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu',
            output_attention = False, distil=True, mix=True,
            device=torch.device('cuda:0')):
    super(Informer, self).__init__()
    self.pred_len = out_len
    self.attn = attn
    self.output_attention = output_attention

    self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
    self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)

    Attn = ProbAttention if attn=='prob' else FullAttention

    self.encoder = Encoder(
        [
            EncoderLayer(
                AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention),
                            d_model, n_heads, mix=False),
                d_model,
                d_ff,
                dropout=dropout,
                activation=activation
            ) for l in range(e_layers)
        ],
        [
            ConvLayer(
                d_model
            ) for l in range(e_layers-1)
        ] if distil else None,
        norm_layer=torch.nn.LayerNorm(d_model)
    )

    self.decoder = Decoder(
        [
            DecoderLayer(
                AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False),
                            d_model, n_heads, mix=mix),
                AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False),
                            d_model, n_heads, mix=False),
                d_model,
                d_ff,
                dropout=dropout,
                activation=activation,
            )
            for l in range(d_layers)
        ],
        norm_layer=torch.nn.LayerNorm(d_model)
    )

    self.projection = nn.Linear(d_model, c_out, bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
            enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
    enc_out = self.enc_embedding(x_enc, x_mark_enc)
    enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)

    dec_out = self.dec_embedding(x_dec, x_mark_dec)
    dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
    dec_out = self.projection(dec_out)

    if self.output_attention:
        return dec_out[:,-self.pred_len:,:], attns
    else:
        return dec_out[:,-self.pred_len:,:]
class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(DataEmbedding, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)

        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):

        x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)

        return self.dropout(x)

从DataEmbedding的结构可以看出其中分别构建了tokenEmbedding、positionEmbedding、temporalEmbedding三个模块

class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__>='1.5.0' else 2

        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                    kernel_size=3, padding=padding, padding_mode='circular')
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')

    def forward(self, x):

        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
        return x
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]
class TemporalEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='fixed', freq='h'):
        super(TemporalEmbedding, self).__init__()

        minute_size = 4; hour_size = 24
        weekday_size = 7; day_size = 32; month_size = 13

        Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
        if freq=='t':
            self.minute_embed = Embed(minute_size, d_model)
        self.hour_embed = Embed(hour_size, d_model)
        self.weekday_embed = Embed(weekday_size, d_model)
        self.day_embed = Embed(day_size, d_model)
        self.month_embed = Embed(month_size, d_model)

   def forward(self, x):
        x = x.long()

        minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.

        hour_x = self.hour_embed(x[:,:,3])
        weekday_x = self.weekday_embed(x[:,:,2])
        day_x = self.day_embed(x[:,:,1])
        month_x = self.month_embed(x[:,:,0])

        return hour_x + weekday_x + day_x + month_x + minute_x
class FixedEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(FixedEmbedding, self).__init__()

        w = torch.zeros(c_in, d_model).float()
        w.require_grad = False

        position = torch.arange(0, c_in).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        w[:, 0::2] = torch.sin(position * div_term)
        w[:, 1::2] = torch.cos(position * div_term)

        self.emb = nn.Embedding(c_in, d_model)
        self.emb.weight = nn.Parameter(w, requires_grad=False)

    def forward(self, x):
        return self.emb(x).detach()
class TimeFeatureEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='timeF', freq='h'):
        super(TimeFeatureEmbedding, self).__init__()

        freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
        d_inp = freq_map[freq]
        self.embed = nn.Linear(d_inp, d_model)

    def forward(self, x):
        return self.embed(x)

在分析Informer的ProbAttention之前我们先来分析一下原始self-Attention的源码,如下代码是Attention机制的整体架构,其中d_keys和d_values的维度与d_model和n_heads(注意力头)相关,AttentionLayer中的forward过程主要做的事情就是将embedding的输入序列映射到n_heads个注意力头,并通过具体的inner_attention来完成自注意力的过程:

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads,
                 d_keys=None, d_values=None, mix=False):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model//n_heads)
        d_values = d_values or (d_model//n_heads)

        self.inner_attention = attention

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads
        self.mix = mix

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask
        )
        if self.mix:
            out = out.transpose(2,1).contiguous()
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

下面分析一下FullAttention,即原始的Transformer中的self-attention的具体源码

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1./sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

其中TriangularCausalMask类的主要功能是给每一个batch的序列打上mask,因为输入某一时刻的特征是无法知道未来时刻的特征的

class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask
class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

首先我们来看一下ProbAttention的forward函数:

def forward(self, queries, keys, values, attn_mask):
    B, L_Q, H, D = queries.shape
    _, L_K, _, _ = keys.shape

    queries = queries.transpose(2,1)
    keys = keys.transpose(2,1)
    values = values.transpose(2,1)

    U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item()
    u = self.factor * np.ceil(np.log(L_Q)).astype('int').item()

    U_part = U_part if U_part<L_K else L_K
    u = u if u<L_Q else L_Q

    scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

    scale = self.scale or 1./sqrt(D)
    if scale is not None:
        scores_top = scores_top * scale

    context = self._get_initial_context(values, L_Q)

    context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)

    return context.transpose(2,1).contiguous(), attn

上面的代码中涉及到几个重要的函数实现,我们接下来依次进行解析,首先是函数_prob_QK,该函数返回经过筛选后的query和key内积后的结果,以及筛选出的u-top个query的index:

def _prob_QK(self, Q, K, sample_k, n_top):

    B, H, L_K, E = K.shape
    _, _, L_Q, _ = Q.shape

    K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)

    index_sample = torch.randint(L_K, (L_Q, sample_k))

    K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]

    Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)

    M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)

    M_top = M.topk(n_top, sorted=False)[1]

    Q_reduce = Q[torch.arange(B)[:, None, None],
                 torch.arange(H)[None, :, None],
                 M_top, :]

    Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))

    return Q_K, M_top

我们这里来逐行的分析_prob_QK函数的过程,首先我们造一个假的输入数据,K的维度是(1, 2, 4, 6),Q与K相同:

K = torch.linspace(1, 48, steps=48).resize(1, 2, 4, 6)
Q = torch.linspace(1, 48, steps=48).resize(1, 2, 4, 6)
'''
Q与K相同:(1, 2, 4, 6)其中1-24的数据属于head-1,25-48的数据属于head-2
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])
'''
K_unsqueeze = K.unsqueeze(-3)
'''
K_unsqueeze:(1, 2, 1, 4, 6)
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]]],

         [[[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]]]]])
'''
K_expand = K_unsqueeze.expand(B, H, L_Q, L_K, E)
'''
K_expand:(1, 2, 4, 4, 6)相当于在倒数第三维上将最后两维的数据复制了四遍
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]]],

         [[[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],

          [[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],

          [[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],

          [[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]]]]])
'''
index_sample = torch.randint(L_K, (L_Q, sample_k))
'''
index_sample:(4, 2)
tensor([[3, 3],
        [3, 0],
        [2, 3],
        [0, 3]])
'''
K_tmp_id = torch.arange(L_Q).unsqueeze(1)
'''
K_tmp_id:(4, 1)
tensor([[0],
        [1],
        [2],
        [3]])
'''
K_sample = K_expand[:, :, K_tmp_id, index_sample, :]
'''
K_sample:(1, 2, 4, 2, 6)
tensor([[[[[19., 20., 21., 22., 23., 24.],
           [19., 20., 21., 22., 23., 24.]],head-1对应K_tmp_id[0], index_sample的[3, 3]

          [[19., 20., 21., 22., 23., 24.],
           [ 1.,  2.,  3.,  4.,  5.,  6.]],head-1对应K_tmp_id[1], index_sample的[3, 0]

          [[13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],head-1对应K_tmp_id[2], index_sample的[2, 3]

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [19., 20., 21., 22., 23., 24.]]],head-1对应K_tmp_id[3], index_sample的[0, 3]

         [[[43., 44., 45., 46., 47., 48.],
           [43., 44., 45., 46., 47., 48.]],head-2对应K_tmp_id[0], index_sample的[3, 3]

          [[43., 44., 45., 46., 47., 48.],
           [25., 26., 27., 28., 29., 30.]],head-2对应K_tmp_id[1], index_sample的[3, 0]

          [[37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],head-2对应K_tmp_id[2], index_sample的[2, 3]

          [[25., 26., 27., 28., 29., 30.],
           [43., 44., 45., 46., 47., 48.]]]]])head-2对应K_tmp_id[3], index_sample的[0, 3]
'''
Q_unsqueeze = Q.unsqueeze(-2)
'''
Q_unsqueeze:(1, 2, 4, 1, 6)
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.]],

          [[ 7.,  8.,  9., 10., 11., 12.]],

          [[13., 14., 15., 16., 17., 18.]],

          [[19., 20., 21., 22., 23., 24.]]],

         [[[25., 26., 27., 28., 29., 30.]],

          [[31., 32., 33., 34., 35., 36.]],

          [[37., 38., 39., 40., 41., 42.]],

          [[43., 44., 45., 46., 47., 48.]]]]])
'''
K_sample_trans = K_sample.transpose(-2, -1)
'''
K_sample_trans:(1, 2, 4, 6, 2)
tensor([[[[[19., 19.],
           [20., 20.],
           [21., 21.],
           [22., 22.],
           [23., 23.],
           [24., 24.]],

          [[19.,  1.],
           [20.,  2.],
           [21.,  3.],
           [22.,  4.],
           [23.,  5.],
           [24.,  6.]],

          [[13., 19.],
           [14., 20.],
           [15., 21.],
           [16., 22.],
           [17., 23.],
           [18., 24.]],

          [[ 1., 19.],
           [ 2., 20.],
           [ 3., 21.],
           [ 4., 22.],
           [ 5., 23.],
           [ 6., 24.]]],

         [[[43., 43.],
           [44., 44.],
           [45., 45.],
           [46., 46.],
           [47., 47.],
           [48., 48.]],

          [[43., 25.],
           [44., 26.],
           [45., 27.],
           [46., 28.],
           [47., 29.],
           [48., 30.]],

          [[37., 43.],
           [38., 44.],
           [39., 45.],
           [40., 46.],
           [41., 47.],
           [42., 48.]],

          [[25., 43.],
           [26., 44.],
           [27., 45.],
           [28., 46.],
           [29., 47.],
           [30., 48.]]]]])
'''
Q_K_sample_nonsqueeze = torch.matmul(Q_unsqueeze, K_sample_trans)
'''
Q_unsqueeze:(1, 2, 4, 1, 6)
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.]],

          [[ 7.,  8.,  9., 10., 11., 12.]],

          [[13., 14., 15., 16., 17., 18.]],

          [[19., 20., 21., 22., 23., 24.]]],

         [[[25., 26., 27., 28., 29., 30.]],

          [[31., 32., 33., 34., 35., 36.]],

          [[37., 38., 39., 40., 41., 42.]],

          [[43., 44., 45., 46., 47., 48.]]]]])
 K_sample_trans:(1, 2, 4, 6, 2)
 tensor([[[[[19., 19.],
           [20., 20.],
           [21., 21.],
           [22., 22.],
           [23., 23.],
           [24., 24.]],

          [[19.,  1.],
           [20.,  2.],
           [21.,  3.],
           [22.,  4.],
           [23.,  5.],
           [24.,  6.]],

          [[13., 19.],
           [14., 20.],
           [15., 21.],
           [16., 22.],
           [17., 23.],
           [18., 24.]],

          [[ 1., 19.],
           [ 2., 20.],
           [ 3., 21.],
           [ 4., 22.],
           [ 5., 23.],
           [ 6., 24.]]],

         [[[43., 43.],
           [44., 44.],
           [45., 45.],
           [46., 46.],
           [47., 47.],
           [48., 48.]],

          [[43., 25.],
           [44., 26.],
           [45., 27.],
           [46., 28.],
           [47., 29.],
           [48., 30.]],

          [[37., 43.],
           [38., 44.],
           [39., 45.],
           [40., 46.],
           [41., 47.],
           [42., 48.]],

          [[25., 43.],
           [26., 44.],
           [27., 45.],
           [28., 46.],
           [29., 47.],
           [30., 48.]]]]])

Q_K_sample_nonsqueeze:(1, 2, 4, 1, 2)
tensor([[[[[  469.,   469.]],

          [[ 1243.,   217.]],

          [[ 1459.,  2017.]],

          [[  469.,  2791.]]],

         [[[ 7525.,  7525.]],

          [[ 9163.,  5545.]],

          [[ 9379., 10801.]],

          [[ 7525., 12439.]]]]])
'''
Q_K_sample = Q_K_sample_nonsqueeze.squeeze(-2)
'''
Q_K_sample:(1, 2, 4, 2)
tensor([[[[  469.,   469.],
          [ 1243.,   217.],
          [ 1459.,  2017.],
          [  469.,  2791.]],

         [[ 7525.,  7525.],
          [ 9163.,  5545.],
          [ 9379., 10801.],
          [ 7525., 12439.]]]])
'''

Q_K_sample_sum = Q_K_sample.sum(-1)
'''
Q_K_sample_sum:(1, 2, 4)
tensor([[[  938.,  1460.,  3476.,  3260.],
         [15050., 14708., 20180., 19964.]]])
'''
Q_K_sample_max = Q_K_sample.max(-1)
'''
torch.return_types.max(
values=tensor([[[  469.,  1243.,  2017.,  2791.],
         [ 7525.,  9163., 10801., 12439.]]]),
indices=tensor([[[0, 0, 1, 1],
         [0, 0, 1, 1]]]))'''
div_tmp = torch.div(Q_K_sample_sum, L_K)
'''
div_tmp:(1, 2, 4)
tensor([[[ 234.5000,  365.0000,  869.0000,  815.0000],
         [3762.5000, 3677.0000, 5045.0000, 4991.0000]]])
'''

M = Q_K_sample_max[0] - torch.div(Q_K_sample_sum, L_K)
'''
M:(1, 2, 4)
tensor([[[ 234.5000,  878.0000, 1148.0000, 1976.0000],
         [3762.5000, 5486.0000, 5756.0000, 7448.0000]]])
'''

M_top_tmp = M.topk(n_top, sorted=False)
'''
torch.return_types.topk(
values=tensor([[[1976., 1148.],
         [7448., 5756.]]]),
indices=tensor([[[3, 2],
         [3, 2]]]))
'''
M_top = M_top_tmp[1]
'''
M_top:(1, 2, n_top:2)
tensor([[[3, 2],
         [3, 2]]])
'''

Q_reduce_B = torch.arange(B)[:, None, None]
'''
Q_reduce_B:(1, 1, 1)
tensor([[[0]]])
'''
Q_reduce_H = torch.arange(H)[None, :, None]
'''
Q_reduce_H:(1, 2, 1)
tensor([[[0],
         [1]]])
'''

Q_reduce = Q[Q_reduce_B,
             Q_reduce_H,
             M_top, :]
'''
Q:(1, 2, 4, 6)
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])

Q_reduce:(1, head:2, n_top:2, dim:6) 这里拿到的就是Q的head-1中位于坐标[3, 2]和head-2中位于坐标[3, 2]位置的query
tensor([[[[19., 20., 21., 22., 23., 24.],
          [13., 14., 15., 16., 17., 18.]],

         [[43., 44., 45., 46., 47., 48.],
          [37., 38., 39., 40., 41., 42.]]]])
'''
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
'''
Q_K:(1, 2, n_top:2, 4)
tensor([[[[  343.,   901.,  1459.,  2017.],
          [  469.,  1243.,  2017.,  2791.]],

         [[ 6535.,  7957.,  9379., 10801.],
          [ 7525.,  9163., 10801., 12439.]]]])
'''
return Q_K, M_top

通过上面代码的逐行打印分析可以发现,其实最终计算返回的Q_K就是计算论文中公式(3)的Q ‾ K T \overline{Q}K^T Q ​K T:
Λ ( Q , K , V ) = S o f t m a x ( Q ‾ K T d ) V \Lambda(Q,K,V)=Softmax(\frac{\overline{Q}K^T}{\sqrt{d}})V Λ(Q ,K ,V )=S o f t m a x (d ​Q ​K T ​)V
下面分析函数_get_initial_context,该函数的主要作用就是将V按照倒数第二维度进行均值计算,并扩展复制到多个head的维度:

def _get_initial_context(self, V, L_Q):
    B, H, L_V, D = V.shape
    if not self.mask_flag:

        V_sum = V.mean(dim=-2)
        contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
    else:
        assert(L_Q == L_V)
        contex = V.cumsum(dim=-2)
    return contex

我们将上述函数的实现展开一步一步推测其做的工作,输入的V与上一步的Q和K相同:

B, H, L_V, D = V.shape
'''
V:(1, 2, 4, 6)
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])
'''
if not self.mask_flag:
    V_sum = V.mean(dim=-2)
    '''
    V_sum:(1, 2, 6)
    tensor([[[10., 11., 12., 13., 14., 15.],
         [34., 35., 36., 37., 38., 39.]]])
    '''
    V_sum_unsequeese = V_sum.unsqueeze(-2)
    '''
    V_sum_unsequeese:(1, 2, 1, 6)
    tensor([[[[10., 11., 12., 13., 14., 15.]],

         [[34., 35., 36., 37., 38., 39.]]]])
    '''
    contex = V_sum_unsequeese.expand(B, H, L_Q, V_sum.shape[-1]).clone()
    '''
    contex:(1, 2, 4, 6)
    tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.]]]])
    '''
else:
    assert(L_Q == L_V)
    contex = V.cumsum(dim=-2)

之后_update_context的工作就是根据论文中的公式计算ProbSparse self-attention,其中Q ‾ \overline{Q}Q ​表示选择出的top-u个query所组成的新Q矩阵
A ( Q ; K ; V ) = S o f t m a x ( Q ‾ K T d ) V A(Q;K;V)=Softmax(\frac{\overline{Q}K^{T}}{\sqrt{d}})V A (Q ;K ;V )=S o f t m a x (d ​Q ​K T ​)V


def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
    B, H, L_V, D = V.shape

    if self.mask_flag:
        attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    attn = torch.softmax(scores, dim=-1)

    context_in[torch.arange(B)[:, None, None],
               torch.arange(H)[None, :, None],
               index, :] = torch.matmul(attn, V).type_as(context_in)
    if self.output_attention:
        attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
        attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
        return (context_in, attns)
    else:
        return (context_in, None)

我们将上面的函数展开以此来查看具体做了些什么操作:

B, H, L_V, D = V.shape

if self.mask_flag:
    attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
    scores.masked_fill_(attn_mask.mask, -np.inf)

attn = torch.softmax(scores, dim=-1)
'''
scores:(1, 2, 2, 4)
tensor([[[[ 140.0292,  367.8317,  595.6343,  823.4368],
          [ 191.4685,  507.4526,  823.4368, 1139.4210]],

         [[2667.9026, 3248.4319, 3828.9609, 4409.4897],
          [3072.0686, 3740.7793, 4409.4897, 5078.2007]]]])

attn:(1, 2, n_top:2, 4)
tensor([[[[0., 0., 0., 1.],
          [0., 0., 0., 1.]],

         [[0., 0., 0., 1.],
          [0., 0., 0., 1.]]]])
'''
attn_V = torch.matmul(attn, V).type_as(context_in)
'''
V:(1, 2, 4, 6)
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])

attn_V:(1, 2, 2, 6)
tensor([[[[19., 20., 21., 22., 23., 24.],
          [19., 20., 21., 22., 23., 24.]],

         [[43., 44., 45., 46., 47., 48.],
          [43., 44., 45., 46., 47., 48.]]]])
'''
context_in_B = torch.arange(B)[:, None, None]
context_in_H = torch.arange(H)[None, :, None]
context_in[context_in_B, context_in_H, index, :] = attn_V
'''
index:(1, 2, 2)
tensor([[[3, 2],
         [3, 2]]])
->index_t:(1, 2, 2, D)
tensor([[[[3., 3., 3., 3., 3., 3.],
          [2., 2., 2., 2., 2., 2.]],

         [[3., 3., 3., 3., 3., 3.],
          [2., 2., 2., 2., 2., 2.]]]])

 tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [19., 20., 21., 22., 23., 24.],
          [19., 20., 21., 22., 23., 24.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [43., 44., 45., 46., 47., 48.],
          [43., 44., 45., 46., 47., 48.]]]])

context_in old:(1, 2, 4, 6)
tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.]]]])

context_in new:(1, 2, 4, 6)
tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [19., 20., 21., 22., 23., 24.],
          [19., 20., 21., 22., 23., 24.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [43., 44., 45., 46., 47., 48.],
          [43., 44., 45., 46., 47., 48.]]]])
'''

if self.output_attention:
    attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
    attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
    return (context_in, attns)
else:
    return (context_in, None)
class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):

        attns = []
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:

                x, attn = attn_layer(x, attn_mask=attn_mask)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

这里将attn_layers和conv_layers的内容也打印在下面:


[
     EncoderLayer(
         AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention),
                        d_model, n_heads, mix=False),
         d_model,
         d_ff,
         dropout=dropout,
         activation=activation
     ) for l in range(e_layers)
  ]

[
    ConvLayer(
        d_model
    ) for l in range(e_layers-1)
] if distil else None

下面详细展示EncoderLayer和ConvLayer的详细实现:

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):

        new_x, attn = self.attention(
            x, x, x,
            attn_mask = attn_mask
        )

        x = x + self.dropout(new_x)

        y = x = self.norm1(x)

        y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
        y = self.dropout(self.conv2(y).transpose(-1,1))

        return self.norm2(x+y), attn

ConvLayer的结构就是在计算论文中的distill的公式,其中[ . ] A B [.]{AB}[.]A B ​表示attention block:
X j + 1 t = M a x P o o l ( E L U ( C o n v 1 d ( [ X j t ] A B ) ) ) X
{j+1}^{t}=MaxPool(ELU(Conv1d([X_j^t]_{AB})))X j +1 t ​=M a x P o o l (E L U (C o n v 1 d ([X j t ​]A B ​)))

class ConvLayer(nn.Module):

    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        padding = 1 if torch.__version__>='1.5.0' else 2
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=padding,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):

        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1,2)
        return x
class Decoder(nn.Module):
    def __init__(self, layers, norm_layer=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)

        if self.norm is not None:
            x = self.norm(x)

        return x

我们把decoder在Informer中的初始化放在下方,可以看到Decoder中每一层Decoder Layer中包含两个Attention Layer:

self.decoder = Decoder(
    [
        DecoderLayer(
            AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False),
                        d_model, n_heads, mix=mix)
            AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False),
                        d_model, n_heads, mix=False)
            d_model,
            d_ff,
            dropout=dropout,
            activation=activation,
        )
        for l in range(d_layers)
    ],
    norm_layer=torch.nn.LayerNorm(d_model)
)
class DecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None):

        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
        y = self.dropout(self.conv2(y).transpose(-1,1))

        return self.norm3(x+y)

落地方案采用转换为MNN框架兼容格式onnx,之后通过MNN框架进行落地部署

4.1 将pytorch代码转换为onnx格式

将model转换为onnx进行持久化用到pytorch中的torch.onnx.export接口(官方网址:https://pytorch.org/docs/master/onnx.html?highlight=torch%20onnx%20export#torch.onnx.export),先看一个网上的例子:

import torch
import torch.nn as nn
import onnx
import numpy as np
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=nn.Conv2d(3,3, kernel_size=3, stride=2,padding=1)
    def forward(self,x,y):
        result1=self.conv1(x)
        result2=self.conv1(y)
        return result1,result2

model=Model()
model.eval()

input_names = ["input_0","input_1"]
output_names = ["output_0","output_1"]

x=torch.randn((1,3,12,12))
y=torch.randn((1,3,6,6))

torch.onnx.export(model,(x,y),'model.onnx',input_names=input_names,output_names=output_names,
  dynamic_axes={'input_0':[0],'output_0':[0]})

需要先确定我们Informer模型的输入与输出的维度从而确定接口第二项参数的维度。

其中输入包括batch_x, batch_x_mark, dec_inp, batch_y_mark,输出包括outputs_app, outputs_user,其中在APP预测项目中各参数维度分别为:

  • batch_x:(32,12,64)
  • batch_x_mark:(32,12,5)
  • dec_inp:(32,7,64)
  • batch_y_mark:(32,7,5)
  • outputs_app:(32,1,1521)
  • outputs_user:(32,1,851)

Original: https://blog.csdn.net/jrh1223/article/details/122130497
Author: jrh1223
Title: Informer源码分析

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/613729/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球