CMSIS-DSP: Improved tests on matrix inversions

And correction of an internal pointer bug in pivot code.
pull/19/head
Christophe Favergeon 5 years ago
parent 1019e4c4a8
commit 28746aeadb

@ -201,7 +201,7 @@ arm_status arm_mat_inverse_f16(
/* /*
* Loop over the number rows present below * Loop over the number rows present below
*/ */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows-l; i++)
{ {
/* /*
* Update the input and destination pointers * Update the input and destination pointers
@ -680,7 +680,7 @@ arm_status arm_mat_inverse_f16(
{ {
/* Loop over the number rows present below */ /* Loop over the number rows present below */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows-l; i++)
{ {
/* Update the input and destination pointers */ /* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i); pInT2 = pInT1 + (numCols * i);

@ -219,7 +219,7 @@ arm_status arm_mat_inverse_f32(
/* /*
* Loop over the number rows present below * Loop over the number rows present below
*/ */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows-l; i++)
{ {
/* /*
* Update the input and destination pointers * Update the input and destination pointers
@ -698,7 +698,7 @@ arm_status arm_mat_inverse_f32(
if (*pInT1 == 0.0f) if (*pInT1 == 0.0f)
{ {
/* Loop over the number rows present below */ /* Loop over the number rows present below */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows - l; i++)
{ {
/* Update the input and destination pointers */ /* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i); pInT2 = pInT1 + (numCols * i);
@ -1105,7 +1105,7 @@ arm_status arm_mat_inverse_f32(
{ {
/* Loop over the number rows present below */ /* Loop over the number rows present below */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows - l; i++)
{ {
/* Update the input and destination pointers */ /* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i); pInT2 = pInT1 + (numCols * i);
@ -1407,7 +1407,7 @@ arm_status arm_mat_inverse_f32(
if (*pInT1 == 0.0f) if (*pInT1 == 0.0f)
{ {
/* Loop over the number rows present below */ /* Loop over the number rows present below */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows-l; i++)
{ {
/* Update the input and destination pointers */ /* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i); pInT2 = pInT1 + (numCols * i);

@ -174,11 +174,14 @@ arm_status arm_mat_inverse_f64(
/* Temporary variable to hold the pivot value */ /* Temporary variable to hold the pivot value */
in = *pInT1; in = *pInT1;
/* Check if the pivot element is zero */ /* Check if the pivot element is zero */
if (*pInT1 == 0.0) if (*pInT1 == 0.0)
{ {
/* Loop over the number rows present below */ /* Loop over the number rows present below */
for (i = (l + 1U); i < numRows; i++)
for (i = 1U; i < numRows - l; i++)
{ {
/* Update the input and destination pointers */ /* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i); pInT2 = pInT1 + (numCols * i);
@ -224,6 +227,8 @@ arm_status arm_mat_inverse_f64(
break; break;
} }
/* Decrement loop counter */
} }
} }
@ -474,13 +479,11 @@ arm_status arm_mat_inverse_f64(
/* Temporary variable to hold the pivot value */ /* Temporary variable to hold the pivot value */
in = *pInT1; in = *pInT1;
/* Check if the pivot element is zero */ /* Check if the pivot element is zero */
if (*pInT1 == 0.0) if (*pInT1 == 0.0)
{ {
/* Loop over the number rows present below */ /* Loop over the number rows present below */
for (i = (l + 1U); i < numRows; i++) for (i = 1U; i < numRows-l; i++)
{ {
/* Update the input and destination pointers */ /* Update the input and destination pointers */
pInT2 = pInT1 + (numCols * i); pInT2 = pInT1 + (numCols * i);
@ -513,11 +516,10 @@ arm_status arm_mat_inverse_f64(
/* Break after exchange is done */ /* Break after exchange is done */
break; break;
} }
} }
} }
/* Update the status if the matrix is singular */ /* Update the status if the matrix is singular */
if ((flag != 1U) && (in == 0.0)) if ((flag != 1U) && (in == 0.0))
{ {

@ -299,6 +299,22 @@ void UnaryTestsF16::test_mat_cmplx_trans_f16()
} }
static void refInnerTail(float16_t *b)
{
b[0] = 1.0f16;
b[1] = -2.0f16;
b[2] = 3.0f16;
b[3] = -4.0f16;
}
static void checkInnerTail(float16_t *b)
{
ASSERT_TRUE(b[0] == 1.0f16);
ASSERT_TRUE(b[1] == -2.0f16);
ASSERT_TRUE(b[2] == 3.0f16);
ASSERT_TRUE(b[3] == -4.0f16);
}
void UnaryTestsF16::test_mat_inverse_f16() void UnaryTestsF16::test_mat_inverse_f16()
{ {
const float16_t *inp1=input1.ptr(); const float16_t *inp1=input1.ptr();
@ -319,15 +335,18 @@ void UnaryTestsF16::test_mat_inverse_f16()
PREPAREDATA1(false); PREPAREDATA1(false);
refInnerTail(outp+(rows * columns));
status=arm_mat_inverse_f16(&this->in1,&this->out); status=arm_mat_inverse_f16(&this->in1,&this->out);
ASSERT_TRUE(status==ARM_MATH_SUCCESS); ASSERT_TRUE(status==ARM_MATH_SUCCESS);
outp += (rows * columns); outp += (rows * columns);
inp1 += (rows * columns); inp1 += (rows * columns);
checkInnerTail(outp);
} }
ASSERT_EMPTY_TAIL(output);
ASSERT_SNR(output,ref,(float16_t)SNR_THRESHOLD_INV); ASSERT_SNR(output,ref,(float16_t)SNR_THRESHOLD_INV);

@ -323,6 +323,24 @@ void UnaryTestsF32::test_mat_cmplx_trans_f32()
} }
static void refInnerTail(float32_t *b)
{
b[0] = 1.0f;
b[1] = -2.0f;
b[2] = 3.0f;
b[3] = -4.0f;
}
static void checkInnerTail(float32_t *b)
{
ASSERT_TRUE(b[0] == 1.0f);
ASSERT_TRUE(b[1] == -2.0f);
ASSERT_TRUE(b[2] == 3.0f);
ASSERT_TRUE(b[3] == -4.0f);
}
void UnaryTestsF32::test_mat_inverse_f32() void UnaryTestsF32::test_mat_inverse_f32()
{ {
const float32_t *inp1=input1.ptr(); const float32_t *inp1=input1.ptr();
@ -343,15 +361,18 @@ void UnaryTestsF32::test_mat_inverse_f32()
PREPAREDATA1(false); PREPAREDATA1(false);
refInnerTail(outp+(rows * columns));
status=arm_mat_inverse_f32(&this->in1,&this->out); status=arm_mat_inverse_f32(&this->in1,&this->out);
ASSERT_TRUE(status==ARM_MATH_SUCCESS); ASSERT_TRUE(status==ARM_MATH_SUCCESS);
outp += (rows * columns); outp += (rows * columns);
inp1 += (rows * columns); inp1 += (rows * columns);
checkInnerTail(outp);
} }
ASSERT_EMPTY_TAIL(output);
ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD_INV); ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD_INV);

@ -171,6 +171,24 @@ void UnaryTestsF64::test_mat_trans_f64()
} }
/*
Test framework is only adding 16 bytes of free memory after the end of a buffer.
So, we limit to 2 float64 for checking out of buffer write.
*/
static void refInnerTail(float64_t *b)
{
b[0] = 1.0;
b[1] = -2.0;
}
static void checkInnerTail(float64_t *b)
{
ASSERT_TRUE(b[0] == 1.0);
ASSERT_TRUE(b[1] == -2.0);
}
void UnaryTestsF64::test_mat_inverse_f64() void UnaryTestsF64::test_mat_inverse_f64()
{ {
const float64_t *inp1=input1.ptr(); const float64_t *inp1=input1.ptr();
@ -191,15 +209,18 @@ void UnaryTestsF64::test_mat_inverse_f64()
PREPAREDATA1(false); PREPAREDATA1(false);
refInnerTail(outp+(rows * columns));
status=arm_mat_inverse_f64(&this->in1,&this->out); status=arm_mat_inverse_f64(&this->in1,&this->out);
ASSERT_TRUE(status==ARM_MATH_SUCCESS); ASSERT_TRUE(status==ARM_MATH_SUCCESS);
outp += (rows * columns); outp += (rows * columns);
inp1 += (rows * columns); inp1 += (rows * columns);
checkInnerTail(outp);
} }
ASSERT_EMPTY_TAIL(output);
ASSERT_SNR(output,ref,(float64_t)SNR_THRESHOLD); ASSERT_SNR(output,ref,(float64_t)SNR_THRESHOLD);

Loading…
Cancel
Save