Added DTW to the PythonWrapper API

pull/94/head
Christophe Favergeon 3 years ago
parent b46a2f86b5
commit 82559adce2

@ -68,7 +68,7 @@ jobs:
- name: Archive documentation
if: ${{ github.event_name == 'pull_request' }}
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: documentation
path: Documentation/html/
@ -92,7 +92,7 @@ jobs:
- name: Archive pack
if: ${{ github.event_name != 'release' }}
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
path: output/*.pack
retention-days: 1
@ -108,7 +108,7 @@ jobs:
tag: ${{ github.ref }}
overwrite: true
- uses: actions/checkout@v2
- uses: actions/checkout@v3
if: ${{ github.event_name == 'release' || github.event_name == 'push' || github.event_name == 'workflow_dispatch' }}
with:
ref: gh-pages

@ -371,7 +371,7 @@ arm_status arm_dtw_distance_f32(const arm_matrix_instance_f32 *pDistance,
/**
* @brief Mapping between query and template
* @param[in] pDTW Cost matrix (Query rows * Template columns)
* @param[out] pPath Warping path in cost matrix 2*nb rows + nb columns)
* @param[out] pPath Warping path in cost matrix 2*(nb rows + nb columns)
* @param[out] pathLength Length of path in number of points
* @return none
*

@ -30,6 +30,13 @@
#define MODINITNAME cmsisdsp_distance
#include "cmsisdsp_module.h"
MATRIXFROMNUMPY(f32,float32_t,double,NPY_DOUBLE);
CREATEMATRIX(f32,float32_t);
NUMPYARRAYFROMMATRIX(f32,NPY_FLOAT);
MATRIXFROMNUMPY(q7,q7_t,int8_t,NPY_BYTE);
CREATEMATRIX(q7,q7_t);
NUMPYARRAYFROMMATRIX(q7,NPY_BYTE);
NUMPYVECTORFROMBUFFER(f32,float32_t,NPY_FLOAT);
@ -209,6 +216,140 @@ INTDIST(sokalmichener_distance);
INTDIST(sokalsneath_distance);
INTDIST(yule_distance);
static PyObject *
cmsis_arm_dtw_init_window_q7(PyObject *obj,
PyObject *args)
{
PyObject *pSrc=NULL; // input
int32_t winType;
int32_t winSize;
arm_matrix_instance_q7 pSrc_converted; // input
if (PyArg_ParseTuple(args,"iiO",&winType,&winSize,&pSrc));
{
q7MatrixFromNumpy(&pSrc_converted,pSrc);
uint32_t row = pSrc_converted.numCols ;
uint32_t column = pSrc_converted.numRows ;
arm_status returnValue =
arm_dtw_init_window_q7(winType,
winSize,
&pSrc_converted
);
PyObject* theReturnOBJ=Py_BuildValue("i",returnValue);
PyObject* dstOBJ=NumpyArrayFromq7Matrix(&pSrc_converted);
PyObject *pythonResult = Py_BuildValue("OO",theReturnOBJ,dstOBJ);
Py_DECREF(theReturnOBJ);
Py_DECREF(dstOBJ);
return(pythonResult);
}
Py_RETURN_NONE;
}
static PyObject *
cmsis_arm_dtw_distance_f32(PyObject *obj,
PyObject *args)
{
PyObject *pDist=NULL; // input
arm_matrix_instance_f32 pDist_converted; // input
PyObject *pWin=NULL; // input
arm_matrix_instance_q7 pWin_converted; // input
arm_matrix_instance_q7 *pWinMatrix;
arm_matrix_instance_f32 dtw_converted;
if (PyArg_ParseTuple(args,"OO",&pDist,&pWin));
{
f32MatrixFromNumpy(&pDist_converted,pDist);
if (pWin != Py_None)
{
q7MatrixFromNumpy(&pWin_converted,pWin);
pWinMatrix = &pWin_converted;
}
else
{
pWinMatrix = NULL;
}
uint32_t column = pDist_converted.numCols ;
uint32_t row = pDist_converted.numRows ;
createf32Matrix(&dtw_converted,row,column);
float32_t distance;
arm_status returnValue =
arm_dtw_distance_f32(&pDist_converted,
pWinMatrix,
&dtw_converted,
&distance
);
PyObject* theReturnOBJ=Py_BuildValue("i",returnValue);
PyObject* distOBJ=Py_BuildValue("f",distance);
PyObject* dstOBJ=NumpyArrayFromf32Matrix(&dtw_converted);
PyObject *pythonResult = Py_BuildValue("OOO",theReturnOBJ,distOBJ,dstOBJ);
Py_DECREF(theReturnOBJ);
Py_DECREF(distOBJ);
FREEMATRIX(&pDist_converted);
if (pWinMatrix)
{
FREEMATRIX(pWinMatrix);
}
Py_DECREF(dstOBJ);
return(pythonResult);
}
Py_RETURN_NONE;
}
static PyObject *
cmsis_arm_dtw_path_f32(PyObject *obj,
PyObject *args)
{
PyObject *pCost=NULL; // input
arm_matrix_instance_f32 pCost_converted; // input
int16_t *pDst=NULL; // output
if (PyArg_ParseTuple(args,"O",&pCost))
{
f32MatrixFromNumpy(&pCost_converted,pCost);
uint32_t pathLength;
int32_t blockSize;
blockSize=2*(pCost_converted.numRows+pCost_converted.numCols);
pDst=PyMem_Malloc(sizeof(int16_t)*blockSize);
arm_dtw_path_f32(&pCost_converted,
pDst,
&pathLength);
INT16ARRAY1(pDstOBJ,2*pathLength,pDst);
PyObject *pythonResult = Py_BuildValue("O",pDstOBJ);
FREEMATRIX(&pCost_converted);
Py_DECREF(pDstOBJ);
return(pythonResult);
}
}
static PyMethodDef CMSISDSPMethods[] = {
@ -241,6 +382,10 @@ static PyMethodDef CMSISDSPMethods[] = {
{"arm_sokalsneath_distance",cmsis_arm_sokalsneath_distance, METH_VARARGS,""},
{"arm_yule_distance",cmsis_arm_yule_distance, METH_VARARGS,""},
{"arm_dtw_init_window_q7", cmsis_arm_dtw_init_window_q7, METH_VARARGS,""},
{"arm_dtw_distance_f32", cmsis_arm_dtw_distance_f32, METH_VARARGS,""},
{"arm_dtw_path_f32", cmsis_arm_dtw_path_f32, METH_VARARGS,""},
{"error_out", (PyCFunction)error_out, METH_NOARGS, NULL},
{NULL, NULL, 0, NULL} /* Sentinel */
};

