@ -3,8 +3,8 @@
* Title : arm_mat_mult_q31 . c
* Description : Q31 matrix multiplication
*
* $ Date : 23 April 2021
* $ Revision : V1 . 9 .0
* $ Date : 3 Nov 2021
* $ Revision : V1 . 10 .0
*
* Target Processor : Cortex - M and Cortex - A cores
* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
@ -332,44 +332,45 @@ __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve(
return ( ARM_MATH_SUCCESS ) ;
}
arm_status arm_mat_mult_q31 (
const arm_matrix_instance_q31 * pSrcA ,
const arm_matrix_instance_q31 * pSrcB ,
arm_matrix_instance_q31 * pDst )
{
q31_t const * pInB = ( q31_t const * ) pSrcB - > pData ; /* input data matrix pointer B */
q31_t const * pInA = ( q31_t const * ) pSrcA - > pData ; /* input data matrix pointer A */
q31_t * pOut = pDst - > pData ; /* output data matrix pointer */
q31_t * pInA = pSrcA - > pData ; /* input data matrix pointer A */
q31_t * pInB = pSrcB - > pData ; /* input data matrix pointer B */
q31_t * pInA2 ;
q31_t * pInB2 ;
q31_t * px ; /* Temporary output data matrix pointer */
uint16_t numRowsA = pSrcA - > numRows ; /* number of rows of input matrix A */
uint16_t numColsB = pSrcB - > numCols ; /* number of columns of input matrix B */
uint16_t numColsA = pSrcA - > numCols ; /* number of columns of input matrix A */
uint16_t col , i = 0U , row = numRowsA ; /* loop counters */
arm_status status ; /* status of matrix multiplication */
uint32x4_t vecOffs , vecColBOffs ;
uint32_t blkCnt , rowCnt ; /* loop counters */
# ifdef ARM_MATH_MATRIX_CHECK
q31_t * px2 ; /* Temporary output data matrix pointer */
uint32_t numRowsA = pSrcA - > numRows ; /* number of rows of input matrix A */
uint32_t numColsB = pSrcB - > numCols ; /* number of columns of input matrix B */
uint32_t numColsA = pSrcA - > numCols ; /* number of columns of input matrix A */
uint32_t numRowsB = pSrcB - > numRows ; /* number of rows of input matrix A */
uint32_t col , i = 0u , j , row = numRowsB ; /* loop counters */
q31_t State [ numRowsB * numColsB * 1 ] ;
q31_t * pSrcBT = State ; /* input data matrix pointer for transpose */
uint32_t blkCnt ; /* loop counters */
arm_status status ; /* Status of matrix multiplication */
arm_matrix_instance_q31 BT ;
# ifdef ARM_MATH_MATRIX_CHECK
/* Check for matrix mismatch condition */
if ( ( pSrcA - > numCols ! = pSrcB - > numRows ) | |
( pSrcA - > numRows ! = pDst - > numRows ) | |
( pSrcB - > numCols ! = pDst - > numCols ) )
{
( pSrcA - > numRows ! = pDst - > numRows ) | | ( pSrcB - > numCols ! = pDst - > numCols ) ) {
/* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH ;
}
else
} else
# endif /* #ifdef ARM_MATH_MATRIX_CHECK */
{
/* small squared matrix specialized routines */
if ( numRowsA = = numColsB & & numColsB = = numColsA ) {
if ( numRowsA = = 1 )
{
q63_t sum = ( q63_t ) * pInA * * pInB ;
p Out [ 0 ] = ( q31_t ) ( sum > > 31 ) ;
p Dst- > pData [ 0 ] = ( q31_t ) ( sum > > 31 ) ;
return ( ARM_MATH_SUCCESS ) ;
}
else if ( numRowsA = = 2 )
@ -380,245 +381,262 @@ arm_status arm_mat_mult_q31(
return arm_mat_mult_q31_4x4_mve ( pSrcA , pSrcB , pDst ) ;
}
vecColBOffs = vidupq_u32 ( ( uint32_t ) 0 , 1 ) ;
vecColBOffs = vecColBOffs * ( uint32_t ) ( numColsB ) ;
/*
* The following loop performs the dot - product of each row in pSrcA with each column in pSrcB
* Matrix transpose
*/
BT . numRows = numColsB ;
BT . numCols = numRowsB ;
BT . pData = pSrcBT ;
arm_mat_trans_q31 ( pSrcB , & BT ) ;
/*
* row loop
* Reset the variables for the usage in the following multiplication process
*/
rowCnt = row > > 2 ;
while ( rowCnt > 0U )
{
i = 0 ;
row = numRowsA > > 1 ;
px = pDst - > pData ;
px2 = px + numColsB ;
/*
* Output pointer is set to starting address of the row being processed
* main loop
* compute 2 x 2 output blocks
* with dot products ( Matrix A rows * Transposed MAtrix B rows )
*/
px = pOut + i ;
i = i + 4 * numColsB ;
while ( row > 0u ) {
/*
* For every row wise process , the column loop counter is to be initiated
* Compute 2 columns and 2 rows in parrallel
*/
col = numColsB ;
/*
* For every row wise process , the pInB pointer is set
* to the starting address of the pSrcB data
*/
pInB = ( q31_t const * ) pSrcB - > pData ;
/*
* column loop
*/
while ( col > 0U )
{
col = numColsB > > 1 ;
j = 0 ;
/*
* generate 4 columns elements
* column pair loop
*/
while ( col > 0u ) {
q31_t const * pSrcAVec , * pSrcBVec , * pSrcA2Vec , * pSrcB2Vec ;
q31x4_t vecA , vecA2 , vecB , vecB2 ;
q63_t acc0 , acc1 , acc2 , acc3 ;
/*
* Matrix A columns number of MAC operations are to be performed
* Initiate the pointers
* - 2 x consecutive Matrix A rows ( i increment is 2 x numColsA )
* - 2 x consecutive Matrix B ' rows ( j increment is 2 x numRowsB )
*/
pInA = pSrcA - > pData + i ;
pInA2 = pInA + numColsA ;
pInB = pSrcBT + j ;
pInB2 = pInB + numRowsB ;
q31_t const * pSrcA0Vec , * pSrcA1Vec , * pSrcA2Vec , * pSrcA3Vec ;
q31_t const * pInA0 = pInA ;
q31_t const * pInA1 = pInA0 + numColsA ;
q31_t const * pInA2 = pInA1 + numColsA ;
q31_t const * pInA3 = pInA2 + numColsA ;
q63_t acc0 , acc1 , acc2 , acc3 ;
pSrcAVec = ( q31_t const * ) pInA ;
pSrcA2Vec = ( q31_t const * ) pInA2 ;
pSrcBVec = ( q31_t const * ) pInB ;
pSrcB2Vec = ( q31_t const * ) pInB2 ;
acc0 = 0LL ;
acc1 = 0LL ;
acc2 = 0LL ;
acc3 = 0LL ;
pSrcA0Vec = ( q31_t const * ) pInA0 ;
pSrcA1Vec = ( q31_t const * ) pInA1 ;
pSrcA2Vec = ( q31_t const * ) pInA2 ;
pSrcA3Vec = ( q31_t const * ) pInA3 ;
vecOffs = vecColBOffs ;
/* load scheduling */
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 4 ;
/* process 1 x 4 block output */
blkCnt = numColsA > > 2 ;
while ( blkCnt > 0U )
{
q31x4_t vecB , vecA ;
vecB = vldrwq_gather_shifted_offset ( pInB , vecOffs ) ;
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + ( uint32_t ) ( numColsB * 4 ) ;
vecA = vld1q ( pSrcA0Vec ) ; pSrcA0Vec + = 4 ;
blkCnt = ( numColsA / 4 ) ;
while ( blkCnt > 0U ) {
vecB = vld1q ( pSrcBVec ) ;
pSrcBVec + = 4 ;
acc0 = vrmlaldavhaq ( acc0 , vecA , vecB ) ;
vecA = vld1q ( pSrcA1Vec ) ; pSrcA1Vec + = 4 ;
acc1 = vrmlaldavhaq ( acc1 , vecA , vecB ) ;
vecA = vld1q ( pSrcA2Vec ) ; pSrcA2Vec + = 4 ;
acc2 = vrmlaldavhaq ( acc2 , vecA , vecB ) ;
vecA = vld1q ( pSrcA3Vec ) ; pSrcA3Vec + = 4 ;
acc3 = vrmlaldavhaq ( acc3 , vecA , vecB ) ;
vecA2 = vld1q ( pSrcA2Vec ) ;
pSrcA2Vec + = 4 ;
acc1 = vrmlaldavhaq ( acc1 , vecA2 , vecB ) ;
vecB2 = vld1q ( pSrcB2Vec ) ;
pSrcB2Vec + = 4 ;
acc2 = vrmlaldavhaq ( acc2 , vecA , vecB2 ) ;
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 4 ;
acc3 = vrmlaldavhaq ( acc3 , vecA2 , vecB2 ) ;
blkCnt - - ;
}
/*
* tail
* ( will be merged thru tail predication )
*/
blkCnt = numColsA & 3 ;
if ( blkCnt > 0U )
{
blkCnt = ( numColsA & 3 ) ;
if ( blkCnt > 0U ) {
mve_pred16_t p0 = vctp32q ( blkCnt ) ;
q31x4_t vecB , vecA ;
vecB = vldrwq_gather_shifted_offset_z ( pInB , vecOffs , p0 ) ;
//vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q ( pSrcA0Vec ) ; pSrcA0Vec + = 4 ;
acc0 = vrmlaldavhaq ( acc0 , vecA , vecB ) ;
vecA = vld1q ( pSrcA1Vec ) ; pSrcA1Vec + = 4 ;
acc1 = vrmlaldavhaq ( acc1 , vecA , vecB ) ;
vecA = vld1q ( pSrcA2Vec ) ; pSrcA2Vec + = 4 ;
acc2 = vrmlaldavhaq ( acc2 , vecA , vecB ) ;
vecA = vld1q ( pSrcA3Vec ) ; pSrcA3Vec + = 4 ;
acc3 = vrmlaldavhaq ( acc3 , vecA , vecB ) ;
vecB = vld1q ( pSrcBVec ) ;
acc0 = vrmlaldavhaq_p ( acc0 , vecA , vecB , p0 ) ;
vecA2 = vld1q ( pSrcA2Vec ) ;
acc1 = vrmlaldavhaq_p ( acc1 , vecA2 , vecB , p0 ) ;
vecB2 = vld1q ( pSrcB2Vec ) ;
acc2 = vrmlaldavhaq_p ( acc2 , vecA , vecB2 , p0 ) ;
vecA = vld1q ( pSrcAVec ) ;
acc3 = vrmlaldavhaq_p ( acc3 , vecA2 , vecB2 , p0 ) ;
}
/* Convert to 1.31 */
acc0 = asrl ( acc0 , 23 ) ;
acc1 = asrl ( acc1 , 23 ) ;
acc2 = asrl ( acc2 , 23 ) ;
acc3 = asrl ( acc3 , 23 ) ;
px [ 0 ] = ( q31_t ) acc0 ;
px [ 1 * numColsB ] = ( q31_t ) acc1 ;
px [ 2 * numColsB ] = ( q31_t ) acc2 ;
px [ 3 * numColsB ] = ( q31_t ) acc3 ;
px + + ;
/* Store the results (2 x 2 block) in the destination buffer */
* px + + = ( q31_t ) acc0 ;
* px + + = ( q31_t ) acc2 ;
* px2 + + = ( q31_t ) acc1 ;
* px2 + + = ( q31_t ) acc3 ;
j + = numRowsB * 2 ;
/*
* Decrement the column loop counter
* Decrement the column pair loop counter
*/
col - - ;
}
i = i + numColsA * 2 ;
px = px2 + ( numColsB & 1u ) ;
px2 = px + numColsB ;
/*
* Update the pointer pInB to point to the starting address of the next column
* Decrement the row pair loop counter
*/
pInB = ( q31_t const * ) pSrcB - > pData + ( numColsB - col ) ;
row- - ;
}
/*
* Update the pointer pInA to point to the starting address of the next r ow
* Compute remaining row and / or column bel ow
*/
pInA + = ( numColsA * 4 ) ;
if ( numColsB & 1u ) {
row = numRowsA & ( ~ 0x1 ) ; //avoid redundant computation
px = pDst - > pData + numColsB - 1 ;
i = 0 ;
/*
* Decrement the row loop counter
* row loop
*/
rowCnt - - ;
while ( row > 0 ) {
q31_t const * pSrcAVec , * pSrcBVec ;
q31x4_t vecA , vecB ;
q63_t acc0 ;
/*
* point to last column in matrix B
*/
pInB = pSrcBT + numRowsB * ( numColsB - 1 ) ;
pInA = pSrcA - > pData + i ;
pSrcAVec = ( q31_t const * ) pInA ;
pSrcBVec = ( q31_t const * ) pInB ;
/* single dot-product */
acc0 = 0LL ;
blkCnt = ( numColsA / 4 ) ;
while ( blkCnt > 0U ) {
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 4 ;
vecB = vld1q ( pSrcBVec ) ;
pSrcBVec + = 4 ;
acc0 = vrmlaldavhaq ( acc0 , vecA , vecB ) ;
blkCnt - - ;
}
rowCnt = row & 3 ;
while ( rowCnt > 0U )
{
/*
* Output pointer is set to starting address of the row being processed
* tail
* ( will be merged thru tail predication )
*/
px = pOut + i ;
i = i + numColsB ;
blkCnt = ( numColsA & 3 ) ;
if ( blkCnt > 0U ) {
mve_pred16_t p0 = vctp32q ( blkCnt ) ;
vecA = vld1q ( pSrcAVec ) ;
vecB = vld1q ( pSrcBVec ) ;
acc0 = vrmlaldavhaq_p ( acc0 , vecA , vecB , p0 ) ;
}
acc0 = asrl ( acc0 , 23 ) ;
* px = ( q31_t ) acc0 ;
px + = numColsB ;
i + = numColsA ;
/*
* For every row wise process , the column loop counter is to be initiated
* Decrement the row loop counter
*/
row - - ;
}
}
if ( numRowsA & 1u ) {
col = numColsB ;
i = 0u ;
/*
* For every row wise process , the pInB pointer is set
* to the starting address of the pSrcB data
* point to last row in output matrix
*/
pInB = ( q31_t const * ) pSrcB - > pData ;
px = pDst - > pData + ( numColsB ) * ( numRowsA - 1 ) ;
/*
* column loop
* col loop
*/
while ( col > 0U )
{
while ( col > 0 ) {
q31_t const * pSrcAVec , * pSrcBVec ;
q31x4_t vecA , vecB ;
q63_t acc0 ;
/*
* generate 4 columns elements
* point to last row in matrix A
*/
pInA = pSrcA - > pData + ( numRowsA - 1 ) * numColsA ;
pInB = pSrcBT + i ;
/*
* Matrix A columns number of MAC operations are to be performed
* Set the variable sum , that acts as accumulator , to zero
*/
q31_t const * pSrcA0Vec ;
q31_t const * pInA0 = pInA ;
q63_t acc0 ;
pSrcAVec = ( q31_t const * ) pInA ;
pSrcBVec = ( q31_t const * ) pInB ;
acc0 = 0LL ;
pSrcA0Vec = ( q31_t const * ) pInA0 ;
vecOffs = vecColBOffs ;
/* process 1 x 4 block output */
blkCnt = numColsA > > 2 ;
while ( blkCnt > 0U )
{
q31x4_t vecB , vecA ;
vecB = vldrwq_gather_shifted_offset ( pInB , vecOffs ) ;
/* move Matrix B read offsets, 4 rows down */
vecOffs = vecOffs + ( uint32_t ) ( numColsB * 4 ) ;
vecA = vld1q ( pSrcA0Vec ) ; pSrcA0Vec + = 4 ;
blkCnt = ( numColsA / 4 ) ;
while ( blkCnt > 0U ) {
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 4 ;
vecB = vld1q ( pSrcBVec ) ;
pSrcBVec + = 4 ;
acc0 = vrmlaldavhaq ( acc0 , vecA , vecB ) ;
blkCnt - - ;
}
/*
* tail
* ( will be merged thru tail predication )
*/
blkCnt = numColsA & 3 ;
if ( blkCnt > 0U )
{
blkCnt = ( numColsA & 3 ) ;
if ( blkCnt > 0U ) {
mve_pred16_t p0 = vctp32q ( blkCnt ) ;
q31x4_t vecB , vecA ;
vecB = vldrwq_gather_shifted_offset_z ( pInB , vecOffs , p0 ) ;
//vecOffs = vecOffs + (uint32_t) (numColsB * 4);
vecA = vld1q ( pSrcA0Vec ) ;
pSrcA0Vec + = 4 ;
acc0 = vrmlaldavhaq ( acc0 , vecA , vecB ) ;
vecA = vld1q ( pSrcAVec ) ;
vecB = vld1q ( pSrcBVec ) ;
acc0 = vrmlaldavhaq_p ( acc0 , vecA , vecB , p0 ) ;
}
acc0 = asrl ( acc0 , 23 ) ;
* px + + = ( q31_t ) acc0 ;
px [ 0 ] = ( q31_t ) acc0 ;
px + + ;
i + = numColsA ;
/*
* Decrement the col umn loop counter
* Decrement the col loop counter
*/
col - - ;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = ( q31_t const * ) pSrcB - > pData + ( numColsB - col ) ;
}
/*
* Update the pointer pInA to point to the starting address of the next row
*/
pInA + = numColsA ;
/*
* Decrement the row loop counter
*/
rowCnt - - ;
}
/*
* set status as ARM_MATH_SUCCESS
*/
/* Set status as ARM_MATH_SUCCESS */
status = ARM_MATH_SUCCESS ;
}
/* Return to application */
/*
* Return to application
*/
return ( status ) ;
}