clear
close all
clc

%% generate data

%[X, truelabels] = simulate_data('twogaussians');
%[X, truelabels] = simulate_data('threegaussians');

[X, truelabels] = simulate_data('twocircles');
%[X, truelabels] = simulate_data('threecircles');

%load twomoons

% n = 100;
% X = [randn(n,2)*0.5 - 3; randn(n,2)*0.5 - 1; randn(n,2)*0.5 + 5];
% truelabels = repelem(1:3,n);

figure; 
scatter(X(:,1), X(:,2), 36, "black")
set(gca,'fontsize',16)

%% normalized cut

n = size(X,1);

dists = pdist2(X,X,'squaredeuclidean'); % matrix of squared pairwise Euclidean distances

subset = randsample(n,50);
[knnidx,knndists] = knnsearch(X, X(subset,:), 'K', 7);
sigma = mean(knndists(:,end));

W = exp(-dists/(2*sigma^2));
W(1:n+1:end) = 0; % set diagonals to zero

figure; 
imagesc(W); 
colorbar
title('weight matrix W','fontsize',18)

P = W ./ sum(W,2);

[V,Lambda] = eigs(P, 6, 'LM');

figure; 
plot(diag(Lambda), '.', 'markersize',16)
title('Top eigenvalues of P', 'fontsize',18)
set(gca,'fontsize',16)
grid on

%% 2 clusters

V2 = V(:,2);

labels_ncut = ones(1,n);
labels_ncut(V2>0) = 2;

figure; 
gscatter(1:n, V2, labels_ncut, 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('second smallest eigenvector','fontsize',18)

figure; 
gscatter(X(:,1), X(:,2), labels_ncut, 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('clusters found by Ncut','fontsize',18)

err = clustering_error(labels_ncut, truelabels)

return

%% 3 (or more) clusters
    
figure; 
plot(V(:,2), V(:,3), '.', 'markersize', 16)
set(gca,'fontsize',16)
title('second and third largest eigenvectors of P','fontsize',18)

labels_ncut = kmeans(V(:,2:3), 3, 'replicates', 10);

figure; 
gscatter(V(:,2), V(:,3), labels_ncut, 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('clusters found by kmeans in eigenvector space','fontsize',18)
xlabel('v_2', 'fontsize', 16)
ylabel('v_3', 'fontsize', 16)

figure; 
gscatter(X(:,1), X(:,2), labels_ncut, 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('clusters found by Ncut','fontsize',18)

err = clustering_error(labels_ncut, truelabels)
