笔记(三)GMM与EM 高斯混合模型与EM算法

高斯混合模型是单一高斯机率密度函数的延伸,由于 GMM 能够平滑地近似任意形状的密度分布,因此近年来常被用在语音、图像识别等方面,得到不错的效果。

GMM认为数据是从几个SGM(单高斯模型(Single Gaussian Model, SGM))中生成出来的,即

P ( x ) = ∑ k = 1 K π k N ( x ; u k , σ k ) (1) P(x) = \sum_{k=1}^K \pi_k N(x; u_k, \sigma_k) \tag{1}P (x )=k =1 ∑K ​πk ​N (x ;u k ​,σk ​)(1 )

K K K需要事先确定好,就像K-means中的K K K一样。π k π_k πk ​是权值因子。其中的任意一个高斯分布N ( x ; u k , σ k ) N(x;u_k,\sigma_k)N (x ;u k ​,σk ​)叫作这个模型的一个component。这里有个问题,为什么我们要假设数据是由若干个高斯分布组合而成的,而不假设是其他分布呢?实际上不管是什么分布,只K K K取得足够大,这个XX Mixture Model就会变得足够复杂,就可以用来逼近任意连续的概率密度分布,只是因为高斯函数具有良好的计算性能,所GMM被广泛地应用。

GMM是一种聚类算法,每个component就是一个聚类中心。即在只有样本点,不知道样本分类(含有隐含变量)的情况下,计算出模型参数(π ππ,u u u 和 σ \sigma σ),这可以用EM算法来求解。再用训练好的模型去差别样本所属的分类,
方法是:

  • step1随机选择K个component中的一个(被选中的概率是π k π_k πk ​);
  • step2把样本代入刚选好的component,判断是否属于这个类别,如果不属于则回到step1。

2.1、em_gmm.h


#pragma once

void em_gmm(
        const float *data,
        const long num_pts,
        const long dim,
        const int num_modes,
        float *means,
        float *diag_covs,
        float *weights,
        bool should_fit_spherical_gaussian = true);

void likelihood_gmm(
        const float *data,
        const long num_pts,
        const long dim,
        const int num_modes,
        const float *means,
        const float *diag_covs,
        const float *weights,
        float *log_probs,
        bool is_spherical_gaussian = true);

2.2、em_gmm.cpp

#include "em_gmm.h"

#include
#include
#include
#include

#include
#define M_PI 3.1415926
namespace {

    const float eps_covariance = 1e-10;
    const float eps_zero = 1e-10;
    const float eps_log_negative_inf = -1e30;
    const float eps_convergence = 1e-4;
    const float eps_regularize = 1e-30;
}

using namespace std;
using namespace Eigen;

typedef Matrix<float, Dynamic, Dynamic, RowMajor> RowMatrixXf;

inline float log_sum(const float& log_a, const float& log_b) {
    return log_a < log_b ?
         (log_b + std::log (1.0 + std::exp (log_a - log_b)))
         : (log_a + std::log (1.0 + std::exp (log_b - log_a)));
}

void calculate_log_prob_spherical(
        const RowMatrixXf& mat_data,
        const VectorXf& vec_nrm2_pts,
        const RowVectorXf& vec_weights,
        const RowMatrixXf& mat_means,
        const RowMatrixXf& mat_diag_covs,
        RowMatrixXf& mat_log_probs) {

    const long num_modes = mat_means.rows();
    const long num_pts = mat_data.rows();
    const long dim = mat_data.cols();

    assert( vec_nrm2_pts.rows() == num_pts &&
            vec_weights.cols() == num_modes &&
            mat_means.cols() == dim &&
            mat_diag_covs.rows() == num_modes &&
            mat_diag_covs.cols() == dim &&
            mat_log_probs.rows() == num_pts &&
            mat_log_probs.cols() == num_modes );

    RowVectorXf vec_nrm2_centers(num_modes);
    for (long c = 0; c < num_modes; c++) {
        vec_nrm2_centers(c) = mat_means.row(c).squaredNorm();
    }

    mat_log_probs.noalias() = mat_data * mat_means.transpose();
    mat_log_probs *= -2;
    #pragma omp parallel for
    for (long c = 0; c < num_modes; c++) {
        mat_log_probs.col(c) += vec_nrm2_pts;
    }
    #pragma omp parallel for
    for (long n = 0; n < num_pts; n++) {
        mat_log_probs.row(n) += vec_nrm2_centers;
    }

    #pragma omp parallel for
    for (long c = 0; c < num_modes; c++) {
        const float cov = mat_diag_covs(c,0);
        const float c1 = log(vec_weights(c) + eps_regularize)
            - 0.5*dim*log(2*M_PI) - 0.5*dim*log(cov);
        const float c2 = -0.5f/cov;
        mat_log_probs.col(c) *= c2;
        mat_log_probs.col(c) = (mat_log_probs.col(c).array() + c1).matrix();
    }
}

