You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CMSIS-DSP/Examples/ARM/arm_bayes_example/train.py

74 lines
2.1 KiB
Python

from sklearn.naive_bayes import GaussianNB
import random
import numpy as np
import math
from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
# Generation of data to train the classifier
# 100 vectors are generated. Vector have dimension 2 so can be represented as points
NBVECS = 100
VECDIM = 2
# 3 cluster of points are generated
ballRadius = 1.0
x1 = [1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
x2 = [-1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
x3 = [0, -3] + ballRadius * np.random.randn(NBVECS,VECDIM)
# All points are concatenated
X_train=np.concatenate((x1,x2,x3))
# The classes are 0,1 and 2.
Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS),2*np.ones(NBVECS)))
gnb = GaussianNB()
gnb.fit(X_train, Y_train)
print("Testing")
y_pred = gnb.predict([[1.5,1.0]])
print(y_pred)
y_pred = gnb.predict([[-1.5,1.0]])
print(y_pred)
y_pred = gnb.predict([[0,-3.0]])
print(y_pred)
# Dump of data for CMSIS-DSP
print("Parameters")
# Gaussian averages
print("Theta = ",list(np.reshape(gnb.theta_,np.size(gnb.theta_))))
# Gaussian variances
print("Sigma = ",list(np.reshape(gnb.sigma_,np.size(gnb.sigma_))))
# Class priors
print("Prior = ",list(np.reshape(gnb.class_prior_,np.size(gnb.class_prior_))))
print("Epsilon = ",gnb.epsilon_)
# Some bounds are computed for the graphical representation
x_min = X_train[:, 0].min()
x_max = X_train[:, 0].max()
y_min = X_train[:, 1].min()
y_max = X_train[:, 1].max()
font = FontProperties()
font.set_size(20)
r=plt.figure()
plt.axis('off')
plt.text(1.5,1.0,"A", verticalalignment='center', horizontalalignment='center',fontproperties=font)
plt.text(-1.5,1.0,"B",verticalalignment='center', horizontalalignment='center', fontproperties=font)
plt.text(0,-3,"C", verticalalignment='center', horizontalalignment='center',fontproperties=font)
scatter(x1[:,0],x1[:,1],s=1.0,color='#FF6B00')
scatter(x2[:,0],x2[:,1],s=1.0,color='#95D600')
scatter(x3[:,0],x3[:,1],s=1.0,color='#00C1DE')
#r.savefig('fig.jpeg')
#plt.close(r)
show()