Professional Documents
Culture Documents
Stroke Prediction
Stroke Prediction
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv('healthcare-dataset-stroke-data.csv')
data
id gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 9046 Male 67.0 0 1 Yes Private Urban 228.69 36.6 formerly smoked
Self-
1 51676 Female 61.0 0 0 Yes Rural 202.21 NaN never smoked
employed
2 31112 Male 80.0 0 1 Yes Private Rural 105.92 32.5 never smoked
Self-
4 1665 Female 79.0 1 0 Yes Rural 174.12 24.0 never smoked
employed
... ... ... ... ... ... ... ... ... ... ... ...
5105 18234 Female 80.0 1 0 Yes Private Urban 83.75 NaN never smoked
Self-
5106 44873 Female 81.0 0 0 Yes Urban 125.20 40.0 never smoked
employed
Self-
5107 19723 Female 35.0 0 0 Yes Rural 82.99 30.6 never smoked
employed
5108 37544 Male 51.0 0 0 Yes Private Rural 166.29 25.6 formerly smoked
5109 44679 Female 44.0 0 0 Yes Govt_job Urban 85.28 26.2 Unknown
Data Preprocessing
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 id 5110 non-null int64
1 gender 5110 non-null object
2 age 5110 non-null float64
3 hypertension 5110 non-null int64
4 heart_disease 5110 non-null int64
5 ever_married 5110 non-null object
6 work_type 5110 non-null object
7 Residence_type 5110 non-null object
8 avg_glucose_level 5110 non-null float64
9 bmi 4909 non-null float64
10 smoking_status 5110 non-null object
11 stroke 5110 non-null int64
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB
data.describe()
data.isnull().sum()
id 0
gender 0
age 0
hypertension 0
heart_disease 0
ever_married 0
work_type 0
Residence_type 0
avg_glucose_level 0
bmi 201
smoking_status 0
stroke 0
dtype: int64
plt.figure(figsize=(8,5))
data['bmi'].plot(kind='kde')
plt.show()
data.isnull().sum()
id 0
gender 0
age 0
hypertension 0
heart_disease 0
ever_married 0
work_type 0
Residence_type 0
avg_glucose_level 0
bmi 0
smoking_status 0
stroke 0
dtype: int64
data.head()
gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
Self-
1 Female 61.0 0 0 Yes Rural 202.21 28.893237 never smoked 1
employed
Self-
4 Female 79.0 1 0 Yes Rural 174.12 24.000000 never smoked 1
employed
EDA
for i in num.columns:
sns.boxplot(data=num,x=i)
plt.show()
Gender
data['gender'].value_counts()
Female 2994
Male 2115
Other 1
Name: gender, dtype: int64
sns.countplot(data=data,x='gender')
plt.show()
sns.countplot(data=data,x='gender',hue='stroke')
plt.show()
data['stroke'].value_counts().plot(kind='pie',autopct='%0.2f%%')
plt.show()
Age
# More men than women had strokes
data.groupby('gender').mean()[['age', 'stroke']]
age stroke
gender
Ever married
data['ever_married'].value_counts()
Yes 3353
No 1757
Name: ever_married, dtype: int64
sns.countplot(data=data,x='ever_married',hue='stroke')
plt.show()
Work Type
data['work_type'].unique()
data['work_type'].unique()
data['work_type'].value_counts()
Private 2925
Self-employed 819
children 687
Govt_job 657
Never_worked 22
Name: work_type, dtype: int64
sns.countplot(data=data,x='work_type',hue='stroke')
plt.show()
Residence Type
data['Residence_type'].unique()
data['Residence_type'].value_counts()
Urban 2596
Rural 2514
Name: Residence_type, dtype: int64
sns.countplot(data=data,x='Residence_type',hue='stroke')
plt.show()
Smoking Features
data['smoking_status'].value_counts()
sns.countplot(data=data,x='smoking_status',hue='stroke')
plt.show()
Heatmap
sns.heatmap(data.corr(),annot=True,fmt='.2f')
plt.show()
gender object
age float64
hypertension int64
heart_disease int64
ever_married object
work_type object
Residence_type object
avg_glucose_level float64
bmi float64
smoking_status object
stroke int64
dtype: object
data['gender'] = lr.fit_transform(data['gender'])
data['ever_married'] = lr.fit_transform(data['ever_married'])
data['work_type'] = lr.fit_transform(data['work_type'])
data['Residence_type'] = lr.fit_transform(data['Residence_type'])
data['smoking_status'] = lr.fit_transform(data['smoking_status'])
Y=data['stroke'].values
Y
# splitting
Logistic Regression
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression()
classifier.fit(X_train, Y_train)
LogisticRegression()
predict = classifier.predict(X_test)
predict
Y_test
print(confusion_matrix(Y_test, predict))
[[968 0]
[ 54 0]]
print(classification_report(Y_test, predict))
KNN Classifier
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
knn.fit(X_train, Y_train)
KNeighborsClassifier()
pred = knn.predict(X_test)
pred
Y_test
Accuracy: 0.9422700587084148
classifier = DecisionTreeClassifier(max_depth=3)
classifier.fit(X_train, Y_train)
DecisionTreeClassifier(max_depth=3)
Y_pred = classifier.predict(X_test)
Y_pred
Y_test
Accuracy: 0.9461839530332681
fig = plt.figure(figsize=(15,10))
tree.plot_tree(classifier,filled=True,class_names=True,node_ids=True)
plt.show()
Random Forest Classifier
from sklearn.ensemble import RandomForestClassifier
classifier = RandomForestClassifier()
classifier.fit(X_train,Y_train)
RandomForestClassifier()
Y_pred1 = classifier.predict(X_test)
Y_pred1
Y_test
Accuracy: 0.9461839530332681
Loading [MathJax]/jax/output/CommonHTML/fonts/TeX/fontdata.js