void calculate_log_prob_diagonal(
        const RowMatrixXf& mat_data,
        const RowVectorXf& vec_weights,
        const RowMatrixXf& mat_means,
        const RowMatrixXf& mat_diag_covs,
        RowMatrixXf& mat_log_probs) {

    const long num_modes = mat_means.rows();
    const long num_pts = mat_data.rows();
    const long dim = mat_data.cols();

    assert( vec_weights.cols() == num_modes &&
            mat_means.cols() == dim &&
            mat_diag_covs.rows() == num_modes &&
            mat_diag_covs.cols() == dim &&
            mat_log_probs.rows() == num_pts &&
            mat_log_probs.cols() == num_modes);

    const float c0(-0.5f*dim*log(2*M_PI));
    RowVectorXf vec_c1_cov_prod(num_modes);
    for (long c = 0; c < num_modes; c++) {
        vec_c1_cov_prod(c) = -0.5f*(mat_diag_covs.row(c).array().log().sum());
    }

    #pragma omp parallel for
    for (long n = 0; n < num_pts; n++) {
        RowVectorXf vec_data = mat_data.row(n);
        for (long c = 0; c < num_modes; c++) {
            RowVectorXf delta = (vec_data - mat_means.row(c));
            mat_log_probs(n,c) = -0.5f*(delta.array()*mat_diag_covs.row(c).array().cwiseInverse()).matrix().dot(delta)
                + c0 + vec_c1_cov_prod(c);
        }
    }
}