@ -2186,6 +2186,7 @@ cmsis_arm_mat_solve_upper_triangular_f64(PyObject *obj, PyObject *args)
Py_RETURN_NONE;
}
static PyMethodDef CMSISDSPMethods[] = {
{"arm_mat_add_f32", cmsis_arm_mat_add_f32, METH_VARARGS,""},
@ -2238,6 +2239,7 @@ static PyMethodDef CMSISDSPMethods[] = {
{"arm_householder_f64", cmsis_arm_householder_f64, METH_VARARGS,""},
{"arm_mat_qr_f32", cmsis_arm_mat_qr_f32, METH_VARARGS,""},
{"arm_mat_qr_f64", cmsis_arm_mat_qr_f64, METH_VARARGS,""},
{"error_out", (PyCFunction)error_out, METH_NOARGS, NULL},
{NULL, NULL, 0, NULL} /* Sentinel */

@ -35,7 +35,6 @@
#endif
#include <Python.h>
#define MAX(A,B) ((A) < (B) ? (B) : (A))
#define CAT1(A,B) A##B
#define CAT(A,B) CAT1(A,B)

@ -0,0 +1,154 @@
# Bug corrections for version 1.9
import cmsisdsp as dsp
import cmsisdsp.fixedpoint as f
import numpy as np
import colorama
from colorama import init,Fore, Back, Style
from numpy.testing import assert_allclose
from numpy.linalg import norm
from dtw import *
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
init()
def printTitle(s):
print("\n" + Fore.GREEN + Style.BRIGHT + s + Style.RESET_ALL)
def printSubTitle(s):
print("\n" + Style.BRIGHT + s + Style.RESET_ALL)
printTitle("DTW Window")
printSubTitle("SAKOE_CHIBA_WINDOW")
refWin1=np.array([[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1],
[0, 1, 1, 1, 1],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]], dtype=np.int8)
dtwWindow=np.zeros((10,5),dtype=np.int8)
wsize=2
status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SAKOE_CHIBA_WINDOW,wsize,dtwWindow)
assert (w==refWin1).all()
printSubTitle("SLANTED_BAND_WINDOW")
refWin2=np.array([[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[0, 1, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 1, 0],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1]], dtype=np.int8)
dtwWindow=np.zeros((10,5),dtype=np.int8)
wsize=1
status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SLANTED_BAND_WINDOW,wsize,dtwWindow)
assert (w==refWin2).all()
printTitle("DTW Cost Matrix and DTW Distance")
QUERY_LENGTH = 10
TEMPLATE_LENGTH = 5
query=np.array([ 0.08387197, 0.68082274, 1.06756417, 0.88914541, 0.42513398, -0.3259053,
-0.80934885, -0.90979435, -0.64026483, 0.06923695])
template=np.array([ 1.00000000e+00, 7.96326711e-04, -9.99998732e-01, -2.38897811e-03,
9.99994927e-01])
cols=np.array([1,2,3])
rows=np.array([10,11,12])
printSubTitle("Without a window")
referenceCost=np.array([[0.91612804, 0.9992037 , 2.0830743 , 2.1693354 , 3.0854583 ],
[1.2353053 , 1.6792301 , 3.3600516 , 2.8525472 , 2.8076797 ],
[1.3028694 , 2.3696373 , 4.4372 , 3.9225004 , 2.875249 ],
[1.4137241 , 2.302073 , 4.1912174 , 4.814035 , 2.9860985 ],
[1.98859 , 2.2623994 , 3.6875322 , 4.115055 , 3.5609593 ],
[3.3144953 , 2.589101 , 3.2631946 , 3.586711 , 4.8868594 ],
[5.123844 , 3.3992462 , 2.9704008 , 3.7773607 , 5.5867043 ],
[7.0336385 , 4.309837 , 3.0606053 , 3.9680107 , 5.8778 ],
[8.673903 , 4.950898 , 3.420339 , 4.058215 , 5.698475 ],
[9.604667 , 5.0193386 , 4.489575 , 3.563591 , 4.494349 ]],
dtype=np.float32)
referenceDistance = 0.2996232807636261
# Each row is a new query
a,b = np.meshgrid(template,query)
distance=abs(a-b).astype(np.float32)
status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,None)
assert_allclose(referenceDistance,dtwDistance)
assert_allclose(referenceCost,dtwMatrix)
printSubTitle("Path")
path=dsp.arm_dtw_path_f32(np.copy(dtwMatrix))
#print(path)
pathMatrix=np.zeros(dtwMatrix.shape)
for x in list(zip(path[0::2],path[1::2])):
pathMatrix[x] = 1
fig, ax = plt.subplots()
im = ax.imshow(pathMatrix,vmax=2.0)
for i in range(QUERY_LENGTH):
for j in range(TEMPLATE_LENGTH):
text = ax.text(j, i, "%.1f" % dtwMatrix[i, j],
ha="center", va="center", color="w")
fig.tight_layout()
plt.show()
printSubTitle("With a window")
referenceDistance = 0.617099940776825
referenceCost=np.array([[9.1612804e-01, 9.9920368e-01, np.NAN, np.NAN,
np.NAN],
[1.2353053e+00, 1.6792301e+00, np.NAN, np.NAN,
np.NAN],
[1.3028694e+00, 2.3696373e+00, 4.4372001e+00, np.NAN,
np.NAN],
[np.NAN, 3.0795674e+00, 4.9687119e+00, np.NAN,
np.NAN],
[np.NAN, 3.5039051e+00, 4.9290380e+00, 5.3565612e+00,
np.NAN],
[np.NAN, np.NAN, 4.8520918e+00, 5.1756082e+00,
np.NAN],
[np.NAN, np.NAN, 5.0427418e+00, 5.8497019e+00,
7.6590457e+00],
[np.NAN, np.NAN, np.NAN, 6.7571073e+00,
8.6668968e+00],
[np.NAN, np.NAN, np.NAN, 7.3949833e+00,
9.0352430e+00],
[np.NAN, np.NAN, np.NAN, np.NAN,
9.2564993e+00]], dtype=np.float32)
status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,w)
assert_allclose(referenceDistance,dtwDistance)
assert_allclose(referenceCost[w==1],dtwMatrix[w==1])

