From cfc30c12b88c0d9b1bd3365478d6c592e73ce537 Mon Sep 17 00:00:00 2001 From: FabKlein Date: Mon, 8 Nov 2021 17:07:27 +0000 Subject: [PATCH] CMSIS-DSP : faster Q.15/Q.31 Helium matrix multiplications. Uses an initial transpose stage, requiring extra scratch space to hold RHS transposed matrix. --- Source/MatrixFunctions/arm_mat_mult_q15.c | 485 +++++++++++---------- Source/MatrixFunctions/arm_mat_mult_q31.c | 496 +++++++++++----------- 2 files changed, 514 insertions(+), 467 deletions(-) diff --git a/Source/MatrixFunctions/arm_mat_mult_q15.c b/Source/MatrixFunctions/arm_mat_mult_q15.c index 8eed6ee5..9219ed02 100644 --- a/Source/MatrixFunctions/arm_mat_mult_q15.c +++ b/Source/MatrixFunctions/arm_mat_mult_q15.c @@ -3,8 +3,8 @@ * Title: arm_mat_mult_q15.c * Description: Q15 matrix multiplication * - * $Date: 23 April 2021 - * $Revision: V1.9.0 + * $Date: 3 Nov 2021 + * $Revision: V1.10.0 * * Target Processor: Cortex-M and Cortex-A cores * -------------------------------------------------------------------- */ @@ -315,279 +315,308 @@ __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve( return (ARM_MATH_SUCCESS); } + arm_status arm_mat_mult_q15( - const arm_matrix_instance_q15 * pSrcA, - const arm_matrix_instance_q15 * pSrcB, - arm_matrix_instance_q15 * pDst, - q15_t * pState) + const arm_matrix_instance_q15 * pSrcA, + const arm_matrix_instance_q15 * pSrcB, + arm_matrix_instance_q15 * pDst, + q15_t * pState) { - q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */ - q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */ - q15_t *pOut = pDst->pData; /* output data matrix pointer */ - q15_t *px; /* Temporary output data matrix pointer */ - uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ - uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ - uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ - uint16_t col, i = 0U, row = numRowsA; /* loop counters */ - uint16x8_t vecOffs, vecColBOffs; - uint32_t blkCnt,rowCnt; /* loop counters */ - arm_status status; /* Status of matrix multiplication */ - (void)pState; + q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */ + q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */ + q15_t *pInA2; + q15_t *pInB2; + q15_t *px; /* Temporary output data matrix pointer */ + q15_t *px2; /* Temporary output data matrix pointer */ + uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ + uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ + uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ + uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */ + uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */ + q15_t *pSrcBT = pState; /* input data matrix pointer for transpose */ + uint32_t blkCnt; /* loop counters */ + arm_status status; /* Status of matrix multiplication */ + arm_matrix_instance_q15 BT; #ifdef ARM_MATH_MATRIX_CHECK - /* Check for matrix mismatch condition */ - if ((pSrcA->numCols != pSrcB->numRows) || + /* Check for matrix mismatch condition */ + if ((pSrcA->numCols != pSrcB->numRows) || (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols) ) - { - /* Set status as ARM_MATH_SIZE_MISMATCH */ - status = ARM_MATH_SIZE_MISMATCH; - } - else + { + /* Set status as ARM_MATH_SIZE_MISMATCH */ + status = ARM_MATH_SIZE_MISMATCH; + } + else #endif - { - /* small squared matrix specialized routines */ - if(numRowsA == numColsB && numColsB == numColsA) { - - if (numRowsA == 1) - { - q63_t sum; - sum = pInA[0] * pInB[0]; - pOut[0] = (q15_t) __SSAT((sum >> 15), 16); - return (ARM_MATH_SUCCESS); + { + /* small squared matrix specialized routines */ + if (numRowsA == numColsB && numColsB == numColsA) { + + if (numRowsA == 1) { + q63_t sum; + sum = pInA[0] * pInB[0]; + pDst->pData[0] = (q15_t) __SSAT((sum >> 15), 16); + return (ARM_MATH_SUCCESS); + } else if (numRowsA == 2) + return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst); + else if (numRowsA == 3) + return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst); + else if (numRowsA == 4) + return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst); } - else if(numRowsA == 2) - return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst); - else if(numRowsA == 3) - return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst); - else if (numRowsA == 4) - return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst); - } - vecColBOffs = vidupq_u16((uint32_t)0, 1); - vecColBOffs = vecColBOffs * (uint16_t) (numColsB); - - /* - * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB - */ - - /* - * row loop - */ - rowCnt = row >> 2; - while (rowCnt > 0U) - { /* - * Output pointer is set to starting address of the row being processed + * Matrix transpose */ - px = pOut + i; - i = i + 4 * numColsB; + + BT.numRows = numColsB; + BT.numCols = numRowsB; + BT.pData = pSrcBT; + + arm_mat_trans_q15(pSrcB, &BT); + + /* - * For every row wise process, the column loop counter is to be initiated + * Reset the variables for the usage in the following multiplication process */ - col = numColsB; + i = 0; + row = numRowsA >> 1; + px = pDst->pData; + px2 = px + numColsB; + /* - * For every row wise process, the pInB pointer is set - * to the starting address of the pSrcB data + * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */ - pInB = pSrcB->pData; + /* - * column loop + * row loop */ - while (col > 0U) - { + while (row > 0u) { /* - * generate 4 columns elements + * For every row wise process, the column loop counter is to be initiated */ + col = numColsB >> 1; /* - * Matrix A columns number of MAC operations are to be performed + * For every row wise process, the pIn2 pointer is set + * to the starting address of the transposed pSrcB data */ + pInB = pSrcBT; + pInB2 = pInB + numRowsB; + j = 0; - q15_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec; - q15_t *pInA0 = pInA; - q15_t *pInA1 = pInA0 + numColsA; - q15_t *pInA2 = pInA1 + numColsA; - q15_t *pInA3 = pInA2 + numColsA; - q63_t acc0, acc1, acc2, acc3; - - acc0 = 0LL; - acc1 = 0LL; - acc2 = 0LL; - acc3 = 0LL; - - pSrcA0Vec = (q15_t const *) pInA0; - pSrcA1Vec = (q15_t const *) pInA1; - pSrcA2Vec = (q15_t const *) pInA2; - pSrcA3Vec = (q15_t const *) pInA3; - - vecOffs = vecColBOffs; - - blkCnt = (numColsA) >> 3; - while (blkCnt > 0U) - { - q15x8_t vecB, vecA; - - vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); - vecOffs = vecOffs + (uint16_t) (numColsB * 8); - - vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8; - acc0 = vmlaldavaq(acc0, vecA, vecB); - vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8; - acc1 = vmlaldavaq(acc1, vecA, vecB); - vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8; - acc2 = vmlaldavaq(acc2, vecA, vecB); - vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8; - acc3 = vmlaldavaq(acc3, vecA, vecB); - blkCnt--; - - } /* - * tail + * column loop */ - blkCnt = numColsA & 7; - if (blkCnt > 0U) - { - mve_pred16_t p0 = vctp16q(blkCnt); - q15x8_t vecB, vecA; - - vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); - vecOffs = vecOffs + (uint16_t) (numColsB * 8); - - vecA = vld1q(pSrcA0Vec); - acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); - vecA = vld1q(pSrcA1Vec); - acc1 = vmlaldavaq_p(acc1, vecA, vecB, p0); - vecA = vld1q(pSrcA2Vec); - acc2 = vmlaldavaq_p(acc2, vecA, vecB, p0); - vecA = vld1q(pSrcA3Vec); - acc3 = vmlaldavaq_p(acc3, vecA, vecB, p0); + while (col > 0u) { + q15_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec; + q15x8_t vecA, vecA2, vecB, vecB2; + q63_t acc0, acc1, acc2, acc3; + + /* + * Initiate the pointer pIn1 to point to the starting address of the column being processed + */ + pInA = pSrcA->pData + i; + pInA2 = pInA + numColsA; + pInB = pSrcBT + j; + pInB2 = pInB + numRowsB; + + + pSrcAVec = (q15_t const *) pInA; + pSrcA2Vec = (q15_t const *) pInA2; + pSrcBVec = (q15_t const *) pInB; + pSrcB2Vec = (q15_t const *) pInB2; + + acc0 = 0LL; + acc1 = 0LL; + acc2 = 0LL; + acc3 = 0LL; + + vecA = vld1q(pSrcAVec); + pSrcAVec += 8; + + blkCnt = numColsA / 8; + while (blkCnt > 0U) { + vecB = vld1q(pSrcBVec); + pSrcBVec += 8; + acc0 = vmlaldavaq(acc0, vecA, vecB); + vecA2 = vld1q(pSrcA2Vec); + pSrcA2Vec += 8; + acc1 = vmlaldavaq(acc1, vecA2, vecB); + vecB2 = vld1q(pSrcB2Vec); + pSrcB2Vec += 8; + acc2 = vmlaldavaq(acc2, vecA, vecB2); + vecA = vld1q(pSrcAVec); + pSrcAVec += 8; + acc3 = vmlaldavaq(acc3, vecA2, vecB2); + + blkCnt--; + } + /* + * tail + */ + blkCnt = numColsA & 7; + if (blkCnt > 0U) { + mve_pred16_t p0 = vctp16q(blkCnt); + vecB = vld1q(pSrcBVec); + acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); + vecA2 = vld1q(pSrcA2Vec); + acc1 = vmlaldavaq_p(acc1, vecA2, vecB, p0); + vecB2 = vld1q(pSrcB2Vec); + acc2 = vmlaldavaq_p(acc2, vecA, vecB2, p0); + vecA = vld1q(pSrcAVec); + acc3 = vmlaldavaq_p(acc3, vecA2, vecB2, p0); + } + + *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15); + *px++ = (q15_t) MVE_ASRL_SAT16(acc2, 15); + *px2++ = (q15_t) MVE_ASRL_SAT16(acc1, 15); + *px2++ = (q15_t) MVE_ASRL_SAT16(acc3, 15); + j += numRowsB * 2; + /* + * Decrement the column loop counter + */ + col--; + } - px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15); - px[1 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc1, 15); - px[2 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc2, 15); - px[3 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc3, 15); - px++; + i = i + numColsA * 2; + px = px2 + (numColsB & 1u); + px2 = px + numColsB; /* - * Decrement the column loop counter + * Decrement the row loop counter */ - col--; - /* - * Update the pointer pInB to point to the starting address of the next column - */ - pInB = pSrcB->pData + (numColsB - col); + row--; } /* - * Update the pointer pInA to point to the starting address of the next row + * Compute remaining row and/or column below */ - pInA += (numColsA * 4); - /* - * Decrement the row loop counter - */ - rowCnt --; - } + if (numColsB & 1u) { + row = numRowsA & (~0x1); //avoid redundant computation + px = pDst->pData + numColsB - 1; + i = 0; - rowCnt = row & 3; - while (rowCnt > 0U) - { - /* - * Output pointer is set to starting address of the row being processed - */ - px = pOut + i; - i = i + numColsB; - /* - * For every row wise process, the column loop counter is to be initiated - */ - col = numColsB; - /* - * For every row wise process, the pInB pointer is set - * to the starting address of the pSrcB data - */ - pInB = pSrcB->pData; - /* - * column loop - */ - while (col > 0U) - { /* - * generate 4 columns elements + * row loop */ - /* - * Matrix A columns number of MAC operations are to be performed - */ - - q15_t const *pSrcA0Vec; - q15_t *pInA0 = pInA; - q63_t acc0; - - acc0 = 0LL; - - pSrcA0Vec = (q15_t const *) pInA0; - - vecOffs = vecColBOffs; - - blkCnt = (numColsA) >> 3; - while (blkCnt > 0U) - { - q15x8_t vecB, vecA; - - vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); - vecOffs = vecOffs + (uint16_t) (numColsB * 8); - - vecA = vld1q(pSrcA0Vec); - pSrcA0Vec += 8; - acc0 = vmlaldavaq(acc0, vecA, vecB); - - blkCnt--; - - } - /* - * tail - */ - blkCnt = numColsA & 7; - if (blkCnt > 0U) - { - mve_pred16_t p0 = vctp16q(blkCnt); - q15x8_t vecB, vecA; - - vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); - vecOffs = vecOffs + (uint16_t) (numColsB * 8); - - vecA = vld1q(pSrcA0Vec); - acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); - + while (row > 0) { + q15_t const *pSrcAVec, *pSrcBVec; + q15x8_t vecA, vecB; + q63_t acc0; + + /* + * point to last column in matrix B + */ + pInB = pSrcBT + numRowsB * (numColsB - 1); + pInA = pSrcA->pData + i; + + pSrcAVec = (q15_t const *) pInA; + pSrcBVec = (q15_t const *) pInB; + + acc0 = 0LL; + blkCnt = (numColsA) / 8; + while (blkCnt > 0U) { + vecA = vld1q(pSrcAVec); + pSrcAVec += 8; + vecB = vld1q(pSrcBVec); + pSrcBVec += 8; + acc0 = vmlaldavaq(acc0, vecA, vecB); + + blkCnt--; + } + /* + * tail + */ + blkCnt = (numColsA & 7); + if (blkCnt > 0U) { + mve_pred16_t p0 = vctp16q(blkCnt); + vecA = vld1q(pSrcAVec); + vecB = vld1q(pSrcBVec); + acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); + } + + *px = (q15_t) MVE_ASRL_SAT16(acc0, 15); + + px += numColsB; + + i += numColsA; + /* + * Decrement the row loop counter + */ + row--; } + } - px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15); - - px++; + if (numRowsA & 1u) { + col = numColsB; + i = 0u; /* - * Decrement the column loop counter + * point to last row in output matrix */ - col--; + px = pDst->pData + (numColsB) * (numRowsA - 1); /* - * Update the pointer pInB to point to the starting address of the next column + * col loop */ - pInB = pSrcB->pData + (numColsB - col); + while (col > 0) { + q15_t const *pSrcAVec, *pSrcBVec; + q15x8_t vecA, vecB; + q63_t acc0; + + /* + * point to last row in matrix A + */ + pInA = pSrcA->pData + (numRowsA - 1) * numColsA; + pInB = pSrcBT + i; + + /* + * Set the variable sum, that acts as accumulator, to zero + */ + pSrcAVec = (q15_t const *) pInA; + pSrcBVec = (q15_t const *) pInB; + acc0 = 0LL; + + blkCnt = ((numColsA) / 8); + while (blkCnt > 0U) { + vecA = vld1q(pSrcAVec); + pSrcAVec += 8; + vecB = vld1q(pSrcBVec); + pSrcBVec += 8; + acc0 = vmlaldavaq(acc0, vecA, vecB); + + blkCnt--; + } + /* + * tail + */ + blkCnt = (numColsA & 7); + if (blkCnt > 0U) { + mve_pred16_t p0 = vctp16q(blkCnt); + vecA = vld1q(pSrcAVec); + vecB = vld1q(pSrcBVec); + acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); + } + + *px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15); + + i += numColsA; + + /* + * Decrement the col loop counter + */ + col--; + } } - /* - * Update the pointer pInA to point to the starting address of the next row - */ - pInA += (numColsA ); - rowCnt--; + /* Set status as ARM_MATH_SUCCESS */ + status = ARM_MATH_SUCCESS; } - /* Set status as ARM_MATH_SUCCESS */ - status = ARM_MATH_SUCCESS; - } - - /* Return to application */ - return (status); - + /* Return to application */ + return (status); } + #else arm_status arm_mat_mult_q15( const arm_matrix_instance_q15 * pSrcA, diff --git a/Source/MatrixFunctions/arm_mat_mult_q31.c b/Source/MatrixFunctions/arm_mat_mult_q31.c index 18738279..08001cc2 100644 --- a/Source/MatrixFunctions/arm_mat_mult_q31.c +++ b/Source/MatrixFunctions/arm_mat_mult_q31.c @@ -3,8 +3,8 @@ * Title: arm_mat_mult_q31.c * Description: Q31 matrix multiplication * - * $Date: 23 April 2021 - * $Revision: V1.9.0 + * $Date: 3 Nov 2021 + * $Revision: V1.10.0 * * Target Processor: Cortex-M and Cortex-A cores * -------------------------------------------------------------------- */ @@ -332,44 +332,45 @@ __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve( return (ARM_MATH_SUCCESS); } + arm_status arm_mat_mult_q31( - const arm_matrix_instance_q31 * pSrcA, - const arm_matrix_instance_q31 * pSrcB, - arm_matrix_instance_q31 * pDst) + const arm_matrix_instance_q31 * pSrcA, + const arm_matrix_instance_q31 * pSrcB, + arm_matrix_instance_q31 * pDst) { - q31_t const *pInB = (q31_t const *)pSrcB->pData; /* input data matrix pointer B */ - q31_t const *pInA = (q31_t const *)pSrcA->pData; /* input data matrix pointer A */ - q31_t *pOut = pDst->pData; /* output data matrix pointer */ - q31_t *px; /* Temporary output data matrix pointer */ - uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ - uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ - uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ - uint16_t col, i = 0U, row = numRowsA; /* loop counters */ - arm_status status; /* status of matrix multiplication */ - uint32x4_t vecOffs, vecColBOffs; - uint32_t blkCnt, rowCnt; /* loop counters */ - - #ifdef ARM_MATH_MATRIX_CHECK - - /* Check for matrix mismatch condition */ - if ((pSrcA->numCols != pSrcB->numRows) || - (pSrcA->numRows != pDst->numRows) || - (pSrcB->numCols != pDst->numCols) ) - { - /* Set status as ARM_MATH_SIZE_MISMATCH */ - status = ARM_MATH_SIZE_MISMATCH; - } - else + q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */ + q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */ + q31_t *pInA2; + q31_t *pInB2; + q31_t *px; /* Temporary output data matrix pointer */ + q31_t *px2; /* Temporary output data matrix pointer */ + uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ + uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ + uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ + uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */ + uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */ + q31_t State[numRowsB * numColsB * 1]; + q31_t *pSrcBT = State; /* input data matrix pointer for transpose */ + uint32_t blkCnt; /* loop counters */ + arm_status status; /* Status of matrix multiplication */ + arm_matrix_instance_q31 BT; +#ifdef ARM_MATH_MATRIX_CHECK -#endif /* #ifdef ARM_MATH_MATRIX_CHECK */ + /* Check for matrix mismatch condition */ + if ((pSrcA->numCols != pSrcB->numRows) || + (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols)) { + /* Set status as ARM_MATH_SIZE_MISMATCH */ + status = ARM_MATH_SIZE_MISMATCH; + } else +#endif /* #ifdef ARM_MATH_MATRIX_CHECK */ + { - { - /* small squared matrix specialized routines */ + /* small squared matrix specialized routines */ if(numRowsA == numColsB && numColsB == numColsA) { if (numRowsA == 1) { q63_t sum = (q63_t) *pInA * *pInB; - pOut[0] = (q31_t)(sum >> 31); + pDst->pData[0] = (q31_t)(sum >> 31); return (ARM_MATH_SUCCESS); } else if(numRowsA == 2) @@ -380,246 +381,263 @@ arm_status arm_mat_mult_q31( return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst); } - vecColBOffs = vidupq_u32((uint32_t)0, 1); - vecColBOffs = vecColBOffs * (uint32_t) (numColsB); - /* - * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB - */ - - /* - * row loop - */ - rowCnt = row >> 2; - while (rowCnt > 0U) - { /* - * Output pointer is set to starting address of the row being processed + * Matrix transpose */ - px = pOut + i; - i = i + 4 * numColsB; - /* - * For every row wise process, the column loop counter is to be initiated - */ - col = numColsB; + BT.numRows = numColsB; + BT.numCols = numRowsB; + BT.pData = pSrcBT; + + arm_mat_trans_q31(pSrcB, &BT); + + /* - * For every row wise process, the pInB pointer is set - * to the starting address of the pSrcB data + * Reset the variables for the usage in the following multiplication process */ - pInB = (q31_t const *)pSrcB->pData; + i = 0; + row = numRowsA >> 1; + px = pDst->pData; + px2 = px + numColsB; + /* - * column loop + * main loop + * compute 2 x 2 output blocks + * with dot products (Matrix A rows * Transposed MAtrix B rows) */ - while (col > 0U) - { - /* - * generate 4 columns elements - */ + while (row > 0u) { /* - * Matrix A columns number of MAC operations are to be performed + * For every row wise process, the column loop counter is to be initiated + * Compute 2 columns and 2 rows in parrallel */ - - q31_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec; - q31_t const *pInA0 = pInA; - q31_t const *pInA1 = pInA0 + numColsA; - q31_t const *pInA2 = pInA1 + numColsA; - q31_t const *pInA3 = pInA2 + numColsA; - q63_t acc0, acc1, acc2, acc3; - - acc0 = 0LL; - acc1 = 0LL; - acc2 = 0LL; - acc3 = 0LL; - - pSrcA0Vec = (q31_t const *) pInA0; - pSrcA1Vec = (q31_t const *) pInA1; - pSrcA2Vec = (q31_t const *) pInA2; - pSrcA3Vec = (q31_t const *) pInA3; - - vecOffs = vecColBOffs; - - /* process 1 x 4 block output */ - blkCnt = numColsA >> 2; - while (blkCnt > 0U) - { - q31x4_t vecB, vecA; - - vecB = vldrwq_gather_shifted_offset(pInB, vecOffs); - /* move Matrix B read offsets, 4 rows down */ - vecOffs = vecOffs + (uint32_t) (numColsB * 4); - - vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4; - acc0 = vrmlaldavhaq(acc0, vecA, vecB); - vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4; - acc1 = vrmlaldavhaq(acc1, vecA, vecB); - vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4; - acc2 = vrmlaldavhaq(acc2, vecA, vecB); - vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4; - acc3 = vrmlaldavhaq(acc3, vecA, vecB); - blkCnt--; - } + col = numColsB >> 1; + j = 0; /* - * tail - * (will be merged thru tail predication) + * column pair loop */ - blkCnt = numColsA & 3; - if (blkCnt > 0U) - { - mve_pred16_t p0 = vctp32q(blkCnt); - q31x4_t vecB, vecA; - - vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0); - //vecOffs = vecOffs + (uint32_t) (numColsB * 4); - - vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4; - acc0 = vrmlaldavhaq(acc0, vecA, vecB); - vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4; - acc1 = vrmlaldavhaq(acc1, vecA, vecB); - vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4; - acc2 = vrmlaldavhaq(acc2, vecA, vecB); - vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4; - acc3 = vrmlaldavhaq(acc3, vecA, vecB); - } + while (col > 0u) { + q31_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec; + q31x4_t vecA, vecA2, vecB, vecB2; + q63_t acc0, acc1, acc2, acc3; + + /* + * Initiate the pointers + * - 2 x consecutive Matrix A rows (i increment is 2 x numColsA) + * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB) + */ + pInA = pSrcA->pData + i; + pInA2 = pInA + numColsA; + pInB = pSrcBT + j; + pInB2 = pInB + numRowsB; + + + pSrcAVec = (q31_t const *) pInA; + pSrcA2Vec = (q31_t const *) pInA2; + pSrcBVec = (q31_t const *) pInB; + pSrcB2Vec = (q31_t const *) pInB2; + + acc0 = 0LL; + acc1 = 0LL; + acc2 = 0LL; + acc3 = 0LL; + + /* load scheduling */ + vecA = vld1q(pSrcAVec); + pSrcAVec += 4; + + blkCnt = (numColsA / 4); + while (blkCnt > 0U) { + vecB = vld1q(pSrcBVec); + pSrcBVec += 4; + acc0 = vrmlaldavhaq(acc0, vecA, vecB); + vecA2 = vld1q(pSrcA2Vec); + pSrcA2Vec += 4; + acc1 = vrmlaldavhaq(acc1, vecA2, vecB); + vecB2 = vld1q(pSrcB2Vec); + pSrcB2Vec += 4; + acc2 = vrmlaldavhaq(acc2, vecA, vecB2); + vecA = vld1q(pSrcAVec); + pSrcAVec += 4; + acc3 = vrmlaldavhaq(acc3, vecA2, vecB2); + + blkCnt--; + } + /* + * tail + * (will be merged thru tail predication) + */ + blkCnt = (numColsA & 3); + if (blkCnt > 0U) { + mve_pred16_t p0 = vctp32q(blkCnt); + vecB = vld1q(pSrcBVec); + acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0); + vecA2 = vld1q(pSrcA2Vec); + acc1 = vrmlaldavhaq_p(acc1, vecA2, vecB, p0); + vecB2 = vld1q(pSrcB2Vec); + acc2 = vrmlaldavhaq_p(acc2, vecA, vecB2, p0); + vecA = vld1q(pSrcAVec); + acc3 = vrmlaldavhaq_p(acc3, vecA2, vecB2, p0); + } + + /* Convert to 1.31 */ + acc0 = asrl(acc0, 23); + acc1 = asrl(acc1, 23); + acc2 = asrl(acc2, 23); + acc3 = asrl(acc3, 23); + + /* Store the results (2 x 2 block) in the destination buffer */ + *px++ = (q31_t) acc0; + *px++ = (q31_t) acc2; + *px2++ = (q31_t) acc1; + *px2++ = (q31_t) acc3; + + j += numRowsB * 2; + /* + * Decrement the column pair loop counter + */ + col--; - acc0 = asrl(acc0, 23); - acc1 = asrl(acc1, 23); - acc2 = asrl(acc2, 23); - acc3 = asrl(acc3, 23); + } - px[0] = (q31_t) acc0; - px[1 * numColsB] = (q31_t) acc1; - px[2 * numColsB] = (q31_t) acc2; - px[3 * numColsB] = (q31_t) acc3; - px++; + i = i + numColsA * 2; + px = px2 + (numColsB & 1u); + px2 = px + numColsB; /* - * Decrement the column loop counter + * Decrement the row pair loop counter */ - col--; - /* - * Update the pointer pInB to point to the starting address of the next column - */ - pInB = (q31_t const *)pSrcB->pData + (numColsB - col); + row--; } /* - * Update the pointer pInA to point to the starting address of the next row - */ - pInA += (numColsA * 4); - /* - * Decrement the row loop counter - */ - rowCnt --; - - } - rowCnt = row & 3; - while (rowCnt > 0U) - { - /* - * Output pointer is set to starting address of the row being processed - */ - px = pOut + i; - i = i + numColsB; - /* - * For every row wise process, the column loop counter is to be initiated - */ - col = numColsB; - /* - * For every row wise process, the pInB pointer is set - * to the starting address of the pSrcB data - */ - pInB = (q31_t const *)pSrcB->pData; - /* - * column loop + * Compute remaining row and/or column below */ - while (col > 0U) - { - /* - * generate 4 columns elements - */ - /* - * Matrix A columns number of MAC operations are to be performed - */ - - q31_t const *pSrcA0Vec; - q31_t const *pInA0 = pInA; - q63_t acc0; - - acc0 = 0LL; - - - pSrcA0Vec = (q31_t const *) pInA0; - - vecOffs = vecColBOffs; - - /* process 1 x 4 block output */ - blkCnt = numColsA >> 2; - while (blkCnt > 0U) - { - q31x4_t vecB, vecA; - - vecB = vldrwq_gather_shifted_offset(pInB, vecOffs); - /* move Matrix B read offsets, 4 rows down */ - vecOffs = vecOffs + (uint32_t) (numColsB * 4); - - vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4; - acc0 = vrmlaldavhaq(acc0, vecA, vecB); - - blkCnt--; - } + if (numColsB & 1u) { + row = numRowsA & (~0x1); //avoid redundant computation + px = pDst->pData + numColsB - 1; + i = 0; /* - * tail - * (will be merged thru tail predication) + * row loop */ - blkCnt = numColsA & 3; - if (blkCnt > 0U) - { - mve_pred16_t p0 = vctp32q(blkCnt); - q31x4_t vecB, vecA; - - vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0); - //vecOffs = vecOffs + (uint32_t) (numColsB * 4); - - vecA = vld1q(pSrcA0Vec); - pSrcA0Vec += 4; - acc0 = vrmlaldavhaq(acc0, vecA, vecB); - + while (row > 0) { + q31_t const *pSrcAVec, *pSrcBVec; + q31x4_t vecA, vecB; + q63_t acc0; + + /* + * point to last column in matrix B + */ + pInB = pSrcBT + numRowsB * (numColsB - 1); + pInA = pSrcA->pData + i; + + pSrcAVec = (q31_t const *) pInA; + pSrcBVec = (q31_t const *) pInB; + + /* single dot-product */ + acc0 = 0LL; + blkCnt = (numColsA / 4); + while (blkCnt > 0U) { + vecA = vld1q(pSrcAVec); + pSrcAVec += 4; + vecB = vld1q(pSrcBVec); + pSrcBVec += 4; + acc0 = vrmlaldavhaq(acc0, vecA, vecB); + + blkCnt--; + } + /* + * tail + * (will be merged thru tail predication) + */ + blkCnt = (numColsA & 3); + if (blkCnt > 0U) { + mve_pred16_t p0 = vctp32q(blkCnt); + vecA = vld1q(pSrcAVec); + vecB = vld1q(pSrcBVec); + acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0); + } + + acc0 = asrl(acc0, 23); + *px = (q31_t) acc0; + + px += numColsB; + + i += numColsA; + /* + * Decrement the row loop counter + */ + row--; } + } - acc0 = asrl(acc0, 23); - - - px[0] = (q31_t) acc0; - px++; + if (numRowsA & 1u) { + col = numColsB; + i = 0u; /* - * Decrement the column loop counter + * point to last row in output matrix */ - col--; + px = pDst->pData + (numColsB) * (numRowsA - 1); /* - * Update the pointer pInB to point to the starting address of the next column + * col loop */ - pInB = (q31_t const *)pSrcB->pData + (numColsB - col); + while (col > 0) { + q31_t const *pSrcAVec, *pSrcBVec; + q31x4_t vecA, vecB; + q63_t acc0; + + /* + * point to last row in matrix A + */ + pInA = pSrcA->pData + (numRowsA - 1) * numColsA; + pInB = pSrcBT + i; + + /* + * Set the variable sum, that acts as accumulator, to zero + */ + pSrcAVec = (q31_t const *) pInA; + pSrcBVec = (q31_t const *) pInB; + acc0 = 0LL; + + blkCnt = (numColsA / 4); + while (blkCnt > 0U) { + vecA = vld1q(pSrcAVec); + pSrcAVec += 4; + vecB = vld1q(pSrcBVec); + pSrcBVec += 4; + acc0 = vrmlaldavhaq(acc0, vecA, vecB); + + blkCnt--; + } + /* + * tail + * (will be merged thru tail predication) + */ + blkCnt = (numColsA & 3); + if (blkCnt > 0U) { + mve_pred16_t p0 = vctp32q(blkCnt); + vecA = vld1q(pSrcAVec); + vecB = vld1q(pSrcBVec); + acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0); + } + + acc0 = asrl(acc0, 23); + *px++ = (q31_t) acc0; + + i += numColsA; + /* + * Decrement the col loop counter + */ + col--; + } } - - /* - * Update the pointer pInA to point to the starting address of the next row - */ - pInA += numColsA; - /* - * Decrement the row loop counter - */ - rowCnt--; + /* Set status as ARM_MATH_SUCCESS */ + status = ARM_MATH_SUCCESS; } - /* - * set status as ARM_MATH_SUCCESS + * Return to application */ - status = ARM_MATH_SUCCESS; - } - - /* Return to application */ - return (status); + return (status); } #else