CMSIS-DSP : faster Q.15/Q.31 Helium matrix multiplications. Uses an initial transpose stage, requiring extra scratch space to hold RHS transposed matrix.

pull/19/head
FabKlein 4 years ago committed by Christophe Favergeon
parent c520fb08f4
commit cfc30c12b8

@ -3,8 +3,8 @@
* Title: arm_mat_mult_q15.c * Title: arm_mat_mult_q15.c
* Description: Q15 matrix multiplication * Description: Q15 matrix multiplication
* *
* $Date: 23 April 2021 * $Date: 3 Nov 2021
* $Revision: V1.9.0 * $Revision: V1.10.0
* *
* Target Processor: Cortex-M and Cortex-A cores * Target Processor: Cortex-M and Cortex-A cores
* -------------------------------------------------------------------- */ * -------------------------------------------------------------------- */
@ -315,24 +315,28 @@ __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve(
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} }
arm_status arm_mat_mult_q15( arm_status arm_mat_mult_q15(
const arm_matrix_instance_q15 * pSrcA, const arm_matrix_instance_q15 * pSrcA,
const arm_matrix_instance_q15 * pSrcB, const arm_matrix_instance_q15 * pSrcB,
arm_matrix_instance_q15 * pDst, arm_matrix_instance_q15 * pDst,
q15_t * pState) q15_t * pState)
{ {
q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */ q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
q15_t *pOut = pDst->pData; /* output data matrix pointer */ q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q15_t *pInA2;
q15_t *pInB2;
q15_t *px; /* Temporary output data matrix pointer */ q15_t *px; /* Temporary output data matrix pointer */
uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ q15_t *px2; /* Temporary output data matrix pointer */
uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint16_t col, i = 0U, row = numRowsA; /* loop counters */ uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint16x8_t vecOffs, vecColBOffs; uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
uint32_t blkCnt,rowCnt; /* loop counters */ uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
q15_t *pSrcBT = pState; /* input data matrix pointer for transpose */
uint32_t blkCnt; /* loop counters */
arm_status status; /* Status of matrix multiplication */ arm_status status; /* Status of matrix multiplication */
(void)pState; arm_matrix_instance_q15 BT;
#ifdef ARM_MATH_MATRIX_CHECK #ifdef ARM_MATH_MATRIX_CHECK
@ -350,14 +354,12 @@ arm_status arm_mat_mult_q15(
/* small squared matrix specialized routines */ /* small squared matrix specialized routines */
if (numRowsA == numColsB && numColsB == numColsA) { if (numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == 1) if (numRowsA == 1) {
{
q63_t sum; q63_t sum;
sum = pInA[0] * pInB[0]; sum = pInA[0] * pInB[0];
pOut[0] = (q15_t) __SSAT((sum >> 15), 16); pDst->pData[0] = (q15_t) __SSAT((sum >> 15), 16);
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} } else if (numRowsA == 2)
else if(numRowsA == 2)
return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst); return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst);
else if (numRowsA == 3) else if (numRowsA == 3)
return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst); return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst);
@ -365,229 +367,256 @@ arm_status arm_mat_mult_q15(
return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst); return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst);
} }
vecColBOffs = vidupq_u16((uint32_t)0, 1); /*
vecColBOffs = vecColBOffs * (uint16_t) (numColsB); * Matrix transpose
*/
BT.numRows = numColsB;
BT.numCols = numRowsB;
BT.pData = pSrcBT;
arm_mat_trans_q15(pSrcB, &BT);
/* /*
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB * Reset the variables for the usage in the following multiplication process
*/ */
i = 0;
row = numRowsA >> 1;
px = pDst->pData;
px2 = px + numColsB;
/* /*
* row loop * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
*/ */
rowCnt = row >> 2;
while (rowCnt > 0U)
{
/* /*
* Output pointer is set to starting address of the row being processed * row loop
*/ */
px = pOut + i; while (row > 0u) {
i = i + 4 * numColsB;
/* /*
* For every row wise process, the column loop counter is to be initiated * For every row wise process, the column loop counter is to be initiated
*/ */
col = numColsB; col = numColsB >> 1;
/* /*
* For every row wise process, the pInB pointer is set * For every row wise process, the pIn2 pointer is set
* to the starting address of the pSrcB data * to the starting address of the transposed pSrcB data
*/ */
pInB = pSrcB->pData; pInB = pSrcBT;
pInB2 = pInB + numRowsB;
j = 0;
/* /*
* column loop * column loop
*/ */
while (col > 0U) while (col > 0u) {
{ q15_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
/* q15x8_t vecA, vecA2, vecB, vecB2;
* generate 4 columns elements q63_t acc0, acc1, acc2, acc3;
*/
/* /*
* Matrix A columns number of MAC operations are to be performed * Initiate the pointer pIn1 to point to the starting address of the column being processed
*/ */
pInA = pSrcA->pData + i;
pInA2 = pInA + numColsA;
pInB = pSrcBT + j;
pInB2 = pInB + numRowsB;
q15_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
q15_t *pInA0 = pInA; pSrcAVec = (q15_t const *) pInA;
q15_t *pInA1 = pInA0 + numColsA; pSrcA2Vec = (q15_t const *) pInA2;
q15_t *pInA2 = pInA1 + numColsA; pSrcBVec = (q15_t const *) pInB;
q15_t *pInA3 = pInA2 + numColsA; pSrcB2Vec = (q15_t const *) pInB2;
q63_t acc0, acc1, acc2, acc3;
acc0 = 0LL; acc0 = 0LL;
acc1 = 0LL; acc1 = 0LL;
acc2 = 0LL; acc2 = 0LL;
acc3 = 0LL; acc3 = 0LL;
pSrcA0Vec = (q15_t const *) pInA0; vecA = vld1q(pSrcAVec);
pSrcA1Vec = (q15_t const *) pInA1; pSrcAVec += 8;
pSrcA2Vec = (q15_t const *) pInA2;
pSrcA3Vec = (q15_t const *) pInA3;
vecOffs = vecColBOffs;
blkCnt = (numColsA) >> 3; blkCnt = numColsA / 8;
while (blkCnt > 0U) while (blkCnt > 0U) {
{ vecB = vld1q(pSrcBVec);
q15x8_t vecB, vecA; pSrcBVec += 8;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB); acc0 = vmlaldavaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8; vecA2 = vld1q(pSrcA2Vec);
acc1 = vmlaldavaq(acc1, vecA, vecB); pSrcA2Vec += 8;
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8; acc1 = vmlaldavaq(acc1, vecA2, vecB);
acc2 = vmlaldavaq(acc2, vecA, vecB); vecB2 = vld1q(pSrcB2Vec);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8; pSrcB2Vec += 8;
acc3 = vmlaldavaq(acc3, vecA, vecB); acc2 = vmlaldavaq(acc2, vecA, vecB2);
blkCnt--; vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
acc3 = vmlaldavaq(acc3, vecA2, vecB2);
blkCnt--;
} }
/* /*
* tail * tail
*/ */
blkCnt = numColsA & 7; blkCnt = numColsA & 7;
if (blkCnt > 0U) if (blkCnt > 0U) {
{
mve_pred16_t p0 = vctp16q(blkCnt); mve_pred16_t p0 = vctp16q(blkCnt);
q15x8_t vecB, vecA; vecB = vld1q(pSrcBVec);
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
vecA = vld1q(pSrcA1Vec); vecA2 = vld1q(pSrcA2Vec);
acc1 = vmlaldavaq_p(acc1, vecA, vecB, p0); acc1 = vmlaldavaq_p(acc1, vecA2, vecB, p0);
vecA = vld1q(pSrcA2Vec); vecB2 = vld1q(pSrcB2Vec);
acc2 = vmlaldavaq_p(acc2, vecA, vecB, p0); acc2 = vmlaldavaq_p(acc2, vecA, vecB2, p0);
vecA = vld1q(pSrcA3Vec); vecA = vld1q(pSrcAVec);
acc3 = vmlaldavaq_p(acc3, vecA, vecB, p0); acc3 = vmlaldavaq_p(acc3, vecA2, vecB2, p0);
} }
px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15); *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
px[1 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc1, 15); *px++ = (q15_t) MVE_ASRL_SAT16(acc2, 15);
px[2 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc2, 15); *px2++ = (q15_t) MVE_ASRL_SAT16(acc1, 15);
px[3 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc3, 15); *px2++ = (q15_t) MVE_ASRL_SAT16(acc3, 15);
px++; j += numRowsB * 2;
/* /*
* Decrement the column loop counter * Decrement the column loop counter
*/ */
col--; col--;
}
i = i + numColsA * 2;
px = px2 + (numColsB & 1u);
px2 = px + numColsB;
/* /*
* Update the pointer pInB to point to the starting address of the next column * Decrement the row loop counter
*/ */
pInB = pSrcB->pData + (numColsB - col); row--;
} }
/* /*
* Update the pointer pInA to point to the starting address of the next row * Compute remaining row and/or column below
*/ */
pInA += (numColsA * 4);
if (numColsB & 1u) {
row = numRowsA & (~0x1); //avoid redundant computation
px = pDst->pData + numColsB - 1;
i = 0;
/* /*
* Decrement the row loop counter * row loop
*/ */
rowCnt --; while (row > 0) {
q15_t const *pSrcAVec, *pSrcBVec;
q15x8_t vecA, vecB;
q63_t acc0;
} /*
* point to last column in matrix B
*/
pInB = pSrcBT + numRowsB * (numColsB - 1);
pInA = pSrcA->pData + i;
rowCnt = row & 3; pSrcAVec = (q15_t const *) pInA;
while (rowCnt > 0U) pSrcBVec = (q15_t const *) pInB;
{
acc0 = 0LL;
blkCnt = (numColsA) / 8;
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
vecB = vld1q(pSrcBVec);
pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
blkCnt--;
}
/* /*
* Output pointer is set to starting address of the row being processed * tail
*/ */
px = pOut + i; blkCnt = (numColsA & 7);
i = i + numColsB; if (blkCnt > 0U) {
mve_pred16_t p0 = vctp16q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
}
*px = (q15_t) MVE_ASRL_SAT16(acc0, 15);
px += numColsB;
i += numColsA;
/* /*
* For every row wise process, the column loop counter is to be initiated * Decrement the row loop counter
*/ */
row--;
}
}
if (numRowsA & 1u) {
col = numColsB; col = numColsB;
i = 0u;
/* /*
* For every row wise process, the pInB pointer is set * point to last row in output matrix
* to the starting address of the pSrcB data
*/ */
pInB = pSrcB->pData; px = pDst->pData + (numColsB) * (numRowsA - 1);
/* /*
* column loop * col loop
*/ */
while (col > 0U) while (col > 0) {
{ q15_t const *pSrcAVec, *pSrcBVec;
q15x8_t vecA, vecB;
q63_t acc0;
/* /*
* generate 4 columns elements * point to last row in matrix A
*/ */
pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
pInB = pSrcBT + i;
/* /*
* Matrix A columns number of MAC operations are to be performed * Set the variable sum, that acts as accumulator, to zero
*/ */
pSrcAVec = (q15_t const *) pInA;
q15_t const *pSrcA0Vec; pSrcBVec = (q15_t const *) pInB;
q15_t *pInA0 = pInA;
q63_t acc0;
acc0 = 0LL; acc0 = 0LL;
pSrcA0Vec = (q15_t const *) pInA0; blkCnt = ((numColsA) / 8);
while (blkCnt > 0U) {
vecOffs = vecColBOffs; vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
blkCnt = (numColsA) >> 3; vecB = vld1q(pSrcBVec);
while (blkCnt > 0U) pSrcBVec += 8;
{
q15x8_t vecB, vecA;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec);
pSrcA0Vec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB); acc0 = vmlaldavaq(acc0, vecA, vecB);
blkCnt--; blkCnt--;
} }
/* /*
* tail * tail
*/ */
blkCnt = numColsA & 7; blkCnt = (numColsA & 7);
if (blkCnt > 0U) if (blkCnt > 0U) {
{
mve_pred16_t p0 = vctp16q(blkCnt); mve_pred16_t p0 = vctp16q(blkCnt);
q15x8_t vecB, vecA; vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
} }
px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15); *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
i += numColsA;
px++;
/* /*
* Decrement the column loop counter * Decrement the col loop counter
*/ */
col--; col--;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = pSrcB->pData + (numColsB - col);
} }
/*
* Update the pointer pInA to point to the starting address of the next row
*/
pInA += (numColsA );
rowCnt--;
} }
/* Set status as ARM_MATH_SUCCESS */ /* Set status as ARM_MATH_SUCCESS */
status = ARM_MATH_SUCCESS; status = ARM_MATH_SUCCESS;
} }
/* Return to application */ /* Return to application */
return (status); return (status);
} }
#else #else
arm_status arm_mat_mult_q15( arm_status arm_mat_mult_q15(
const arm_matrix_instance_q15 * pSrcA, const arm_matrix_instance_q15 * pSrcA,

@ -3,8 +3,8 @@
* Title: arm_mat_mult_q31.c * Title: arm_mat_mult_q31.c
* Description: Q31 matrix multiplication * Description: Q31 matrix multiplication
* *
* $Date: 23 April 2021 * $Date: 3 Nov 2021
* $Revision: V1.9.0 * $Revision: V1.10.0
* *
* Target Processor: Cortex-M and Cortex-A cores * Target Processor: Cortex-M and Cortex-A cores
* -------------------------------------------------------------------- */ * -------------------------------------------------------------------- */
@ -332,44 +332,45 @@ __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve(
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} }
arm_status arm_mat_mult_q31( arm_status arm_mat_mult_q31(
const arm_matrix_instance_q31 * pSrcA, const arm_matrix_instance_q31 * pSrcA,
const arm_matrix_instance_q31 * pSrcB, const arm_matrix_instance_q31 * pSrcB,
arm_matrix_instance_q31 * pDst) arm_matrix_instance_q31 * pDst)
{ {
q31_t const *pInB = (q31_t const *)pSrcB->pData; /* input data matrix pointer B */ q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
q31_t const *pInA = (q31_t const *)pSrcA->pData; /* input data matrix pointer A */ q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q31_t *pOut = pDst->pData; /* output data matrix pointer */ q31_t *pInA2;
q31_t *pInB2;
q31_t *px; /* Temporary output data matrix pointer */ q31_t *px; /* Temporary output data matrix pointer */
uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ q31_t *px2; /* Temporary output data matrix pointer */
uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint16_t col, i = 0U, row = numRowsA; /* loop counters */ uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
arm_status status; /* status of matrix multiplication */ uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
uint32x4_t vecOffs, vecColBOffs; uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
uint32_t blkCnt, rowCnt; /* loop counters */ q31_t State[numRowsB * numColsB * 1];
q31_t *pSrcBT = State; /* input data matrix pointer for transpose */
uint32_t blkCnt; /* loop counters */
arm_status status; /* Status of matrix multiplication */
arm_matrix_instance_q31 BT;
#ifdef ARM_MATH_MATRIX_CHECK #ifdef ARM_MATH_MATRIX_CHECK
/* Check for matrix mismatch condition */ /* Check for matrix mismatch condition */
if ((pSrcA->numCols != pSrcB->numRows) || if ((pSrcA->numCols != pSrcB->numRows) ||
(pSrcA->numRows != pDst->numRows) || (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols)) {
(pSrcB->numCols != pDst->numCols) )
{
/* Set status as ARM_MATH_SIZE_MISMATCH */ /* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH; status = ARM_MATH_SIZE_MISMATCH;
} } else
else
#endif /* #ifdef ARM_MATH_MATRIX_CHECK */ #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
{ {
/* small squared matrix specialized routines */ /* small squared matrix specialized routines */
if(numRowsA == numColsB && numColsB == numColsA) { if(numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == 1) if (numRowsA == 1)
{ {
q63_t sum = (q63_t) *pInA * *pInB; q63_t sum = (q63_t) *pInA * *pInB;
pOut[0] = (q31_t)(sum >> 31); pDst->pData[0] = (q31_t)(sum >> 31);
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} }
else if(numRowsA == 2) else if(numRowsA == 2)
@ -380,245 +381,262 @@ arm_status arm_mat_mult_q31(
return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst); return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst);
} }
vecColBOffs = vidupq_u32((uint32_t)0, 1);
vecColBOffs = vecColBOffs * (uint32_t) (numColsB);
/* /*
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB * Matrix transpose
*/ */
BT.numRows = numColsB;
BT.numCols = numRowsB;
BT.pData = pSrcBT;
arm_mat_trans_q31(pSrcB, &BT);
/* /*
* row loop * Reset the variables for the usage in the following multiplication process
*/ */
rowCnt = row >> 2; i = 0;
while (rowCnt > 0U) row = numRowsA >> 1;
{ px = pDst->pData;
px2 = px + numColsB;
/* /*
* Output pointer is set to starting address of the row being processed * main loop
* compute 2 x 2 output blocks
* with dot products (Matrix A rows * Transposed MAtrix B rows)
*/ */
px = pOut + i; while (row > 0u) {
i = i + 4 * numColsB;
/* /*
* For every row wise process, the column loop counter is to be initiated * For every row wise process, the column loop counter is to be initiated
* Compute 2 columns and 2 rows in parrallel
*/ */
col = numColsB; col = numColsB >> 1;
/* j = 0;
* For every row wise process, the pInB pointer is set
* to the starting address of the pSrcB data
*/
pInB = (q31_t const *)pSrcB->pData;
/*
* column loop
*/
while (col > 0U)
{
/* /*
* generate 4 columns elements * column pair loop
*/ */
while (col > 0u) {
q31_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
q31x4_t vecA, vecA2, vecB, vecB2;
q63_t acc0, acc1, acc2, acc3;
/* /*
* Matrix A columns number of MAC operations are to be performed * Initiate the pointers
* - 2 x consecutive Matrix A rows (i increment is 2 x numColsA)
* - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
*/ */
pInA = pSrcA->pData + i;
pInA2 = pInA + numColsA;
pInB = pSrcBT + j;
pInB2 = pInB + numRowsB;
q31_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
q31_t const *pInA0 = pInA; pSrcAVec = (q31_t const *) pInA;
q31_t const *pInA1 = pInA0 + numColsA; pSrcA2Vec = (q31_t const *) pInA2;
q31_t const *pInA2 = pInA1 + numColsA; pSrcBVec = (q31_t const *) pInB;
q31_t const *pInA3 = pInA2 + numColsA; pSrcB2Vec = (q31_t const *) pInB2;
q63_t acc0, acc1, acc2, acc3;
acc0 = 0LL; acc0 = 0LL;
acc1 = 0LL; acc1 = 0LL;
acc2 = 0LL; acc2 = 0LL;
acc3 = 0LL; acc3 = 0LL;
pSrcA0Vec = (q31_t const *) pInA0; /* load scheduling */
pSrcA1Vec = (q31_t const *) pInA1; vecA = vld1q(pSrcAVec);
pSrcA2Vec = (q31_t const *) pInA2; pSrcAVec += 4;
pSrcA3Vec = (q31_t const *) pInA3;
vecOffs = vecColBOffs;
/* process 1 x 4 block output */
blkCnt = numColsA >> 2;
while (blkCnt > 0U)
{
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset(pInB, vecOffs); blkCnt = (numColsA / 4);
/* move Matrix B read offsets, 4 rows down */ while (blkCnt > 0U) {
vecOffs = vecOffs + (uint32_t) (numColsB * 4); vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB); acc0 = vrmlaldavhaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4; vecA2 = vld1q(pSrcA2Vec);
acc1 = vrmlaldavhaq(acc1, vecA, vecB); pSrcA2Vec += 4;
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4; acc1 = vrmlaldavhaq(acc1, vecA2, vecB);
acc2 = vrmlaldavhaq(acc2, vecA, vecB); vecB2 = vld1q(pSrcB2Vec);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4; pSrcB2Vec += 4;
acc3 = vrmlaldavhaq(acc3, vecA, vecB); acc2 = vrmlaldavhaq(acc2, vecA, vecB2);
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
acc3 = vrmlaldavhaq(acc3, vecA2, vecB2);
blkCnt--; blkCnt--;
} }
/* /*
* tail * tail
* (will be merged thru tail predication) * (will be merged thru tail predication)
*/ */
blkCnt = numColsA & 3; blkCnt = (numColsA & 3);
if (blkCnt > 0U) if (blkCnt > 0U) {
{
mve_pred16_t p0 = vctp32q(blkCnt); mve_pred16_t p0 = vctp32q(blkCnt);
q31x4_t vecB, vecA; vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0); vecA2 = vld1q(pSrcA2Vec);
//vecOffs = vecOffs + (uint32_t) (numColsB * 4); acc1 = vrmlaldavhaq_p(acc1, vecA2, vecB, p0);
vecB2 = vld1q(pSrcB2Vec);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4; acc2 = vrmlaldavhaq_p(acc2, vecA, vecB2, p0);
acc0 = vrmlaldavhaq(acc0, vecA, vecB); vecA = vld1q(pSrcAVec);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4; acc3 = vrmlaldavhaq_p(acc3, vecA2, vecB2, p0);
acc1 = vrmlaldavhaq(acc1, vecA, vecB);
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
acc2 = vrmlaldavhaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
acc3 = vrmlaldavhaq(acc3, vecA, vecB);
} }
/* Convert to 1.31 */
acc0 = asrl(acc0, 23); acc0 = asrl(acc0, 23);
acc1 = asrl(acc1, 23); acc1 = asrl(acc1, 23);
acc2 = asrl(acc2, 23); acc2 = asrl(acc2, 23);
acc3 = asrl(acc3, 23); acc3 = asrl(acc3, 23);
px[0] = (q31_t) acc0; /* Store the results (2 x 2 block) in the destination buffer */
px[1 * numColsB] = (q31_t) acc1; *px++ = (q31_t) acc0;
px[2 * numColsB] = (q31_t) acc2; *px++ = (q31_t) acc2;
px[3 * numColsB] = (q31_t) acc3; *px2++ = (q31_t) acc1;
px++; *px2++ = (q31_t) acc3;
j += numRowsB * 2;
/* /*
* Decrement the column loop counter * Decrement the column pair loop counter
*/ */
col--; col--;
}
i = i + numColsA * 2;
px = px2 + (numColsB & 1u);
px2 = px + numColsB;
/* /*
* Update the pointer pInB to point to the starting address of the next column * Decrement the row pair loop counter
*/ */
pInB = (q31_t const *)pSrcB->pData + (numColsB - col); row--;
} }
/* /*
* Update the pointer pInA to point to the starting address of the next row * Compute remaining row and/or column below
*/ */
pInA += (numColsA * 4); if (numColsB & 1u) {
row = numRowsA & (~0x1); //avoid redundant computation
px = pDst->pData + numColsB - 1;
i = 0;
/* /*
* Decrement the row loop counter * row loop
*/ */
rowCnt --; while (row > 0) {
q31_t const *pSrcAVec, *pSrcBVec;
q31x4_t vecA, vecB;
q63_t acc0;
/*
* point to last column in matrix B
*/
pInB = pSrcBT + numRowsB * (numColsB - 1);
pInA = pSrcA->pData + i;
pSrcAVec = (q31_t const *) pInA;
pSrcBVec = (q31_t const *) pInB;
/* single dot-product */
acc0 = 0LL;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--;
} }
rowCnt = row & 3;
while (rowCnt > 0U)
{
/* /*
* Output pointer is set to starting address of the row being processed * tail
* (will be merged thru tail predication)
*/ */
px = pOut + i; blkCnt = (numColsA & 3);
i = i + numColsB; if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
}
acc0 = asrl(acc0, 23);
*px = (q31_t) acc0;
px += numColsB;
i += numColsA;
/* /*
* For every row wise process, the column loop counter is to be initiated * Decrement the row loop counter
*/ */
row--;
}
}
if (numRowsA & 1u) {
col = numColsB; col = numColsB;
i = 0u;
/* /*
* For every row wise process, the pInB pointer is set * point to last row in output matrix
* to the starting address of the pSrcB data
*/ */
pInB = (q31_t const *)pSrcB->pData; px = pDst->pData + (numColsB) * (numRowsA - 1);
/* /*
* column loop * col loop
*/ */
while (col > 0U) while (col > 0) {
{ q31_t const *pSrcAVec, *pSrcBVec;
q31x4_t vecA, vecB;
q63_t acc0;
/* /*
* generate 4 columns elements * point to last row in matrix A
*/ */
pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
pInB = pSrcBT + i;
/* /*
* Matrix A columns number of MAC operations are to be performed * Set the variable sum, that acts as accumulator, to zero
*/ */
pSrcAVec = (q31_t const *) pInA;
q31_t const *pSrcA0Vec; pSrcBVec = (q31_t const *) pInB;
q31_t const *pInA0 = pInA;
q63_t acc0;
acc0 = 0LL; acc0 = 0LL;
blkCnt = (numColsA / 4);
pSrcA0Vec = (q31_t const *) pInA0; while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
vecOffs = vecColBOffs; pSrcAVec += 4;
vecB = vld1q(pSrcBVec);
/* process 1 x 4 block output */ pSrcBVec += 4;
blkCnt = numColsA >> 2;
while (blkCnt > 0U)
{
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB); acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--; blkCnt--;
} }
/* /*
* tail * tail
* (will be merged thru tail predication) * (will be merged thru tail predication)
*/ */
blkCnt = numColsA & 3; blkCnt = (numColsA & 3);
if (blkCnt > 0U) if (blkCnt > 0U) {
{
mve_pred16_t p0 = vctp32q(blkCnt); mve_pred16_t p0 = vctp32q(blkCnt);
q31x4_t vecB, vecA; vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0); acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
//vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec);
pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
} }
acc0 = asrl(acc0, 23); acc0 = asrl(acc0, 23);
*px++ = (q31_t) acc0;
i += numColsA;
px[0] = (q31_t) acc0;
px++;
/* /*
* Decrement the column loop counter * Decrement the col loop counter
*/ */
col--; col--;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
} }
/*
* Update the pointer pInA to point to the starting address of the next row
*/
pInA += numColsA;
/*
* Decrement the row loop counter
*/
rowCnt--;
} }
/* Set status as ARM_MATH_SUCCESS */
/*
* set status as ARM_MATH_SUCCESS
*/
status = ARM_MATH_SUCCESS; status = ARM_MATH_SUCCESS;
} }
/*
/* Return to application */ * Return to application
*/
return (status); return (status);
} }

Loading…
Cancel
Save