CMSIS-DSP : faster Q.15/Q.31 Helium matrix multiplications. Uses an initial transpose stage, requiring extra scratch space to hold RHS transposed matrix.

pull/19/head
FabKlein 4 years ago committed by Christophe Favergeon
parent c520fb08f4
commit cfc30c12b8

@ -3,8 +3,8 @@
* Title: arm_mat_mult_q15.c
* Description: Q15 matrix multiplication
*
* $Date: 23 April 2021
* $Revision: V1.9.0
* $Date: 3 Nov 2021
* $Revision: V1.10.0
*
* Target Processor: Cortex-M and Cortex-A cores
* -------------------------------------------------------------------- */
@ -315,24 +315,28 @@ __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve(
return (ARM_MATH_SUCCESS);
}
arm_status arm_mat_mult_q15(
const arm_matrix_instance_q15 * pSrcA,
const arm_matrix_instance_q15 * pSrcB,
arm_matrix_instance_q15 * pDst,
q15_t * pState)
{
q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
q15_t *pOut = pDst->pData; /* output data matrix pointer */
q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q15_t *pInA2;
q15_t *pInB2;
q15_t *px; /* Temporary output data matrix pointer */
uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint16_t col, i = 0U, row = numRowsA; /* loop counters */
uint16x8_t vecOffs, vecColBOffs;
uint32_t blkCnt,rowCnt; /* loop counters */
q15_t *px2; /* Temporary output data matrix pointer */
uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
q15_t *pSrcBT = pState; /* input data matrix pointer for transpose */
uint32_t blkCnt; /* loop counters */
arm_status status; /* Status of matrix multiplication */
(void)pState;
arm_matrix_instance_q15 BT;
#ifdef ARM_MATH_MATRIX_CHECK
@ -348,246 +352,271 @@ arm_status arm_mat_mult_q15(
#endif
{
/* small squared matrix specialized routines */
if(numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == 1)
{
if (numRowsA == 1) {
q63_t sum;
sum = pInA[0] * pInB[0];
pOut[0] = (q15_t) __SSAT((sum >> 15), 16);
pDst->pData[0] = (q15_t) __SSAT((sum >> 15), 16);
return (ARM_MATH_SUCCESS);
}
else if(numRowsA == 2)
} else if (numRowsA == 2)
return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst);
else if(numRowsA == 3)
else if (numRowsA == 3)
return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst);
else if (numRowsA == 4)
return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst);
}
vecColBOffs = vidupq_u16((uint32_t)0, 1);
vecColBOffs = vecColBOffs * (uint16_t) (numColsB);
/*
* Matrix transpose
*/
BT.numRows = numColsB;
BT.numCols = numRowsB;
BT.pData = pSrcBT;
arm_mat_trans_q15(pSrcB, &BT);
/*
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
* Reset the variables for the usage in the following multiplication process
*/
i = 0;
row = numRowsA >> 1;
px = pDst->pData;
px2 = px + numColsB;
/*
* row loop
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
*/
rowCnt = row >> 2;
while (rowCnt > 0U)
{
/*
* Output pointer is set to starting address of the row being processed
* row loop
*/
px = pOut + i;
i = i + 4 * numColsB;
while (row > 0u) {
/*
* For every row wise process, the column loop counter is to be initiated
*/
col = numColsB;
col = numColsB >> 1;
/*
* For every row wise process, the pInB pointer is set
* to the starting address of the pSrcB data
* For every row wise process, the pIn2 pointer is set
* to the starting address of the transposed pSrcB data
*/
pInB = pSrcB->pData;
pInB = pSrcBT;
pInB2 = pInB + numRowsB;
j = 0;
/*
* column loop
*/
while (col > 0U)
{
/*
* generate 4 columns elements
*/
while (col > 0u) {
q15_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
q15x8_t vecA, vecA2, vecB, vecB2;
q63_t acc0, acc1, acc2, acc3;
/*
* Matrix A columns number of MAC operations are to be performed
* Initiate the pointer pIn1 to point to the starting address of the column being processed
*/
pInA = pSrcA->pData + i;
pInA2 = pInA + numColsA;
pInB = pSrcBT + j;
pInB2 = pInB + numRowsB;
q15_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
q15_t *pInA0 = pInA;
q15_t *pInA1 = pInA0 + numColsA;
q15_t *pInA2 = pInA1 + numColsA;
q15_t *pInA3 = pInA2 + numColsA;
q63_t acc0, acc1, acc2, acc3;
pSrcAVec = (q15_t const *) pInA;
pSrcA2Vec = (q15_t const *) pInA2;
pSrcBVec = (q15_t const *) pInB;
pSrcB2Vec = (q15_t const *) pInB2;
acc0 = 0LL;
acc1 = 0LL;
acc2 = 0LL;
acc3 = 0LL;
pSrcA0Vec = (q15_t const *) pInA0;
pSrcA1Vec = (q15_t const *) pInA1;
pSrcA2Vec = (q15_t const *) pInA2;
pSrcA3Vec = (q15_t const *) pInA3;
vecOffs = vecColBOffs;
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
blkCnt = (numColsA) >> 3;
while (blkCnt > 0U)
{
q15x8_t vecB, vecA;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8;
blkCnt = numColsA / 8;
while (blkCnt > 0U) {
vecB = vld1q(pSrcBVec);
pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8;
acc1 = vmlaldavaq(acc1, vecA, vecB);
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8;
acc2 = vmlaldavaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8;
acc3 = vmlaldavaq(acc3, vecA, vecB);
blkCnt--;
vecA2 = vld1q(pSrcA2Vec);
pSrcA2Vec += 8;
acc1 = vmlaldavaq(acc1, vecA2, vecB);
vecB2 = vld1q(pSrcB2Vec);
pSrcB2Vec += 8;
acc2 = vmlaldavaq(acc2, vecA, vecB2);
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
acc3 = vmlaldavaq(acc3, vecA2, vecB2);
blkCnt--;
}
/*
* tail
*/
blkCnt = numColsA & 7;
if (blkCnt > 0U)
{
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp16q(blkCnt);
q15x8_t vecB, vecA;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec);
vecB = vld1q(pSrcBVec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
vecA = vld1q(pSrcA1Vec);
acc1 = vmlaldavaq_p(acc1, vecA, vecB, p0);
vecA = vld1q(pSrcA2Vec);
acc2 = vmlaldavaq_p(acc2, vecA, vecB, p0);
vecA = vld1q(pSrcA3Vec);
acc3 = vmlaldavaq_p(acc3, vecA, vecB, p0);
vecA2 = vld1q(pSrcA2Vec);
acc1 = vmlaldavaq_p(acc1, vecA2, vecB, p0);
vecB2 = vld1q(pSrcB2Vec);
acc2 = vmlaldavaq_p(acc2, vecA, vecB2, p0);
vecA = vld1q(pSrcAVec);
acc3 = vmlaldavaq_p(acc3, vecA2, vecB2, p0);
}
px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15);
px[1 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc1, 15);
px[2 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc2, 15);
px[3 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc3, 15);
px++;
*px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
*px++ = (q15_t) MVE_ASRL_SAT16(acc2, 15);
*px2++ = (q15_t) MVE_ASRL_SAT16(acc1, 15);
*px2++ = (q15_t) MVE_ASRL_SAT16(acc3, 15);
j += numRowsB * 2;
/*
* Decrement the column loop counter
*/
col--;
}
i = i + numColsA * 2;
px = px2 + (numColsB & 1u);
px2 = px + numColsB;
/*
* Update the pointer pInB to point to the starting address of the next column
* Decrement the row loop counter
*/
pInB = pSrcB->pData + (numColsB - col);
row--;
}
/*
* Update the pointer pInA to point to the starting address of the next row
* Compute remaining row and/or column below
*/
pInA += (numColsA * 4);
if (numColsB & 1u) {
row = numRowsA & (~0x1); //avoid redundant computation
px = pDst->pData + numColsB - 1;
i = 0;
/*
* Decrement the row loop counter
* row loop
*/
rowCnt --;
while (row > 0) {
q15_t const *pSrcAVec, *pSrcBVec;
q15x8_t vecA, vecB;
q63_t acc0;
}
/*
* point to last column in matrix B
*/
pInB = pSrcBT + numRowsB * (numColsB - 1);
pInA = pSrcA->pData + i;
rowCnt = row & 3;
while (rowCnt > 0U)
{
pSrcAVec = (q15_t const *) pInA;
pSrcBVec = (q15_t const *) pInB;
acc0 = 0LL;
blkCnt = (numColsA) / 8;
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
vecB = vld1q(pSrcBVec);
pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
blkCnt--;
}
/*
* Output pointer is set to starting address of the row being processed
* tail
*/
px = pOut + i;
i = i + numColsB;
blkCnt = (numColsA & 7);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp16q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
}
*px = (q15_t) MVE_ASRL_SAT16(acc0, 15);
px += numColsB;
i += numColsA;
/*
* For every row wise process, the column loop counter is to be initiated
* Decrement the row loop counter
*/
row--;
}
}
if (numRowsA & 1u) {
col = numColsB;
i = 0u;
/*
* For every row wise process, the pInB pointer is set
* to the starting address of the pSrcB data
* point to last row in output matrix
*/
pInB = pSrcB->pData;
px = pDst->pData + (numColsB) * (numRowsA - 1);
/*
* column loop
* col loop
*/
while (col > 0U)
{
while (col > 0) {
q15_t const *pSrcAVec, *pSrcBVec;
q15x8_t vecA, vecB;
q63_t acc0;
/*
* generate 4 columns elements
* point to last row in matrix A
*/
pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
pInB = pSrcBT + i;
/*
* Matrix A columns number of MAC operations are to be performed
* Set the variable sum, that acts as accumulator, to zero
*/
q15_t const *pSrcA0Vec;
q15_t *pInA0 = pInA;
q63_t acc0;
pSrcAVec = (q15_t const *) pInA;
pSrcBVec = (q15_t const *) pInB;
acc0 = 0LL;
pSrcA0Vec = (q15_t const *) pInA0;
vecOffs = vecColBOffs;
blkCnt = (numColsA) >> 3;
while (blkCnt > 0U)
{
q15x8_t vecB, vecA;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec);
pSrcA0Vec += 8;
blkCnt = ((numColsA) / 8);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
vecB = vld1q(pSrcBVec);
pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
blkCnt--;
}
/*
* tail
*/
blkCnt = numColsA & 7;
if (blkCnt > 0U)
{
blkCnt = (numColsA & 7);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp16q(blkCnt);
q15x8_t vecB, vecA;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
}
px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15);
*px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
i += numColsA;
px++;
/*
* Decrement the column loop counter
* Decrement the col loop counter
*/
col--;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = pSrcB->pData + (numColsB - col);
}
/*
* Update the pointer pInA to point to the starting address of the next row
*/
pInA += (numColsA );
rowCnt--;
}
/* Set status as ARM_MATH_SUCCESS */
status = ARM_MATH_SUCCESS;
}
/* Return to application */
return (status);
}
#else
arm_status arm_mat_mult_q15(
const arm_matrix_instance_q15 * pSrcA,

@ -3,8 +3,8 @@
* Title: arm_mat_mult_q31.c
* Description: Q31 matrix multiplication
*
* $Date: 23 April 2021
* $Revision: V1.9.0
* $Date: 3 Nov 2021
* $Revision: V1.10.0
*
* Target Processor: Cortex-M and Cortex-A cores
* -------------------------------------------------------------------- */
@ -332,44 +332,45 @@ __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve(
return (ARM_MATH_SUCCESS);
}
arm_status arm_mat_mult_q31(
const arm_matrix_instance_q31 * pSrcA,
const arm_matrix_instance_q31 * pSrcB,
arm_matrix_instance_q31 * pDst)
{
q31_t const *pInB = (q31_t const *)pSrcB->pData; /* input data matrix pointer B */
q31_t const *pInA = (q31_t const *)pSrcA->pData; /* input data matrix pointer A */
q31_t *pOut = pDst->pData; /* output data matrix pointer */
q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q31_t *pInA2;
q31_t *pInB2;
q31_t *px; /* Temporary output data matrix pointer */
uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint16_t col, i = 0U, row = numRowsA; /* loop counters */
arm_status status; /* status of matrix multiplication */
uint32x4_t vecOffs, vecColBOffs;
uint32_t blkCnt, rowCnt; /* loop counters */
#ifdef ARM_MATH_MATRIX_CHECK
q31_t *px2; /* Temporary output data matrix pointer */
uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
q31_t State[numRowsB * numColsB * 1];
q31_t *pSrcBT = State; /* input data matrix pointer for transpose */
uint32_t blkCnt; /* loop counters */
arm_status status; /* Status of matrix multiplication */
arm_matrix_instance_q31 BT;
#ifdef ARM_MATH_MATRIX_CHECK
/* Check for matrix mismatch condition */
if ((pSrcA->numCols != pSrcB->numRows) ||
(pSrcA->numRows != pDst->numRows) ||
(pSrcB->numCols != pDst->numCols) )
{
(pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols)) {
/* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH;
}
else
} else
#endif /* #ifdef ARM_MATH_MATRIX_CHECK */
{
/* small squared matrix specialized routines */
if(numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == 1)
{
q63_t sum = (q63_t) *pInA * *pInB;
pOut[0] = (q31_t)(sum >> 31);
pDst->pData[0] = (q31_t)(sum >> 31);
return (ARM_MATH_SUCCESS);
}
else if(numRowsA == 2)
@ -380,245 +381,262 @@ arm_status arm_mat_mult_q31(
return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst);
}
vecColBOffs = vidupq_u32((uint32_t)0, 1);
vecColBOffs = vecColBOffs * (uint32_t) (numColsB);
/*
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
* Matrix transpose
*/
BT.numRows = numColsB;
BT.numCols = numRowsB;
BT.pData = pSrcBT;
arm_mat_trans_q31(pSrcB, &BT);
/*
* row loop
* Reset the variables for the usage in the following multiplication process
*/
rowCnt = row >> 2;
while (rowCnt > 0U)
{
i = 0;
row = numRowsA >> 1;
px = pDst->pData;
px2 = px + numColsB;
/*
* Output pointer is set to starting address of the row being processed
* main loop
* compute 2 x 2 output blocks
* with dot products (Matrix A rows * Transposed MAtrix B rows)
*/
px = pOut + i;
i = i + 4 * numColsB;
while (row > 0u) {
/*
* For every row wise process, the column loop counter is to be initiated
* Compute 2 columns and 2 rows in parrallel
*/
col = numColsB;
/*
* For every row wise process, the pInB pointer is set
* to the starting address of the pSrcB data
*/
pInB = (q31_t const *)pSrcB->pData;
/*
* column loop
*/
while (col > 0U)
{
col = numColsB >> 1;
j = 0;
/*
* generate 4 columns elements
* column pair loop
*/
while (col > 0u) {
q31_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
q31x4_t vecA, vecA2, vecB, vecB2;
q63_t acc0, acc1, acc2, acc3;
/*
* Matrix A columns number of MAC operations are to be performed
* Initiate the pointers
* - 2 x consecutive Matrix A rows (i increment is 2 x numColsA)
* - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
*/
pInA = pSrcA->pData + i;
pInA2 = pInA + numColsA;
pInB = pSrcBT + j;
pInB2 = pInB + numRowsB;
q31_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
q31_t const *pInA0 = pInA;
q31_t const *pInA1 = pInA0 + numColsA;
q31_t const *pInA2 = pInA1 + numColsA;
q31_t const *pInA3 = pInA2 + numColsA;
q63_t acc0, acc1, acc2, acc3;
pSrcAVec = (q31_t const *) pInA;
pSrcA2Vec = (q31_t const *) pInA2;
pSrcBVec = (q31_t const *) pInB;
pSrcB2Vec = (q31_t const *) pInB2;
acc0 = 0LL;
acc1 = 0LL;
acc2 = 0LL;
acc3 = 0LL;
pSrcA0Vec = (q31_t const *) pInA0;
pSrcA1Vec = (q31_t const *) pInA1;
pSrcA2Vec = (q31_t const *) pInA2;
pSrcA3Vec = (q31_t const *) pInA3;
vecOffs = vecColBOffs;
/* load scheduling */
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
/* process 1 x 4 block output */
blkCnt = numColsA >> 2;
while (blkCnt > 0U)
{
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
acc1 = vrmlaldavhaq(acc1, vecA, vecB);
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
acc2 = vrmlaldavhaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
acc3 = vrmlaldavhaq(acc3, vecA, vecB);
vecA2 = vld1q(pSrcA2Vec);
pSrcA2Vec += 4;
acc1 = vrmlaldavhaq(acc1, vecA2, vecB);
vecB2 = vld1q(pSrcB2Vec);
pSrcB2Vec += 4;
acc2 = vrmlaldavhaq(acc2, vecA, vecB2);
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
acc3 = vrmlaldavhaq(acc3, vecA2, vecB2);
blkCnt--;
}
/*
* tail
* (will be merged thru tail predication)
*/
blkCnt = numColsA & 3;
if (blkCnt > 0U)
{
blkCnt = (numColsA & 3);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
//vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
acc1 = vrmlaldavhaq(acc1, vecA, vecB);
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
acc2 = vrmlaldavhaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
acc3 = vrmlaldavhaq(acc3, vecA, vecB);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
vecA2 = vld1q(pSrcA2Vec);
acc1 = vrmlaldavhaq_p(acc1, vecA2, vecB, p0);
vecB2 = vld1q(pSrcB2Vec);
acc2 = vrmlaldavhaq_p(acc2, vecA, vecB2, p0);
vecA = vld1q(pSrcAVec);
acc3 = vrmlaldavhaq_p(acc3, vecA2, vecB2, p0);
}
/* Convert to 1.31 */
acc0 = asrl(acc0, 23);
acc1 = asrl(acc1, 23);
acc2 = asrl(acc2, 23);
acc3 = asrl(acc3, 23);
px[0] = (q31_t) acc0;
px[1 * numColsB] = (q31_t) acc1;
px[2 * numColsB] = (q31_t) acc2;
px[3 * numColsB] = (q31_t) acc3;
px++;
/* Store the results (2 x 2 block) in the destination buffer */
*px++ = (q31_t) acc0;
*px++ = (q31_t) acc2;
*px2++ = (q31_t) acc1;
*px2++ = (q31_t) acc3;
j += numRowsB * 2;
/*
* Decrement the column loop counter
* Decrement the column pair loop counter
*/
col--;
}
i = i + numColsA * 2;
px = px2 + (numColsB & 1u);
px2 = px + numColsB;
/*
* Update the pointer pInB to point to the starting address of the next column
* Decrement the row pair loop counter
*/
pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
row--;
}
/*
* Update the pointer pInA to point to the starting address of the next row
* Compute remaining row and/or column below
*/
pInA += (numColsA * 4);
if (numColsB & 1u) {
row = numRowsA & (~0x1); //avoid redundant computation
px = pDst->pData + numColsB - 1;
i = 0;
/*
* Decrement the row loop counter
* row loop
*/
rowCnt --;
while (row > 0) {
q31_t const *pSrcAVec, *pSrcBVec;
q31x4_t vecA, vecB;
q63_t acc0;
/*
* point to last column in matrix B
*/
pInB = pSrcBT + numRowsB * (numColsB - 1);
pInA = pSrcA->pData + i;
pSrcAVec = (q31_t const *) pInA;
pSrcBVec = (q31_t const *) pInB;
/* single dot-product */
acc0 = 0LL;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--;
}
rowCnt = row & 3;
while (rowCnt > 0U)
{
/*
* Output pointer is set to starting address of the row being processed
* tail
* (will be merged thru tail predication)
*/
px = pOut + i;
i = i + numColsB;
blkCnt = (numColsA & 3);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
}
acc0 = asrl(acc0, 23);
*px = (q31_t) acc0;
px += numColsB;
i += numColsA;
/*
* For every row wise process, the column loop counter is to be initiated
* Decrement the row loop counter
*/
row--;
}
}
if (numRowsA & 1u) {
col = numColsB;
i = 0u;
/*
* For every row wise process, the pInB pointer is set
* to the starting address of the pSrcB data
* point to last row in output matrix
*/
pInB = (q31_t const *)pSrcB->pData;
px = pDst->pData + (numColsB) * (numRowsA - 1);
/*
* column loop
* col loop
*/
while (col > 0U)
{
while (col > 0) {
q31_t const *pSrcAVec, *pSrcBVec;
q31x4_t vecA, vecB;
q63_t acc0;
/*
* generate 4 columns elements
* point to last row in matrix A
*/
pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
pInB = pSrcBT + i;
/*
* Matrix A columns number of MAC operations are to be performed
* Set the variable sum, that acts as accumulator, to zero
*/
q31_t const *pSrcA0Vec;
q31_t const *pInA0 = pInA;
q63_t acc0;
pSrcAVec = (q31_t const *) pInA;
pSrcBVec = (q31_t const *) pInB;
acc0 = 0LL;
pSrcA0Vec = (q31_t const *) pInA0;
vecOffs = vecColBOffs;
/* process 1 x 4 block output */
blkCnt = numColsA >> 2;
while (blkCnt > 0U)
{
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--;
}
/*
* tail
* (will be merged thru tail predication)
*/
blkCnt = numColsA & 3;
if (blkCnt > 0U)
{
blkCnt = (numColsA & 3);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
//vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec);
pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
}
acc0 = asrl(acc0, 23);
*px++ = (q31_t) acc0;
px[0] = (q31_t) acc0;
px++;
i += numColsA;
/*
* Decrement the column loop counter
* Decrement the col loop counter
*/
col--;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
}
/*
* Update the pointer pInA to point to the starting address of the next row
*/
pInA += numColsA;
/*
* Decrement the row loop counter
*/
rowCnt--;
}
/*
* set status as ARM_MATH_SUCCESS
*/
/* Set status as ARM_MATH_SUCCESS */
status = ARM_MATH_SUCCESS;
}
/* Return to application */
/*
* Return to application
*/
return (status);
}

Loading…
Cancel
Save