CMSIS-DSP : faster Q.15/Q.31 Helium matrix multiplications. Uses an initial transpose stage, requiring extra scratch space to hold RHS transposed matrix.

pull/19/head
FabKlein 4 years ago committed by Christophe Favergeon
parent c520fb08f4
commit cfc30c12b8

@ -3,8 +3,8 @@
* Title: arm_mat_mult_q15.c * Title: arm_mat_mult_q15.c
* Description: Q15 matrix multiplication * Description: Q15 matrix multiplication
* *
* $Date: 23 April 2021 * $Date: 3 Nov 2021
* $Revision: V1.9.0 * $Revision: V1.10.0
* *
* Target Processor: Cortex-M and Cortex-A cores * Target Processor: Cortex-M and Cortex-A cores
* -------------------------------------------------------------------- */ * -------------------------------------------------------------------- */
@ -315,279 +315,308 @@ __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve(
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} }
arm_status arm_mat_mult_q15( arm_status arm_mat_mult_q15(
const arm_matrix_instance_q15 * pSrcA, const arm_matrix_instance_q15 * pSrcA,
const arm_matrix_instance_q15 * pSrcB, const arm_matrix_instance_q15 * pSrcB,
arm_matrix_instance_q15 * pDst, arm_matrix_instance_q15 * pDst,
q15_t * pState) q15_t * pState)
{ {
q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */ q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */
q15_t *pInA = pSrcA->pData; /* input data matrix pointer A */ q15_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q15_t *pOut = pDst->pData; /* output data matrix pointer */ q15_t *pInA2;
q15_t *px; /* Temporary output data matrix pointer */ q15_t *pInB2;
uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ q15_t *px; /* Temporary output data matrix pointer */
uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ q15_t *px2; /* Temporary output data matrix pointer */
uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint16_t col, i = 0U, row = numRowsA; /* loop counters */ uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
uint16x8_t vecOffs, vecColBOffs; uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint32_t blkCnt,rowCnt; /* loop counters */ uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
arm_status status; /* Status of matrix multiplication */ uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
(void)pState; q15_t *pSrcBT = pState; /* input data matrix pointer for transpose */
uint32_t blkCnt; /* loop counters */
arm_status status; /* Status of matrix multiplication */
arm_matrix_instance_q15 BT;
#ifdef ARM_MATH_MATRIX_CHECK #ifdef ARM_MATH_MATRIX_CHECK
/* Check for matrix mismatch condition */ /* Check for matrix mismatch condition */
if ((pSrcA->numCols != pSrcB->numRows) || if ((pSrcA->numCols != pSrcB->numRows) ||
(pSrcA->numRows != pDst->numRows) || (pSrcA->numRows != pDst->numRows) ||
(pSrcB->numCols != pDst->numCols) ) (pSrcB->numCols != pDst->numCols) )
{ {
/* Set status as ARM_MATH_SIZE_MISMATCH */ /* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH; status = ARM_MATH_SIZE_MISMATCH;
} }
else else
#endif #endif
{ {
/* small squared matrix specialized routines */ /* small squared matrix specialized routines */
if(numRowsA == numColsB && numColsB == numColsA) { if (numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == 1) if (numRowsA == 1) {
{ q63_t sum;
q63_t sum; sum = pInA[0] * pInB[0];
sum = pInA[0] * pInB[0]; pDst->pData[0] = (q15_t) __SSAT((sum >> 15), 16);
pOut[0] = (q15_t) __SSAT((sum >> 15), 16); return (ARM_MATH_SUCCESS);
return (ARM_MATH_SUCCESS); } else if (numRowsA == 2)
return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst);
else if (numRowsA == 3)
return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst);
else if (numRowsA == 4)
return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst);
} }
else if(numRowsA == 2)
return arm_mat_mult_q15_2x2_mve(pSrcA, pSrcB, pDst);
else if(numRowsA == 3)
return arm_mat_mult_q15_3x3_mve(pSrcA, pSrcB, pDst);
else if (numRowsA == 4)
return arm_mat_mult_q15_4x4_mve(pSrcA, pSrcB, pDst);
}
vecColBOffs = vidupq_u16((uint32_t)0, 1);
vecColBOffs = vecColBOffs * (uint16_t) (numColsB);
/*
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
*/
/*
* row loop
*/
rowCnt = row >> 2;
while (rowCnt > 0U)
{
/* /*
* Output pointer is set to starting address of the row being processed * Matrix transpose
*/ */
px = pOut + i;
i = i + 4 * numColsB; BT.numRows = numColsB;
BT.numCols = numRowsB;
BT.pData = pSrcBT;
arm_mat_trans_q15(pSrcB, &BT);
/* /*
* For every row wise process, the column loop counter is to be initiated * Reset the variables for the usage in the following multiplication process
*/ */
col = numColsB; i = 0;
row = numRowsA >> 1;
px = pDst->pData;
px2 = px + numColsB;
/* /*
* For every row wise process, the pInB pointer is set * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
* to the starting address of the pSrcB data
*/ */
pInB = pSrcB->pData;
/* /*
* column loop * row loop
*/ */
while (col > 0U) while (row > 0u) {
{
/* /*
* generate 4 columns elements * For every row wise process, the column loop counter is to be initiated
*/ */
col = numColsB >> 1;
/* /*
* Matrix A columns number of MAC operations are to be performed * For every row wise process, the pIn2 pointer is set
* to the starting address of the transposed pSrcB data
*/ */
pInB = pSrcBT;
pInB2 = pInB + numRowsB;
j = 0;
q15_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
q15_t *pInA0 = pInA;
q15_t *pInA1 = pInA0 + numColsA;
q15_t *pInA2 = pInA1 + numColsA;
q15_t *pInA3 = pInA2 + numColsA;
q63_t acc0, acc1, acc2, acc3;
acc0 = 0LL;
acc1 = 0LL;
acc2 = 0LL;
acc3 = 0LL;
pSrcA0Vec = (q15_t const *) pInA0;
pSrcA1Vec = (q15_t const *) pInA1;
pSrcA2Vec = (q15_t const *) pInA2;
pSrcA3Vec = (q15_t const *) pInA3;
vecOffs = vecColBOffs;
blkCnt = (numColsA) >> 3;
while (blkCnt > 0U)
{
q15x8_t vecB, vecA;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs);
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8;
acc1 = vmlaldavaq(acc1, vecA, vecB);
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8;
acc2 = vmlaldavaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8;
acc3 = vmlaldavaq(acc3, vecA, vecB);
blkCnt--;
}
/* /*
* tail * column loop
*/ */
blkCnt = numColsA & 7; while (col > 0u) {
if (blkCnt > 0U) q15_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
{ q15x8_t vecA, vecA2, vecB, vecB2;
mve_pred16_t p0 = vctp16q(blkCnt); q63_t acc0, acc1, acc2, acc3;
q15x8_t vecB, vecA;
/*
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); * Initiate the pointer pIn1 to point to the starting address of the column being processed
vecOffs = vecOffs + (uint16_t) (numColsB * 8); */
pInA = pSrcA->pData + i;
vecA = vld1q(pSrcA0Vec); pInA2 = pInA + numColsA;
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); pInB = pSrcBT + j;
vecA = vld1q(pSrcA1Vec); pInB2 = pInB + numRowsB;
acc1 = vmlaldavaq_p(acc1, vecA, vecB, p0);
vecA = vld1q(pSrcA2Vec);
acc2 = vmlaldavaq_p(acc2, vecA, vecB, p0); pSrcAVec = (q15_t const *) pInA;
vecA = vld1q(pSrcA3Vec); pSrcA2Vec = (q15_t const *) pInA2;
acc3 = vmlaldavaq_p(acc3, vecA, vecB, p0); pSrcBVec = (q15_t const *) pInB;
pSrcB2Vec = (q15_t const *) pInB2;
acc0 = 0LL;
acc1 = 0LL;
acc2 = 0LL;
acc3 = 0LL;
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
blkCnt = numColsA / 8;
while (blkCnt > 0U) {
vecB = vld1q(pSrcBVec);
pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
vecA2 = vld1q(pSrcA2Vec);
pSrcA2Vec += 8;
acc1 = vmlaldavaq(acc1, vecA2, vecB);
vecB2 = vld1q(pSrcB2Vec);
pSrcB2Vec += 8;
acc2 = vmlaldavaq(acc2, vecA, vecB2);
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
acc3 = vmlaldavaq(acc3, vecA2, vecB2);
blkCnt--;
}
/*
* tail
*/
blkCnt = numColsA & 7;
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp16q(blkCnt);
vecB = vld1q(pSrcBVec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
vecA2 = vld1q(pSrcA2Vec);
acc1 = vmlaldavaq_p(acc1, vecA2, vecB, p0);
vecB2 = vld1q(pSrcB2Vec);
acc2 = vmlaldavaq_p(acc2, vecA, vecB2, p0);
vecA = vld1q(pSrcAVec);
acc3 = vmlaldavaq_p(acc3, vecA2, vecB2, p0);
}
*px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
*px++ = (q15_t) MVE_ASRL_SAT16(acc2, 15);
*px2++ = (q15_t) MVE_ASRL_SAT16(acc1, 15);
*px2++ = (q15_t) MVE_ASRL_SAT16(acc3, 15);
j += numRowsB * 2;
/*
* Decrement the column loop counter
*/
col--;
} }
px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15); i = i + numColsA * 2;
px[1 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc1, 15); px = px2 + (numColsB & 1u);
px[2 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc2, 15); px2 = px + numColsB;
px[3 * numColsB] = (q15_t)MVE_ASRL_SAT16(acc3, 15);
px++;
/* /*
* Decrement the column loop counter * Decrement the row loop counter
*/ */
col--; row--;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = pSrcB->pData + (numColsB - col);
} }
/* /*
* Update the pointer pInA to point to the starting address of the next row * Compute remaining row and/or column below
*/ */
pInA += (numColsA * 4);
/*
* Decrement the row loop counter
*/
rowCnt --;
} if (numColsB & 1u) {
row = numRowsA & (~0x1); //avoid redundant computation
px = pDst->pData + numColsB - 1;
i = 0;
rowCnt = row & 3;
while (rowCnt > 0U)
{
/*
* Output pointer is set to starting address of the row being processed
*/
px = pOut + i;
i = i + numColsB;
/*
* For every row wise process, the column loop counter is to be initiated
*/
col = numColsB;
/*
* For every row wise process, the pInB pointer is set
* to the starting address of the pSrcB data
*/
pInB = pSrcB->pData;
/*
* column loop
*/
while (col > 0U)
{
/* /*
* generate 4 columns elements * row loop
*/ */
/* while (row > 0) {
* Matrix A columns number of MAC operations are to be performed q15_t const *pSrcAVec, *pSrcBVec;
*/ q15x8_t vecA, vecB;
q63_t acc0;
q15_t const *pSrcA0Vec;
q15_t *pInA0 = pInA; /*
q63_t acc0; * point to last column in matrix B
*/
acc0 = 0LL; pInB = pSrcBT + numRowsB * (numColsB - 1);
pInA = pSrcA->pData + i;
pSrcA0Vec = (q15_t const *) pInA0;
pSrcAVec = (q15_t const *) pInA;
vecOffs = vecColBOffs; pSrcBVec = (q15_t const *) pInB;
blkCnt = (numColsA) >> 3; acc0 = 0LL;
while (blkCnt > 0U) blkCnt = (numColsA) / 8;
{ while (blkCnt > 0U) {
q15x8_t vecB, vecA; vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); vecB = vld1q(pSrcBVec);
vecOffs = vecOffs + (uint16_t) (numColsB * 8); pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA0Vec);
pSrcA0Vec += 8; blkCnt--;
acc0 = vmlaldavaq(acc0, vecA, vecB); }
/*
blkCnt--; * tail
*/
} blkCnt = (numColsA & 7);
/* if (blkCnt > 0U) {
* tail mve_pred16_t p0 = vctp16q(blkCnt);
*/ vecA = vld1q(pSrcAVec);
blkCnt = numColsA & 7; vecB = vld1q(pSrcBVec);
if (blkCnt > 0U) acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
{ }
mve_pred16_t p0 = vctp16q(blkCnt);
q15x8_t vecB, vecA; *px = (q15_t) MVE_ASRL_SAT16(acc0, 15);
vecB = vldrhq_gather_shifted_offset((int16_t const *)pInB, vecOffs); px += numColsB;
vecOffs = vecOffs + (uint16_t) (numColsB * 8);
i += numColsA;
vecA = vld1q(pSrcA0Vec); /*
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0); * Decrement the row loop counter
*/
row--;
} }
}
px[0] = (q15_t)MVE_ASRL_SAT16(acc0, 15); if (numRowsA & 1u) {
col = numColsB;
px++; i = 0u;
/* /*
* Decrement the column loop counter * point to last row in output matrix
*/ */
col--; px = pDst->pData + (numColsB) * (numRowsA - 1);
/* /*
* Update the pointer pInB to point to the starting address of the next column * col loop
*/ */
pInB = pSrcB->pData + (numColsB - col); while (col > 0) {
q15_t const *pSrcAVec, *pSrcBVec;
q15x8_t vecA, vecB;
q63_t acc0;
/*
* point to last row in matrix A
*/
pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
pInB = pSrcBT + i;
/*
* Set the variable sum, that acts as accumulator, to zero
*/
pSrcAVec = (q15_t const *) pInA;
pSrcBVec = (q15_t const *) pInB;
acc0 = 0LL;
blkCnt = ((numColsA) / 8);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 8;
vecB = vld1q(pSrcBVec);
pSrcBVec += 8;
acc0 = vmlaldavaq(acc0, vecA, vecB);
blkCnt--;
}
/*
* tail
*/
blkCnt = (numColsA & 7);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp16q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vmlaldavaq_p(acc0, vecA, vecB, p0);
}
*px++ = (q15_t) MVE_ASRL_SAT16(acc0, 15);
i += numColsA;
/*
* Decrement the col loop counter
*/
col--;
}
} }
/* /* Set status as ARM_MATH_SUCCESS */
* Update the pointer pInA to point to the starting address of the next row status = ARM_MATH_SUCCESS;
*/
pInA += (numColsA );
rowCnt--;
} }
/* Set status as ARM_MATH_SUCCESS */ /* Return to application */
status = ARM_MATH_SUCCESS; return (status);
}
/* Return to application */
return (status);
} }
#else #else
arm_status arm_mat_mult_q15( arm_status arm_mat_mult_q15(
const arm_matrix_instance_q15 * pSrcA, const arm_matrix_instance_q15 * pSrcA,

@ -3,8 +3,8 @@
* Title: arm_mat_mult_q31.c * Title: arm_mat_mult_q31.c
* Description: Q31 matrix multiplication * Description: Q31 matrix multiplication
* *
* $Date: 23 April 2021 * $Date: 3 Nov 2021
* $Revision: V1.9.0 * $Revision: V1.10.0
* *
* Target Processor: Cortex-M and Cortex-A cores * Target Processor: Cortex-M and Cortex-A cores
* -------------------------------------------------------------------- */ * -------------------------------------------------------------------- */
@ -332,44 +332,45 @@ __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve(
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} }
arm_status arm_mat_mult_q31( arm_status arm_mat_mult_q31(
const arm_matrix_instance_q31 * pSrcA, const arm_matrix_instance_q31 * pSrcA,
const arm_matrix_instance_q31 * pSrcB, const arm_matrix_instance_q31 * pSrcB,
arm_matrix_instance_q31 * pDst) arm_matrix_instance_q31 * pDst)
{ {
q31_t const *pInB = (q31_t const *)pSrcB->pData; /* input data matrix pointer B */ q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
q31_t const *pInA = (q31_t const *)pSrcA->pData; /* input data matrix pointer A */ q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
q31_t *pOut = pDst->pData; /* output data matrix pointer */ q31_t *pInA2;
q31_t *px; /* Temporary output data matrix pointer */ q31_t *pInB2;
uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ q31_t *px; /* Temporary output data matrix pointer */
uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ q31_t *px2; /* Temporary output data matrix pointer */
uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
uint16_t col, i = 0U, row = numRowsA; /* loop counters */ uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
arm_status status; /* status of matrix multiplication */ uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
uint32x4_t vecOffs, vecColBOffs; uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
uint32_t blkCnt, rowCnt; /* loop counters */ uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
q31_t State[numRowsB * numColsB * 1];
#ifdef ARM_MATH_MATRIX_CHECK q31_t *pSrcBT = State; /* input data matrix pointer for transpose */
uint32_t blkCnt; /* loop counters */
/* Check for matrix mismatch condition */ arm_status status; /* Status of matrix multiplication */
if ((pSrcA->numCols != pSrcB->numRows) || arm_matrix_instance_q31 BT;
(pSrcA->numRows != pDst->numRows) || #ifdef ARM_MATH_MATRIX_CHECK
(pSrcB->numCols != pDst->numCols) )
{
/* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH;
}
else
#endif /* #ifdef ARM_MATH_MATRIX_CHECK */ /* Check for matrix mismatch condition */
if ((pSrcA->numCols != pSrcB->numRows) ||
(pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols)) {
/* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH;
} else
#endif /* #ifdef ARM_MATH_MATRIX_CHECK */
{
{ /* small squared matrix specialized routines */
/* small squared matrix specialized routines */
if(numRowsA == numColsB && numColsB == numColsA) { if(numRowsA == numColsB && numColsB == numColsA) {
if (numRowsA == 1) if (numRowsA == 1)
{ {
q63_t sum = (q63_t) *pInA * *pInB; q63_t sum = (q63_t) *pInA * *pInB;
pOut[0] = (q31_t)(sum >> 31); pDst->pData[0] = (q31_t)(sum >> 31);
return (ARM_MATH_SUCCESS); return (ARM_MATH_SUCCESS);
} }
else if(numRowsA == 2) else if(numRowsA == 2)
@ -380,246 +381,263 @@ arm_status arm_mat_mult_q31(
return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst); return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst);
} }
vecColBOffs = vidupq_u32((uint32_t)0, 1);
vecColBOffs = vecColBOffs * (uint32_t) (numColsB);
/*
* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
*/
/*
* row loop
*/
rowCnt = row >> 2;
while (rowCnt > 0U)
{
/* /*
* Output pointer is set to starting address of the row being processed * Matrix transpose
*/ */
px = pOut + i; BT.numRows = numColsB;
i = i + 4 * numColsB; BT.numCols = numRowsB;
/* BT.pData = pSrcBT;
* For every row wise process, the column loop counter is to be initiated
*/ arm_mat_trans_q31(pSrcB, &BT);
col = numColsB;
/* /*
* For every row wise process, the pInB pointer is set * Reset the variables for the usage in the following multiplication process
* to the starting address of the pSrcB data
*/ */
pInB = (q31_t const *)pSrcB->pData; i = 0;
row = numRowsA >> 1;
px = pDst->pData;
px2 = px + numColsB;
/* /*
* column loop * main loop
* compute 2 x 2 output blocks
* with dot products (Matrix A rows * Transposed MAtrix B rows)
*/ */
while (col > 0U) while (row > 0u) {
{
/*
* generate 4 columns elements
*/
/* /*
* Matrix A columns number of MAC operations are to be performed * For every row wise process, the column loop counter is to be initiated
* Compute 2 columns and 2 rows in parrallel
*/ */
col = numColsB >> 1;
q31_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec; j = 0;
q31_t const *pInA0 = pInA;
q31_t const *pInA1 = pInA0 + numColsA;
q31_t const *pInA2 = pInA1 + numColsA;
q31_t const *pInA3 = pInA2 + numColsA;
q63_t acc0, acc1, acc2, acc3;
acc0 = 0LL;
acc1 = 0LL;
acc2 = 0LL;
acc3 = 0LL;
pSrcA0Vec = (q31_t const *) pInA0;
pSrcA1Vec = (q31_t const *) pInA1;
pSrcA2Vec = (q31_t const *) pInA2;
pSrcA3Vec = (q31_t const *) pInA3;
vecOffs = vecColBOffs;
/* process 1 x 4 block output */
blkCnt = numColsA >> 2;
while (blkCnt > 0U)
{
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
acc1 = vrmlaldavhaq(acc1, vecA, vecB);
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
acc2 = vrmlaldavhaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
acc3 = vrmlaldavhaq(acc3, vecA, vecB);
blkCnt--;
}
/* /*
* tail * column pair loop
* (will be merged thru tail predication)
*/ */
blkCnt = numColsA & 3; while (col > 0u) {
if (blkCnt > 0U) q31_t const *pSrcAVec, *pSrcBVec, *pSrcA2Vec, *pSrcB2Vec;
{ q31x4_t vecA, vecA2, vecB, vecB2;
mve_pred16_t p0 = vctp32q(blkCnt); q63_t acc0, acc1, acc2, acc3;
q31x4_t vecB, vecA;
/*
vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0); * Initiate the pointers
//vecOffs = vecOffs + (uint32_t) (numColsB * 4); * - 2 x consecutive Matrix A rows (i increment is 2 x numColsA)
* - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4; */
acc0 = vrmlaldavhaq(acc0, vecA, vecB); pInA = pSrcA->pData + i;
vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4; pInA2 = pInA + numColsA;
acc1 = vrmlaldavhaq(acc1, vecA, vecB); pInB = pSrcBT + j;
vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4; pInB2 = pInB + numRowsB;
acc2 = vrmlaldavhaq(acc2, vecA, vecB);
vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
acc3 = vrmlaldavhaq(acc3, vecA, vecB); pSrcAVec = (q31_t const *) pInA;
} pSrcA2Vec = (q31_t const *) pInA2;
pSrcBVec = (q31_t const *) pInB;
pSrcB2Vec = (q31_t const *) pInB2;
acc0 = 0LL;
acc1 = 0LL;
acc2 = 0LL;
acc3 = 0LL;
/* load scheduling */
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
vecA2 = vld1q(pSrcA2Vec);
pSrcA2Vec += 4;
acc1 = vrmlaldavhaq(acc1, vecA2, vecB);
vecB2 = vld1q(pSrcB2Vec);
pSrcB2Vec += 4;
acc2 = vrmlaldavhaq(acc2, vecA, vecB2);
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
acc3 = vrmlaldavhaq(acc3, vecA2, vecB2);
blkCnt--;
}
/*
* tail
* (will be merged thru tail predication)
*/
blkCnt = (numColsA & 3);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
vecA2 = vld1q(pSrcA2Vec);
acc1 = vrmlaldavhaq_p(acc1, vecA2, vecB, p0);
vecB2 = vld1q(pSrcB2Vec);
acc2 = vrmlaldavhaq_p(acc2, vecA, vecB2, p0);
vecA = vld1q(pSrcAVec);
acc3 = vrmlaldavhaq_p(acc3, vecA2, vecB2, p0);
}
/* Convert to 1.31 */
acc0 = asrl(acc0, 23);
acc1 = asrl(acc1, 23);
acc2 = asrl(acc2, 23);
acc3 = asrl(acc3, 23);
/* Store the results (2 x 2 block) in the destination buffer */
*px++ = (q31_t) acc0;
*px++ = (q31_t) acc2;
*px2++ = (q31_t) acc1;
*px2++ = (q31_t) acc3;
j += numRowsB * 2;
/*
* Decrement the column pair loop counter
*/
col--;
acc0 = asrl(acc0, 23); }
acc1 = asrl(acc1, 23);
acc2 = asrl(acc2, 23);
acc3 = asrl(acc3, 23);
px[0] = (q31_t) acc0; i = i + numColsA * 2;
px[1 * numColsB] = (q31_t) acc1; px = px2 + (numColsB & 1u);
px[2 * numColsB] = (q31_t) acc2; px2 = px + numColsB;
px[3 * numColsB] = (q31_t) acc3;
px++;
/*
* Decrement the column loop counter
*/
col--;
/* /*
* Update the pointer pInB to point to the starting address of the next column * Decrement the row pair loop counter
*/ */
pInB = (q31_t const *)pSrcB->pData + (numColsB - col); row--;
} }
/* /*
* Update the pointer pInA to point to the starting address of the next row * Compute remaining row and/or column below
*/
pInA += (numColsA * 4);
/*
* Decrement the row loop counter
*/
rowCnt --;
}
rowCnt = row & 3;
while (rowCnt > 0U)
{
/*
* Output pointer is set to starting address of the row being processed
*/
px = pOut + i;
i = i + numColsB;
/*
* For every row wise process, the column loop counter is to be initiated
*/ */
col = numColsB; if (numColsB & 1u) {
/* row = numRowsA & (~0x1); //avoid redundant computation
* For every row wise process, the pInB pointer is set px = pDst->pData + numColsB - 1;
* to the starting address of the pSrcB data i = 0;
*/
pInB = (q31_t const *)pSrcB->pData;
/*
* column loop
*/
while (col > 0U)
{
/*
* generate 4 columns elements
*/
/*
* Matrix A columns number of MAC operations are to be performed
*/
q31_t const *pSrcA0Vec;
q31_t const *pInA0 = pInA;
q63_t acc0;
acc0 = 0LL;
pSrcA0Vec = (q31_t const *) pInA0;
vecOffs = vecColBOffs;
/* process 1 x 4 block output */
blkCnt = numColsA >> 2;
while (blkCnt > 0U)
{
q31x4_t vecB, vecA;
vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--;
}
/* /*
* tail * row loop
* (will be merged thru tail predication)
*/ */
blkCnt = numColsA & 3; while (row > 0) {
if (blkCnt > 0U) q31_t const *pSrcAVec, *pSrcBVec;
{ q31x4_t vecA, vecB;
mve_pred16_t p0 = vctp32q(blkCnt); q63_t acc0;
q31x4_t vecB, vecA;
/*
vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0); * point to last column in matrix B
//vecOffs = vecOffs + (uint32_t) (numColsB * 4); */
pInB = pSrcBT + numRowsB * (numColsB - 1);
vecA = vld1q(pSrcA0Vec); pInA = pSrcA->pData + i;
pSrcA0Vec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB); pSrcAVec = (q31_t const *) pInA;
pSrcBVec = (q31_t const *) pInB;
/* single dot-product */
acc0 = 0LL;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--;
}
/*
* tail
* (will be merged thru tail predication)
*/
blkCnt = (numColsA & 3);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
}
acc0 = asrl(acc0, 23);
*px = (q31_t) acc0;
px += numColsB;
i += numColsA;
/*
* Decrement the row loop counter
*/
row--;
} }
}
acc0 = asrl(acc0, 23); if (numRowsA & 1u) {
col = numColsB;
i = 0u;
px[0] = (q31_t) acc0;
px++;
/* /*
* Decrement the column loop counter * point to last row in output matrix
*/ */
col--; px = pDst->pData + (numColsB) * (numRowsA - 1);
/* /*
* Update the pointer pInB to point to the starting address of the next column * col loop
*/ */
pInB = (q31_t const *)pSrcB->pData + (numColsB - col); while (col > 0) {
q31_t const *pSrcAVec, *pSrcBVec;
q31x4_t vecA, vecB;
q63_t acc0;
/*
* point to last row in matrix A
*/
pInA = pSrcA->pData + (numRowsA - 1) * numColsA;
pInB = pSrcBT + i;
/*
* Set the variable sum, that acts as accumulator, to zero
*/
pSrcAVec = (q31_t const *) pInA;
pSrcBVec = (q31_t const *) pInB;
acc0 = 0LL;
blkCnt = (numColsA / 4);
while (blkCnt > 0U) {
vecA = vld1q(pSrcAVec);
pSrcAVec += 4;
vecB = vld1q(pSrcBVec);
pSrcBVec += 4;
acc0 = vrmlaldavhaq(acc0, vecA, vecB);
blkCnt--;
}
/*
* tail
* (will be merged thru tail predication)
*/
blkCnt = (numColsA & 3);
if (blkCnt > 0U) {
mve_pred16_t p0 = vctp32q(blkCnt);
vecA = vld1q(pSrcAVec);
vecB = vld1q(pSrcBVec);
acc0 = vrmlaldavhaq_p(acc0, vecA, vecB, p0);
}
acc0 = asrl(acc0, 23);
*px++ = (q31_t) acc0;
i += numColsA;
/*
* Decrement the col loop counter
*/
col--;
}
} }
/* Set status as ARM_MATH_SUCCESS */
/* status = ARM_MATH_SUCCESS;
* Update the pointer pInA to point to the starting address of the next row
*/
pInA += numColsA;
/*
* Decrement the row loop counter
*/
rowCnt--;
} }
/* /*
* set status as ARM_MATH_SUCCESS * Return to application
*/ */
status = ARM_MATH_SUCCESS; return (status);
}
/* Return to application */
return (status);
} }
#else #else

Loading…
Cancel
Save