You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
from sklearn.naive_bayes import GaussianNB
|
|
import random
|
|
import numpy as np
|
|
import math
|
|
|
|
from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.font_manager import FontProperties
|
|
|
|
# Generation of data to train the classifier
|
|
# 100 vectors are generated. Vector have dimension 2 so can be represented as points
|
|
NBVECS = 100
|
|
VECDIM = 2
|
|
|
|
# 3 cluster of points are generated
|
|
ballRadius = 1.0
|
|
x1 = [1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
|
|
x2 = [-1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
|
|
x3 = [0, -3] + ballRadius * np.random.randn(NBVECS,VECDIM)
|
|
|
|
# All points are concatenated
|
|
X_train=np.concatenate((x1,x2,x3))
|
|
|
|
# The classes are 0,1 and 2.
|
|
Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS),2*np.ones(NBVECS)))
|
|
|
|
gnb = GaussianNB()
|
|
gnb.fit(X_train, Y_train)
|
|
|
|
print("Testing")
|
|
y_pred = gnb.predict([[1.5,1.0]])
|
|
print(y_pred)
|
|
|
|
y_pred = gnb.predict([[-1.5,1.0]])
|
|
print(y_pred)
|
|
|
|
y_pred = gnb.predict([[0,-3.0]])
|
|
print(y_pred)
|
|
|
|
# Dump of data for CMSIS-DSP
|
|
|
|
print("Parameters")
|
|
# Gaussian averages
|
|
print("Theta = ",list(np.reshape(gnb.theta_,np.size(gnb.theta_))))
|
|
|
|
# Gaussian variances
|
|
print("Sigma = ",list(np.reshape(gnb.sigma_,np.size(gnb.sigma_))))
|
|
|
|
# Class priors
|
|
print("Prior = ",list(np.reshape(gnb.class_prior_,np.size(gnb.class_prior_))))
|
|
|
|
print("Epsilon = ",gnb.epsilon_)
|
|
|
|
|
|
# Some bounds are computed for the graphical representation
|
|
x_min = X_train[:, 0].min()
|
|
x_max = X_train[:, 0].max()
|
|
y_min = X_train[:, 1].min()
|
|
y_max = X_train[:, 1].max()
|
|
|
|
font = FontProperties()
|
|
font.set_size(20)
|
|
|
|
r=plt.figure()
|
|
plt.axis('off')
|
|
plt.text(1.5,1.0,"A", verticalalignment='center', horizontalalignment='center',fontproperties=font)
|
|
plt.text(-1.5,1.0,"B",verticalalignment='center', horizontalalignment='center', fontproperties=font)
|
|
plt.text(0,-3,"C", verticalalignment='center', horizontalalignment='center',fontproperties=font)
|
|
scatter(x1[:,0],x1[:,1],s=1.0,color='#FF6B00')
|
|
scatter(x2[:,0],x2[:,1],s=1.0,color='#95D600')
|
|
scatter(x3[:,0],x3[:,1],s=1.0,color='#00C1DE')
|
|
#r.savefig('fig.jpeg')
|
|
#plt.close(r)
|
|
show() |