@ -239,6 +239,12 @@ The wrapper is now containing the compute graph Python scripts and you should re
# Change history
## Version 1.10.0:
* Dynamic Time Warping API
* New asynchronous mode for the compute graph
(see [compute graph documentation](https://github.com/ARM-software/CMSIS-DSP/tree/main/ComputeGraph) for more details.
## Version 1.9.3:
* Corrected real FFTs in the wrapper

@ -24,7 +24,7 @@ cmsis_dsp_version="1.15.0"
# CMSIS-DSP Commit hash used to build the wrapper
commit_hash="4c2501f71b9e021ea1f914df865890f16d539172"
commit_hash=" b46a2f86b5c9d8247ea5417fc0e0022876b80dcf"
# True if development version of CMSIS-DSP used
# (So several CMSIS-DSP versions may have same version number hence the commit hash)
@ -35,3 +35,7 @@ __all__ = ["datatype", "fixedpoint", "mfcc"]
# Default values
DEFAULT_HOUSEHOLDER_THRESHOLD_F64=1.0e-16
DEFAULT_HOUSEHOLDER_THRESHOLD_F32=1.0e-12
# DTW Window Types
ARM_DTW_SAKOE_CHIBA_WINDOW = 1
ARM_DTW_SLANTED_BAND_WINDOW = 3

Loading…
Cancel
Save