You are on page 1of 4

Simple Case Study of Implementing K Means Clustering

on the IRIS Dataset

Import tools and libraries:


from time import time

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn import metrics

from sklearn.cluster import KMeans

from sklearn.datasets import load_digits

from sklearn.decomposition import PCA

from sklearn.preprocessing import scale

Create a function to extract a cluster with k labels:


def get_cluster_metric(y_train, km_labels_):

print("Homogeneity: %0.3f" % metrics.homogeneity_score(y_train, km_labels_))

print("Completeness: %0.3f" % metrics.completeness_score(y_train, km_labels_))

print("V-measure: %0.3f" % metrics.v_measure_score(y_train, km_labels_))

print()

Generate hypothetical data for practice:


np.random.seed(42)

digits = load_digits()

data = scale(digits.data)

n_samples, n_features = data.shape


n_digits = len(np.unique(digits.target))

labels = digits.target

sample_size = 300

print("n_digits: %d, \t n_samples %d, \t n_features %d"

% (n_digits, n_samples, n_features))

Output:

n_digits: 10, n_samples 1797, n_features 64

labels.shape

Output: (1797, )

Loading the inbuilt IRIS dataset in Python:


from sklearn.datasets import load_iris

Algorithm to extract the clusters and compute sum of squared errors:


y = labels

sse = {}

accuracy = []

for k in range(1, 20):

kmeans = KMeans(n_clusters=k, max_iter=1000).fit(data)

sse[k] = kmeans.inertia_ # Inertia: Sum of distances of samples to their closest cluster center

labels_pred = kmeans.labels_

# print(labels_pred.shape)

# check how many of the samples were correctly labeled

correct_labels = sum(labels == labels_pred)

accuracy.append(correct_labels/float(y.size))

# print("Result: %d out of %d samples were correctly labeled. when k = %d " % (correct_labels,


y.size,k))
print("correct %.02f percent classification at k = %d" % (correct_labels/float(y.size) * 100 ,k))

get_cluster_metric(y, kmeans.labels_)

Visualisation:
#No. of clusters v/s SSE

plt.figure()

plt.plot(list(sse.keys()), list(sse.values()))

plt.xlabel("Number of cluster")

plt.ylabel("SSE")

plt.show()

#No. of clusters v/s accuracy

plt.figure()

plt.plot(range(1, 20,1),accuracy)

plt.xlabel("Number of cluster")

plt.ylabel("accuracy")

plt.show()

You might also like