void em_gmm(
        const float *data,
        const long num_pts,
        const long dim,
        const int num_modes,
        float *means,
        float *diag_covs,
        float *weights,
        bool should_fit_spherical_gaussian) {

    using namespace std;

    assert (num_modes < num_pts && "Not enough data for em");

    RowVectorXf vec_eps_regularize(num_modes);
    vec_eps_regularize.fill(eps_regularize);

    std::vector<int> labels(num_pts, -1);

    Map<const RowMatrixXf> mat_data(data, num_pts, dim);
    Map<RowMatrixXf> mat_means(means, num_modes, dim);
    Map<RowMatrixXf> mat_diag_covs(diag_covs, num_modes, dim);
    Map<RowVectorXf> vec_weights(weights, num_modes);

    random_device rd;
    default_random_engine gen(rd());
    uniform_int_distribution<long> kmeans_seed_dist(0, num_pts-1);
    vector<long> center_indices(num_modes);
    generate(center_indices.begin(), center_indices.end(), [&]{
        return kmeans_seed_dist(gen);
    });
    #pragma omp parallel for
    for (int c = 0; c < num_modes; c++) {
        long n = center_indices[c];
        mat_means.row(c) = mat_data.row(n);
    }

    const int max_kmeans_iterations = 20;
    bool is_converged = false;
    float eps(0.0f);

    RowMatrixXf mat_distance(num_pts, num_modes);

    VectorXf vec_nrm2_pts(num_pts);
    #pragma omp parallel for
    for (long r = 0; r < num_pts; r++) {
        vec_nrm2_pts(r) = mat_data.row(r).squaredNorm();
    }

    RowVectorXf vec_nrm2_centers(num_modes);

    RowMatrixXf mat_saved_means(num_modes, dim);
    RowVectorXf assigned_counts(num_modes);

    int iterations = 0;
    while ((iterations++ < max_kmeans_iterations) && !is_converged) {

        mat_saved_means = mat_means;

        #pragma omp parallel for
        for (int c = 0; c < num_modes; c++) {
            vec_nrm2_centers(c) = mat_means.row(c).squaredNorm();
        }
        mat_distance.noalias() = mat_data * mat_means.transpose();
        mat_distance *= -2;
        #pragma omp parallel for
        for (long c = 0; c < num_modes; c++) {
            mat_distance.col(c) += vec_nrm2_pts;
        }

        #pragma omp parallel for
        for (long n = 0; n < num_pts; n++) {
            mat_distance.row(n) += vec_nrm2_centers;
            mat_distance.row(n).minCoeff(&labels[n]);
        }

        assigned_counts.fill(0.0f);
        for (long n = 0; n < num_pts; n++) {
            long c = labels[n];
            if (assigned_counts(c) < 1e-3) {
                mat_means.row(c) = mat_data.row(n);
            } else {
                mat_means.row(c) += mat_data.row(n);
            }
            assigned_counts(c) += 1.0f;
        }

        #pragma omp parallel for
        for (long c = 0; c < num_modes; c++) {
            if (assigned_counts(c) > 1e-3) {
                mat_means.row(c) /= assigned_counts(c);
            }
        }

        const float prev_eps = eps;
        eps = (mat_saved_means - mat_means).norm();
        is_converged = (eps < eps_convergence);
        cout << "kmeans " << "[" << iterations << "] " << prev_eps << " " << eps << endl;
    }

    vec_weights = assigned_counts / (float)num_pts;
    mat_diag_covs.fill(0);
    for (long n = 0; n < num_pts; n++) {
        long c = labels[n];
        mat_diag_covs.row(c) += mat_data.row(n).array().square().matrix();
    }
    for (long c = 0; c < num_modes; c++) {
        if (assigned_counts(c) > 1e-3) {
            mat_diag_covs.row(c) /= assigned_counts(c);
            mat_diag_covs.row(c) -= mat_means.row(c).array().square().matrix();
        }
    }

    if (should_fit_spherical_gaussian) {
        for (int c = 0; c < num_modes; c++) {
            const float spherical_val = std::max(mat_diag_covs.row(c).sum()/dim, eps_covariance);
            mat_diag_covs.row(c).fill(spherical_val);
        }
    }

    RowMatrixXf& mat_log_probs = mat_distance;

    const int max_em_iterations = 20;
    is_converged = false;
    float expectation(std::numeric_limits<float>::lowest());

    RowVectorXf vec_evals(num_pts);
    RowVectorXf vec_log_sum_probs(num_pts);
    RowVectorXf vec_occup_eN(num_modes);
    RowMatrixXf mat_occup_eX(num_modes, dim);
    RowMatrixXf mat_occup_eX2(num_modes, dim);

    iterations = 0;
    while ((iterations++ < max_em_iterations) && !is_converged) {

        if (should_fit_spherical_gaussian) {
            calculate_log_prob_spherical(mat_data, vec_nrm2_pts, vec_weights, mat_means, mat_diag_covs, mat_log_probs);
        } else {
            calculate_log_prob_diagonal(mat_data, vec_weights, mat_means, mat_diag_covs, mat_log_probs);
        }

        #pragma omp parallel for
        for (long n = 0; n < num_pts; n++) {
            vec_log_sum_probs(n) = mat_log_probs(n,0);
            for (long c = 1; c < num_modes; c++) {
                if (mat_log_probs(n,c) > eps_log_negative_inf) {
                    vec_log_sum_probs(n) = log_sum(vec_log_sum_probs(n), mat_log_probs(n,c));
                }
            }
        }

        #pragma omp parallel for
        for (long n = 0; n < num_pts; n++) {
            RowVectorXf soft_count = (mat_log_probs.row(n).array() - vec_log_sum_probs(n)).exp().matrix();
            vec_evals(n) = 0;
            for (long c = 0; c < num_modes; c++) {
                if (soft_count(c) > eps_zero) {
                    vec_evals(n) += mat_log_probs(n,c)*soft_count(c);
                }
            }
            mat_log_probs.row(n) = soft_count;
        }

        #pragma omp parallel for
        for (long c = 0; c < num_modes; c++) {
            vec_occup_eN(c) = mat_log_probs.col(c).sum();
        }

        mat_occup_eX = mat_log_probs.transpose() * mat_data;

        vec_weights = vec_occup_eN / vec_occup_eN.sum();
        vec_occup_eN += vec_eps_regularize;
        #pragma omp parallel for
        for (int c = 0; c < num_modes; c++) {
            if (vec_weights(c) > eps_zero) {
                mat_means.row(c) = mat_occup_eX.row(c)/vec_occup_eN(c);
            }
        }

        VectorXf vec_sum_occup_nrm2_pts = mat_log_probs.transpose() * vec_nrm2_pts;
        if (should_fit_spherical_gaussian) {
            for (int c = 0; c < num_modes; c++) {
                const float spherical_val
                    = (vec_sum_occup_nrm2_pts(c) - 2*mat_means.row(c).dot(mat_occup_eX.row(c)) + vec_nrm2_centers(c)*vec_occup_eN(c))
                    / (dim*vec_occup_eN(c));
                mat_diag_covs.row(c).fill(spherical_val);
            }
        } else {
            mat_occup_eX2 = mat_log_probs.transpose() * mat_data.array().square().matrix();
            for (long c = 0; c < num_modes; c++) {
                mat_diag_covs.row(c) = ((mat_occup_eX2.row(c)/vec_occup_eN(c)).array() - mat_means.row(c).array().square()).max(eps_covariance).matrix();
            }
        }

        const float prev_expectation = expectation;
        expectation = vec_evals.sum();
        const float scale = 1e5;
        const float delta = exp((expectation - prev_expectation)/scale) - 1;
        is_converged = (iterations > 0 && delta < eps_convergence);
        cout << "em " << "[" << iterations << "] " << delta << " " << expectation << endl;
    }

}

