CMSIS-DSP: Correcting a bug in matrix inversion

When pivot is 0, the row permutation code was not correct and failing on
some matrixes (but not all matrixes).
pull/19/head
Christophe Favergeon 5 years ago
parent ac7da660b7
commit 1019e4c4a8

@ -3,6 +3,7 @@
## How to use
This document is explaining how to use cmake with CMSIS-DSP.
(It is not official so not supported. The official way to build is to use the CMSIS-Pack).
The example arm_variance_f32 in folder Examples/ARM/arm_variance_f32 has been modified to also
support cmake and is used as an example in this document.

@ -67,7 +67,7 @@ arm_status arm_mat_inverse_f16(
float16_t *pTmpA, *pTmpB;
_Float16 in = 0.0f16; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, l; /* loop counters */
arm_status status; /* status of matrix inverse */
uint32_t blkCnt;
@ -191,10 +191,7 @@ arm_status arm_mat_inverse_f16(
* Temporary variable to hold the pivot value
*/
in = *pInT1;
/*
* Destination pointer modifier
*/
k = 1U;
/*
* Check if the pivot element is zero
@ -210,7 +207,7 @@ arm_status arm_mat_inverse_f16(
* Update the input and destination pointers
*/
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/*
* Check if there is a non zero pivot element to
* * replace in the rows below
@ -300,10 +297,7 @@ arm_status arm_mat_inverse_f16(
*/
break;
}
/*
* Update the destination pointer modifier
*/
k++;
}
}
@ -569,7 +563,7 @@ arm_status arm_mat_inverse_f16(
uint32_t numCols = pSrc->numCols; /* Number of Cols in the matrix */
_Float16 Xchg, in = 0.0f16, in1; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k,l; /* loop counters */
arm_status status; /* status of matrix inverse */
#ifdef ARM_MATH_MATRIX_CHECK
@ -681,9 +675,6 @@ arm_status arm_mat_inverse_f16(
in = *pInT1;
/* Destination pointer modifier */
k = 1U;
/* Check if the pivot element is zero */
if (*pInT1 == 0.0f16)
{
@ -693,7 +684,7 @@ arm_status arm_mat_inverse_f16(
{
/* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/* Check if there is a non zero pivot element to
* replace in the rows below */
@ -735,10 +726,6 @@ arm_status arm_mat_inverse_f16(
break;
}
/* Update the destination pointer modifier */
k++;
/* Decrement loop counter */
}
}

@ -28,6 +28,7 @@
#include "dsp/matrix_functions.h"
/**
@ingroup groupMatrix
*/
@ -84,7 +85,7 @@ arm_status arm_mat_inverse_f32(
float32_t *pTmpA, *pTmpB;
float32_t in = 0.0f; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, l; /* loop counters */
arm_status status; /* status of matrix inverse */
uint32_t blkCnt;
@ -208,10 +209,7 @@ arm_status arm_mat_inverse_f32(
* Temporary variable to hold the pivot value
*/
in = *pInT1;
/*
* Destination pointer modifier
*/
k = 1U;
/*
* Check if the pivot element is zero
@ -227,7 +225,7 @@ arm_status arm_mat_inverse_f32(
* Update the input and destination pointers
*/
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/*
* Check if there is a non zero pivot element to
* * replace in the rows below
@ -317,10 +315,7 @@ arm_status arm_mat_inverse_f32(
*/
break;
}
/*
* Update the destination pointer modifier
*/
k++;
}
}
@ -699,10 +694,6 @@ arm_status arm_mat_inverse_f32(
/* Temporary variable to hold the pivot value */
in = *pInT1;
/* Destination pointer modifier */
k = 1U;
/* Check if the pivot element is zero */
if (*pInT1 == 0.0f)
{
@ -711,7 +702,7 @@ arm_status arm_mat_inverse_f32(
{
/* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/* Check if there is a non zero pivot element to
* replace in the rows below */
@ -753,8 +744,7 @@ arm_status arm_mat_inverse_f32(
break;
}
/* Update the destination pointer modifier */
k++;
}
}
@ -997,7 +987,7 @@ arm_status arm_mat_inverse_f32(
#if defined (ARM_MATH_DSP)
float32_t Xchg, in = 0.0f, in1; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k,l; /* loop counters */
arm_status status; /* status of matrix inverse */
#ifdef ARM_MATH_MATRIX_CHECK
@ -1109,8 +1099,6 @@ arm_status arm_mat_inverse_f32(
in = *pInT1;
/* Destination pointer modifier */
k = 1U;
/* Check if the pivot element is zero */
if (*pInT1 == 0.0f)
@ -1121,7 +1109,7 @@ arm_status arm_mat_inverse_f32(
{
/* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/* Check if there is a non zero pivot element to
* replace in the rows below */
@ -1163,8 +1151,6 @@ arm_status arm_mat_inverse_f32(
break;
}
/* Update the destination pointer modifier */
k++;
/* Decrement loop counter */
}
@ -1306,7 +1292,7 @@ arm_status arm_mat_inverse_f32(
#else
float32_t Xchg, in = 0.0f; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, l; /* loop counters */
arm_status status; /* status of matrix inverse */
#ifdef ARM_MATH_MATRIX_CHECK
@ -1417,9 +1403,6 @@ arm_status arm_mat_inverse_f32(
/* Temporary variable to hold the pivot value */
in = *pInT1;
/* Destination pointer modifier */
k = 1U;
/* Check if the pivot element is zero */
if (*pInT1 == 0.0f)
{
@ -1428,7 +1411,7 @@ arm_status arm_mat_inverse_f32(
{
/* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/* Check if there is a non zero pivot element to
* replace in the rows below */
@ -1457,12 +1440,10 @@ arm_status arm_mat_inverse_f32(
/* Break after exchange is done */
break;
}
/* Update the destination pointer modifier */
k++;
}
}
/* Update the status if the matrix is singular */
if ((flag != 1U) && (in == 0.0f))
{

@ -63,7 +63,7 @@ arm_status arm_mat_inverse_f64(
#if defined (ARM_MATH_DSP)
float64_t Xchg, in = 0.0, in1; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt,k, l; /* loop counters */
arm_status status; /* status of matrix inverse */
#ifdef ARM_MATH_MATRIX_CHECK
@ -174,9 +174,6 @@ arm_status arm_mat_inverse_f64(
/* Temporary variable to hold the pivot value */
in = *pInT1;
/* Destination pointer modifier */
k = 1U;
/* Check if the pivot element is zero */
if (*pInT1 == 0.0)
{
@ -185,7 +182,7 @@ arm_status arm_mat_inverse_f64(
{
/* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/* Check if there is a non zero pivot element to
* replace in the rows below */
@ -227,11 +224,6 @@ arm_status arm_mat_inverse_f64(
break;
}
/* Update the destination pointer modifier */
k++;
/* Decrement loop counter */
i--;
}
}
@ -371,7 +363,7 @@ arm_status arm_mat_inverse_f64(
#else
float64_t Xchg, in = 0.0; /* Temporary input values */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, k, l; /* loop counters */
uint32_t i, rowCnt, flag = 0U, j, loopCnt, l; /* loop counters */
arm_status status; /* status of matrix inverse */
#ifdef ARM_MATH_MATRIX_CHECK
@ -482,8 +474,7 @@ arm_status arm_mat_inverse_f64(
/* Temporary variable to hold the pivot value */
in = *pInT1;
/* Destination pointer modifier */
k = 1U;
/* Check if the pivot element is zero */
if (*pInT1 == 0.0)
@ -493,7 +484,7 @@ arm_status arm_mat_inverse_f64(
{
/* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i);
pOutT2 = pOutT1 + (numCols * k);
pOutT2 = pOutT1 + (numCols * i);
/* Check if there is a non zero pivot element to
* replace in the rows below */
@ -523,8 +514,7 @@ arm_status arm_mat_inverse_f64(
break;
}
/* Update the destination pointer modifier */
k++;
}
}

@ -172,9 +172,8 @@ def getInvertibleMatrix(d):
m=[[0.804738, -0.310617, 0.505879], [0.505879,
0.804738, -0.310617], [-0.310617, 0.505879, 0.804738]]
if d == 4:
m = [[0.82826, 0.337671, 0.564395, 0.576988], [0.403359, 0.369414,
0.597588, 0.436561], [0.783442, 0.3334, 0.525436,
0.0858155], [0.329328, 0.397682, 0.12816, 0.775337]]
m = [[1.0, 2.0, 3.0, 4.0], [2.0, 4.0, 5.0, 6.0],
[3.0, 5.0, 9.0, 10.0], [4.0, 6.0, 10.0, 16.0]]
if d == 7:
m = [[0.978575, 0.330011, 0.951751, 0.304936, 0.924631, 0.502005,
0.235223], [0.185314, 0.46862, 0.955398, 0.970953, 0.637389,
@ -680,6 +679,7 @@ def getSemidefinitePositiveMatrix(d,k=3):
return(np.matmul(p,np.matmul(a,np.transpose(p))))
def writeUnaryTests(config,format):
config.setOverwrite(False)
# For benchmarks
NBSAMPLES=NBA*NBB
NBVECSAMPLES = NBB
@ -783,8 +783,12 @@ def writeUnaryTests(config,format):
else:
dims=[1,2,3,4,7,8,9,15,16,17,32,33]
#
if format == Tools.F32 or format == Tools.F16 or format == Tools.F64:
config.setOverwrite(True)
vals = []
inp=[]
for d in dims:
ma = getInvertibleMatrix(d)
inp = inp + list(ma.reshape(d*d))
@ -801,6 +805,8 @@ def writeUnaryTests(config,format):
config.writeInputS16(1, dims,"DimsInvert")
config.writeInput(1, inp,"InputInvert")
config.writeReference(1, vals,"RefInvert")
config.setOverwrite(False)
# One kind of matrix shape
# Cholesky and LDLT definite positive (DPO)
@ -854,7 +860,6 @@ def writeUnaryTests(config,format):
config.writeInputS16(1, dims,"DimsCholeskyDPO")
config.writeInput(1, inp,"InputCholeskyDPO")
config.writeReference(1, vals,"RefCholeskyDPO")
@ -916,6 +921,13 @@ def generatePatterns():
configBinaryq15=Tools.Config(PATTERNBINDIR,PARAMBINDIR,"q15")
configBinaryq7=Tools.Config(PATTERNBINDIR,PARAMBINDIR,"q7")
configBinaryf64.setOverwrite(False)
configBinaryf32.setOverwrite(False)
configBinaryf16.setOverwrite(False)
configBinaryq31.setOverwrite(False)
configBinaryq15.setOverwrite(False)
configBinaryq7.setOverwrite(False)
writeBinaryTests(configBinaryf64,Tools.F32)
writeBinaryTests(configBinaryf32,Tools.F32)
@ -934,6 +946,12 @@ def generatePatterns():
configUnaryq15=Tools.Config(PATTERNUNDIR,PARAMUNDIR,"q15")
configUnaryq7=Tools.Config(PATTERNUNDIR,PARAMUNDIR,"q7")
configUnaryf64.setOverwrite(False)
configUnaryf32.setOverwrite(False)
configUnaryf16.setOverwrite(False)
configUnaryq31.setOverwrite(False)
configUnaryq15.setOverwrite(False)
configUnaryq7.setOverwrite(False)
writeUnaryTests(configUnaryf64,Tools.F64)
writeUnaryTests(configUnaryf32,Tools.F32)

@ -28,38 +28,38 @@ H
0x380c
// 0.804738
0x3a70
// 0.828260
0x3aa0
// 0.337671
0x3567
// 0.564395
0x3884
// 0.576988
0x389e
// 0.403359
0x3674
// 0.369414
0x35e9
// 0.597588
0x38c8
// 0.436561
0x36fc
// 0.783442
0x3a44
// 0.333400
0x3556
// 0.525436
0x3834
// 0.085816
0x2d7e
// 0.329328
0x3545
// 0.397682
0x365d
// 0.128160
0x301a
// 0.775337
0x3a34
// 1.000000
0x3c00
// 2.000000
0x4000
// 3.000000
0x4200
// 4.000000
0x4400
// 2.000000
0x4000
// 4.000000
0x4400
// 5.000000
0x4500
// 6.000000
0x4600
// 3.000000
0x4200
// 5.000000
0x4500
// 9.000000
0x4880
// 10.000000
0x4900
// 4.000000
0x4400
// 6.000000
0x4600
// 10.000000
0x4900
// 16.000000
0x4c00
// 0.978575
0x3bd4
// 0.330011

@ -28,38 +28,38 @@ H
0xb4f8
// 0.804738
0x3a70
// 1.413088
0x3da7
// -2.081372
0xc02a
// 0.842716
0x3abe
// 0.027076
0x26ee
// -5.114009
0xc51d
// 0.949805
0x3b99
// 3.715447
0x436e
// 2.859700
0x41b8
// 0.830018
0x3aa4
// 2.503484
0x4102
// -1.378369
0xbd83
// -1.874732
0xbf80
// 1.885638
0x3f8b
// -0.016912
0xa454
// -2.035818
0xc012
// 0.121363
0x2fc4
// -6.500000
0xc680
// 1.500000
0x3e00
// 0.500000
0x3800
// 0.750000
0x3a00
// 1.500000
0x3e00
// 0.500000
0x3800
// -0.500000
0xb800
// -0.250000
0xb400
// 0.500000
0x3800
// -0.500000
0xb800
// 0.500000
0x3800
// -0.250000
0xb400
// 0.750000
0x3a00
// -0.250000
0xb400
// -0.250000
0xb400
// 0.125000
0x3000
// 1.305635
0x3d39
// -2.573542

@ -28,38 +28,38 @@ W
0x3f018149
// 0.804738
0x3f4e034f
// 0.828260
0x3f5408d9
// 0.337671
0x3eace337
// 0.564395
0x3f107c31
// 0.576988
0x3f13b57c
// 0.403359
0x3ece8512
// 0.369414
0x3ebd23d5
// 0.597588
0x3f18fb87
// 0.436561
0x3edf84ec
// 0.783442
0x3f488fa8
// 0.333400
0x3eaab368
// 0.525436
0x3f0682f9
// 0.085816
0x3dafc009
// 0.329328
0x3ea89dae
// 0.397682
0x3ecb9cfa
// 0.128160
0x3e033c60
// 0.775337
0x3f467c7c
// 1.000000
0x3f800000
// 2.000000
0x40000000
// 3.000000
0x40400000
// 4.000000
0x40800000
// 2.000000
0x40000000
// 4.000000
0x40800000
// 5.000000
0x40a00000
// 6.000000
0x40c00000
// 3.000000
0x40400000
// 5.000000
0x40a00000
// 9.000000
0x41100000
// 10.000000
0x41200000
// 4.000000
0x40800000
// 6.000000
0x40c00000
// 10.000000
0x41200000
// 16.000000
0x41800000
// 0.978575
0x3f7a83e4
// 0.330011

@ -28,38 +28,38 @@ W
0xbe9f093a
// 0.804738
0x3f4e0352
// 1.413088
0x3fb4e00e
// -2.081372
0xc0053535
// 0.842716
0x3f57bc42
// 0.027076
0x3cddcf41
// -5.114009
0xc0a3a5f6
// 0.949805
0x3f732671
// 3.715447
0x406dc9e3
// 2.859700
0x40370551
// 0.830018
0x3f547c13
// 2.503484
0x40203916
// -1.378369
0xbfb06e66
// -1.874732
0xbfeff734
// 1.885638
0x3ff15c95
// -0.016912
0xbc8a8bec
// -2.035818
0xc0024ad6
// 0.121363
0x3df88d64
// -6.500000
0xc0d00000
// 1.500000
0x3fc00000
// 0.500000
0x3f000000
// 0.750000
0x3f400000
// 1.500000
0x3fc00000
// 0.500000
0x3f000000
// -0.500000
0xbf000000
// -0.250000
0xbe800000
// 0.500000
0x3f000000
// -0.500000
0xbf000000
// 0.500000
0x3f000000
// -0.250000
0xbe800000
// 0.750000
0x3f400000
// -0.250000
0xbe800000
// -0.250000
0xbe800000
// 0.125000
0x3e000000
// 1.305635
0x3fa71f0a
// -2.573542

@ -28,38 +28,38 @@ D
0x3fe030292817763e
// 0.804738
0x3fe9c069e7fb267c
// 0.828260
0x3fea811b1d92b7fe
// 0.337671
0x3fd59c66d373affb
// 0.564395
0x3fe20f861a60d456
// 0.576988
0x3fe276af89c5e6ff
// 0.403359
0x3fd9d0a244630660
// 0.369414
0x3fd7a47a9e2bcf92
// 0.597588
0x3fe31f70de8f6cf0
// 0.436561
0x3fdbf09d8c6d612c
// 0.783442
0x3fe911f4f50a02b8
// 0.333400
0x3fd5566cf41f212d
// 0.525436
0x3fe0d05f28848388
// 0.085816
0x3fb5f8012dfd694d
// 0.329328
0x3fd513b5bf6a0dbb
// 0.397682
0x3fd9739f340d4dc6
// 0.128160
0x3fc0678c0053e2d6
// 0.775337
0x3fe8cf8f8a4c1ebd
// 1.000000
0x3ff0000000000000
// 2.000000
0x4000000000000000
// 3.000000
0x4008000000000000
// 4.000000
0x4010000000000000
// 2.000000
0x4000000000000000
// 4.000000
0x4010000000000000
// 5.000000
0x4014000000000000
// 6.000000
0x4018000000000000
// 3.000000
0x4008000000000000
// 5.000000
0x4014000000000000
// 9.000000
0x4022000000000000
// 10.000000
0x4024000000000000
// 4.000000
0x4010000000000000
// 6.000000
0x4018000000000000
// 10.000000
0x4024000000000000
// 16.000000
0x4030000000000000
// 0.978575
0x3fef507c84b5dcc6
// 0.330011

@ -28,38 +28,38 @@ D
0xbfd3e1273621427c
// 0.804738
0x3fe9c06a4dbb030e
// 1.413088
0x3ff69c01bcda645e
// -2.081372
0xc000a6a6921dae82
// 0.842716
0x3feaf788323d9e48
// 0.027076
0x3f9bb9e81fc8d11d
// -5.114009
0xc01474bec0fff89f
// 0.949805
0x3fee64ce235cf93e
// 3.715447
0x400db93c5f7d8b66
// 2.859700
0x4006e0aa2629c79f
// 0.830018
0x3fea8f8251c1a7fe
// 2.503484
0x40040722ce21a4d9
// -1.378369
0xbff60dccba5f5183
// -1.874732
0xbffdfee6844680da
// 1.885638
0x3ffe2b92ac107f21
// -0.016912
0xbf91517d8bf08170
// -2.035818
0xc000495ac61bce71
// 0.121363
0x3fbf11ac8d7c1e33
// -6.500000
0xc019fffffffffffe
// 1.500000
0x3ff8000000000000
// 0.500000
0x3fdffffffffffff8
// 0.750000
0x3fe8000000000002
// 1.500000
0x3ff7ffffffffffff
// 0.500000
0x3fe0000000000000
// -0.500000
0xbfdffffffffffffe
// -0.250000
0xbfd0000000000001
// 0.500000
0x3fdffffffffffffe
// -0.500000
0xbfe0000000000000
// 0.500000
0x3fe0000000000000
// -0.250000
0xbfd0000000000000
// 0.750000
0x3fe7ffffffffffff
// -0.250000
0xbfd0000000000000
// -0.250000
0xbfcffffffffffffe
// 0.125000
0x3fbffffffffffffe
// 1.305635
0x3ff4e3e14dd454c7
// -2.573542

Loading…
Cancel
Save