void likelihood_gmm(
        const float *data,
        const long num_pts,
        const long dim,
        const int num_modes,
        const float *means,
        const float *diag_covs,
        const float *weights,
        float *log_probs,
        bool is_spherical_gaussian) {

    Map<const RowMatrixXf> mat_data(data, num_pts, dim);
    Map<const RowMatrixXf> mat_means(means, num_modes, dim);
    Map<const RowMatrixXf> mat_diag_covs(diag_covs, num_modes, dim);
    Map<const RowVectorXf> vec_weights(weights, num_modes);

    VectorXf vec_nrm2_pts(num_pts);
    #pragma omp parallel for
    for (long r = 0; r < num_pts; r++) {
        vec_nrm2_pts(r) = mat_data.row(r).squaredNorm();
    }

    RowVectorXf vec_nrm2_centers(num_modes);

    RowMatrixXf mat_log_probs(num_pts, num_modes);

    if (is_spherical_gaussian) {
        calculate_log_prob_spherical(mat_data, vec_nrm2_pts, vec_weights, mat_means, mat_diag_covs, mat_log_probs);
    } else {
        calculate_log_prob_diagonal(mat_data, vec_weights, mat_means, mat_diag_covs, mat_log_probs);
    }

    std::copy(mat_log_probs.data(), mat_log_probs.data() + num_pts*num_modes, log_probs);
}

2.3、测试程序

test.cpp

#include
#include
#include
#include

#include "em_gmm.h"
#include "sample.h"

int main(int argc, char **argv) {

    using namespace std;

    const long num_pts(1e5);
    const long num_gaussians(2);
    const long dim(2);

    vector<float> weights(num_gaussians);
    vector<float> means(num_gaussians*dim);
    vector<float> diag_covs(num_gaussians*dim);

    em_gmm(sample_data, num_pts, dim, num_gaussians,
            means.data(), diag_covs.data(),
            weights.data(), false );

    cout << "weights: " << weights[0] << " " << weights[1] << endl;
    cout << "means: " << means[0] << " " << means[1] << " ;\n"
        << means[2] << " " << means[3] << endl;
    cout << "covs: " << diag_covs[0] << " " << diag_covs[1] << " ;\n"
        << diag_covs[2] << " " << diag_covs[3] << endl;

    return 0;
}

测试程序输入数据 100000组,维度为2,高斯核个数为2
需要求解的输出是:2个高斯函数的权重、2个高斯函数的参数 期望和方差
程序输出:

kmeans [1] 0 1.39343
kmeans [2] 1.39343 0.00233473
kmeans [3] 0.00233473 0
em [1] inf -284034
em [2] 9.53674e-07 -284034
weights: 0.499979 0.500021
means: -2.99555 -5.00304 ;
0.99933 2.00078
covs: 1.00493 0.997717 ;
2.01731 0.49929

Original: https://blog.csdn.net/juluwangriyue/article/details/122821186
Author: 落花逐流水
Title: 笔记(三)GMM与EM 高斯混合模型与EM算法

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/563215/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球