@ -3,8 +3,8 @@
* Title : arm_mat_mult_q15 . c
* Description : Q15 matrix multiplication
*
* $ Date : 23 April 2021
* $ Revision : V1 . 9 .0
* $ Date : 3 Nov 2021
* $ Revision : V1 . 10 .0
*
* Target Processor : Cortex - M and Cortex - A cores
* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
@ -315,279 +315,308 @@ __STATIC_INLINE arm_status arm_mat_mult_q15_4x4_mve(
return ( ARM_MATH_SUCCESS ) ;
}
arm_status arm_mat_mult_q15 (
const arm_matrix_instance_q15 * pSrcA ,
const arm_matrix_instance_q15 * pSrcB ,
arm_matrix_instance_q15 * pDst ,
q15_t * pState )
const arm_matrix_instance_q15 * pSrcA ,
const arm_matrix_instance_q15 * pSrcB ,
arm_matrix_instance_q15 * pDst ,
q15_t * pState )
{
q15_t * pInB = pSrcB - > pData ; /* input data matrix pointer B */
q15_t * pInA = pSrcA - > pData ; /* input data matrix pointer A */
q15_t * pOut = pDst - > pData ; /* output data matrix pointer */
q15_t * px ; /* Temporary output data matrix pointer */
uint16_t numRowsA = pSrcA - > numRows ; /* number of rows of input matrix A */
uint16_t numColsB = pSrcB - > numCols ; /* number of columns of input matrix B */
uint16_t numColsA = pSrcA - > numCols ; /* number of columns of input matrix A */
uint16_t col , i = 0U , row = numRowsA ; /* loop counters */
uint16x8_t vecOffs , vecColBOffs ;
uint32_t blkCnt , rowCnt ; /* loop counters */
arm_status status ; /* Status of matrix multiplication */
( void ) pState ;
q15_t * pInA = pSrcA - > pData ; /* input data matrix pointer A */
q15_t * pInB = pSrcB - > pData ; /* input data matrix pointer B */
q15_t * pInA2 ;
q15_t * pInB2 ;
q15_t * px ; /* Temporary output data matrix pointer */
q15_t * px2 ; /* Temporary output data matrix pointer */
uint32_t numRowsA = pSrcA - > numRows ; /* number of rows of input matrix A */
uint32_t numColsB = pSrcB - > numCols ; /* number of columns of input matrix B */
uint32_t numColsA = pSrcA - > numCols ; /* number of columns of input matrix A */
uint32_t numRowsB = pSrcB - > numRows ; /* number of rows of input matrix A */
uint32_t col , i = 0u , j , row = numRowsB ; /* loop counters */
q15_t * pSrcBT = pState ; /* input data matrix pointer for transpose */
uint32_t blkCnt ; /* loop counters */
arm_status status ; /* Status of matrix multiplication */
arm_matrix_instance_q15 BT ;
# ifdef ARM_MATH_MATRIX_CHECK
/* Check for matrix mismatch condition */
if ( ( pSrcA - > numCols ! = pSrcB - > numRows ) | |
/* Check for matrix mismatch condition */
if ( ( pSrcA - > numCols ! = pSrcB - > numRows ) | |
( pSrcA - > numRows ! = pDst - > numRows ) | |
( pSrcB - > numCols ! = pDst - > numCols ) )
{
/* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH ;
}
else
{
/* Set status as ARM_MATH_SIZE_MISMATCH */
status = ARM_MATH_SIZE_MISMATCH ;
}
else
# endif
{
/* small squared matrix specialized routines */
if ( numRowsA = = numColsB & & numColsB = = numColsA ) {
if ( numRowsA = = 1 )
{
q63_t sum ;
sum = pInA [ 0 ] * pInB [ 0 ] ;
pOut [ 0 ] = ( q15_t ) __SSAT ( ( sum > > 15 ) , 16 ) ;
return ( ARM_MATH_SUCCESS ) ;
{
/* small squared matrix specialized routines */
if ( numRowsA = = numColsB & & numColsB = = numColsA ) {
if ( numRowsA = = 1 ) {
q63_t sum ;
sum = pInA [ 0 ] * pInB [ 0 ] ;
pDst - > pData [ 0 ] = ( q15_t ) __SSAT ( ( sum > > 15 ) , 16 ) ;
return ( ARM_MATH_SUCCESS ) ;
} else if ( numRowsA = = 2 )
return arm_mat_mult_q15_2x2_mve ( pSrcA , pSrcB , pDst ) ;
else if ( numRowsA = = 3 )
return arm_mat_mult_q15_3x3_mve ( pSrcA , pSrcB , pDst ) ;
else if ( numRowsA = = 4 )
return arm_mat_mult_q15_4x4_mve ( pSrcA , pSrcB , pDst ) ;
}
else if ( numRowsA = = 2 )
return arm_mat_mult_q15_2x2_mve ( pSrcA , pSrcB , pDst ) ;
else if ( numRowsA = = 3 )
return arm_mat_mult_q15_3x3_mve ( pSrcA , pSrcB , pDst ) ;
else if ( numRowsA = = 4 )
return arm_mat_mult_q15_4x4_mve ( pSrcA , pSrcB , pDst ) ;
}
vecColBOffs = vidupq_u16 ( ( uint32_t ) 0 , 1 ) ;
vecColBOffs = vecColBOffs * ( uint16_t ) ( numColsB ) ;
/*
* The following loop performs the dot - product of each row in pSrcA with each column in pSrcB
*/
/*
* row loop
*/
rowCnt = row > > 2 ;
while ( rowCnt > 0U )
{
/*
* Output pointer is set to starting address of the row being processed
* Matrix transpose
*/
px = pOut + i ;
i = i + 4 * numColsB ;
BT . numRows = numColsB ;
BT . numCols = numRowsB ;
BT . pData = pSrcBT ;
arm_mat_trans_q15 ( pSrcB , & BT ) ;
/*
* For every row wise process , the column loop counter is to be initiated
* Reset the variables for the usage in the following multiplication process
*/
col = numColsB ;
i = 0 ;
row = numRowsA > > 1 ;
px = pDst - > pData ;
px2 = px + numColsB ;
/*
* For every row wise process , the pInB pointer is set
* to the starting address of the pSrcB data
* The following loop performs the dot - product of each row in pSrcA with each column in pSrcB
*/
pInB = pSrcB - > pData ;
/*
* column loop
* row loop
*/
while ( col > 0U )
{
while ( row > 0u ) {
/*
* generate 4 columns elements
* For every row wise process , the column loop counter is to be initiated
*/
col = numColsB > > 1 ;
/*
* Matrix A columns number of MAC operations are to be performed
* For every row wise process , the pIn2 pointer is set
* to the starting address of the transposed pSrcB data
*/
pInB = pSrcBT ;
pInB2 = pInB + numRowsB ;
j = 0 ;
q15_t const * pSrcA0Vec , * pSrcA1Vec , * pSrcA2Vec , * pSrcA3Vec ;
q15_t * pInA0 = pInA ;
q15_t * pInA1 = pInA0 + numColsA ;
q15_t * pInA2 = pInA1 + numColsA ;
q15_t * pInA3 = pInA2 + numColsA ;
q63_t acc0 , acc1 , acc2 , acc3 ;
acc0 = 0LL ;
acc1 = 0LL ;
acc2 = 0LL ;
acc3 = 0LL ;
pSrcA0Vec = ( q15_t const * ) pInA0 ;
pSrcA1Vec = ( q15_t const * ) pInA1 ;
pSrcA2Vec = ( q15_t const * ) pInA2 ;
pSrcA3Vec = ( q15_t const * ) pInA3 ;
vecOffs = vecColBOffs ;
blkCnt = ( numColsA ) > > 3 ;
while ( blkCnt > 0U )
{
q15x8_t vecB , vecA ;
vecB = vldrhq_gather_shifted_offset ( ( int16_t const * ) pInB , vecOffs ) ;
vecOffs = vecOffs + ( uint16_t ) ( numColsB * 8 ) ;
vecA = vld1q ( pSrcA0Vec ) ; pSrcA0Vec + = 8 ;
acc0 = vmlaldavaq ( acc0 , vecA , vecB ) ;
vecA = vld1q ( pSrcA1Vec ) ; pSrcA1Vec + = 8 ;
acc1 = vmlaldavaq ( acc1 , vecA , vecB ) ;
vecA = vld1q ( pSrcA2Vec ) ; pSrcA2Vec + = 8 ;
acc2 = vmlaldavaq ( acc2 , vecA , vecB ) ;
vecA = vld1q ( pSrcA3Vec ) ; pSrcA3Vec + = 8 ;
acc3 = vmlaldavaq ( acc3 , vecA , vecB ) ;
blkCnt - - ;
}
/*
* tail
* column loop
*/
blkCnt = numColsA & 7 ;
if ( blkCnt > 0U )
{
mve_pred16_t p0 = vctp16q ( blkCnt ) ;
q15x8_t vecB , vecA ;
vecB = vldrhq_gather_shifted_offset ( ( int16_t const * ) pInB , vecOffs ) ;
vecOffs = vecOffs + ( uint16_t ) ( numColsB * 8 ) ;
vecA = vld1q ( pSrcA0Vec ) ;
acc0 = vmlaldavaq_p ( acc0 , vecA , vecB , p0 ) ;
vecA = vld1q ( pSrcA1Vec ) ;
acc1 = vmlaldavaq_p ( acc1 , vecA , vecB , p0 ) ;
vecA = vld1q ( pSrcA2Vec ) ;
acc2 = vmlaldavaq_p ( acc2 , vecA , vecB , p0 ) ;
vecA = vld1q ( pSrcA3Vec ) ;
acc3 = vmlaldavaq_p ( acc3 , vecA , vecB , p0 ) ;
while ( col > 0u ) {
q15_t const * pSrcAVec , * pSrcBVec , * pSrcA2Vec , * pSrcB2Vec ;
q15x8_t vecA , vecA2 , vecB , vecB2 ;
q63_t acc0 , acc1 , acc2 , acc3 ;
/*
* Initiate the pointer pIn1 to point to the starting address of the column being processed
*/
pInA = pSrcA - > pData + i ;
pInA2 = pInA + numColsA ;
pInB = pSrcBT + j ;
pInB2 = pInB + numRowsB ;
pSrcAVec = ( q15_t const * ) pInA ;
pSrcA2Vec = ( q15_t const * ) pInA2 ;
pSrcBVec = ( q15_t const * ) pInB ;
pSrcB2Vec = ( q15_t const * ) pInB2 ;
acc0 = 0LL ;
acc1 = 0LL ;
acc2 = 0LL ;
acc3 = 0LL ;
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 8 ;
blkCnt = numColsA / 8 ;
while ( blkCnt > 0U ) {
vecB = vld1q ( pSrcBVec ) ;
pSrcBVec + = 8 ;
acc0 = vmlaldavaq ( acc0 , vecA , vecB ) ;
vecA2 = vld1q ( pSrcA2Vec ) ;
pSrcA2Vec + = 8 ;
acc1 = vmlaldavaq ( acc1 , vecA2 , vecB ) ;
vecB2 = vld1q ( pSrcB2Vec ) ;
pSrcB2Vec + = 8 ;
acc2 = vmlaldavaq ( acc2 , vecA , vecB2 ) ;
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 8 ;
acc3 = vmlaldavaq ( acc3 , vecA2 , vecB2 ) ;
blkCnt - - ;
}
/*
* tail
*/
blkCnt = numColsA & 7 ;
if ( blkCnt > 0U ) {
mve_pred16_t p0 = vctp16q ( blkCnt ) ;
vecB = vld1q ( pSrcBVec ) ;
acc0 = vmlaldavaq_p ( acc0 , vecA , vecB , p0 ) ;
vecA2 = vld1q ( pSrcA2Vec ) ;
acc1 = vmlaldavaq_p ( acc1 , vecA2 , vecB , p0 ) ;
vecB2 = vld1q ( pSrcB2Vec ) ;
acc2 = vmlaldavaq_p ( acc2 , vecA , vecB2 , p0 ) ;
vecA = vld1q ( pSrcAVec ) ;
acc3 = vmlaldavaq_p ( acc3 , vecA2 , vecB2 , p0 ) ;
}
* px + + = ( q15_t ) MVE_ASRL_SAT16 ( acc0 , 15 ) ;
* px + + = ( q15_t ) MVE_ASRL_SAT16 ( acc2 , 15 ) ;
* px2 + + = ( q15_t ) MVE_ASRL_SAT16 ( acc1 , 15 ) ;
* px2 + + = ( q15_t ) MVE_ASRL_SAT16 ( acc3 , 15 ) ;
j + = numRowsB * 2 ;
/*
* Decrement the column loop counter
*/
col - - ;
}
px [ 0 ] = ( q15_t ) MVE_ASRL_SAT16 ( acc0 , 15 ) ;
px [ 1 * numColsB ] = ( q15_t ) MVE_ASRL_SAT16 ( acc1 , 15 ) ;
px [ 2 * numColsB ] = ( q15_t ) MVE_ASRL_SAT16 ( acc2 , 15 ) ;
px [ 3 * numColsB ] = ( q15_t ) MVE_ASRL_SAT16 ( acc3 , 15 ) ;
px + + ;
i = i + numColsA * 2 ;
px = px2 + ( numColsB & 1u ) ;
px2 = px + numColsB ;
/*
* Decrement the column loop counter
* Decrement the row loop counter
*/
col - - ;
/*
* Update the pointer pInB to point to the starting address of the next column
*/
pInB = pSrcB - > pData + ( numColsB - col ) ;
row - - ;
}
/*
* Update the pointer pInA to point to the starting address of the next row
* Compute remaining row and / or column below
*/
pInA + = ( numColsA * 4 ) ;
/*
* Decrement the row loop counter
*/
rowCnt - - ;
}
if ( numColsB & 1u ) {
row = numRowsA & ( ~ 0x1 ) ; //avoid redundant computation
px = pDst - > pData + numColsB - 1 ;
i = 0 ;
rowCnt = row & 3 ;
while ( rowCnt > 0U )
{
/*
* Output pointer is set to starting address of the row being processed
*/
px = pOut + i ;
i = i + numColsB ;
/*
* For every row wise process , the column loop counter is to be initiated
*/
col = numColsB ;
/*
* For every row wise process , the pInB pointer is set
* to the starting address of the pSrcB data
*/
pInB = pSrcB - > pData ;
/*
* column loop
*/
while ( col > 0U )
{
/*
* generate 4 columns elements
* row loop
*/
/*
* Matrix A columns number of MAC operations are to be performed
*/
q15_t const * pSrcA0Vec ;
q15_t * pInA0 = pInA ;
q63_t acc0 ;
acc0 = 0LL ;
pSrcA0Vec = ( q15_t const * ) pInA0 ;
vecOffs = vecColBOffs ;
blkCnt = ( numColsA ) > > 3 ;
while ( blkCnt > 0U )
{
q15x8_t vecB , vecA ;
vecB = vldrhq_gather_shifted_offset ( ( int16_t const * ) pInB , vecOffs ) ;
vecOffs = vecOffs + ( uint16_t ) ( numColsB * 8 ) ;
vecA = vld1q ( pSrcA0Vec ) ;
pSrcA0Vec + = 8 ;
acc0 = vmlaldavaq ( acc0 , vecA , vecB ) ;
blkCnt - - ;
}
/*
* tail
*/
blkCnt = numColsA & 7 ;
if ( blkCnt > 0U )
{
mve_pred16_t p0 = vctp16q ( blkCnt ) ;
q15x8_t vecB , vecA ;
vecB = vldrhq_gather_shifted_offset ( ( int16_t const * ) pInB , vecOffs ) ;
vecOffs = vecOffs + ( uint16_t ) ( numColsB * 8 ) ;
vecA = vld1q ( pSrcA0Vec ) ;
acc0 = vmlaldavaq_p ( acc0 , vecA , vecB , p0 ) ;
while ( row > 0 ) {
q15_t const * pSrcAVec , * pSrcBVec ;
q15x8_t vecA , vecB ;
q63_t acc0 ;
/*
* point to last column in matrix B
*/
pInB = pSrcBT + numRowsB * ( numColsB - 1 ) ;
pInA = pSrcA - > pData + i ;
pSrcAVec = ( q15_t const * ) pInA ;
pSrcBVec = ( q15_t const * ) pInB ;
acc0 = 0LL ;
blkCnt = ( numColsA ) / 8 ;
while ( blkCnt > 0U ) {
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 8 ;
vecB = vld1q ( pSrcBVec ) ;
pSrcBVec + = 8 ;
acc0 = vmlaldavaq ( acc0 , vecA , vecB ) ;
blkCnt - - ;
}
/*
* tail
*/
blkCnt = ( numColsA & 7 ) ;
if ( blkCnt > 0U ) {
mve_pred16_t p0 = vctp16q ( blkCnt ) ;
vecA = vld1q ( pSrcAVec ) ;
vecB = vld1q ( pSrcBVec ) ;
acc0 = vmlaldavaq_p ( acc0 , vecA , vecB , p0 ) ;
}
* px = ( q15_t ) MVE_ASRL_SAT16 ( acc0 , 15 ) ;
px + = numColsB ;
i + = numColsA ;
/*
* Decrement the row loop counter
*/
row - - ;
}
}
px [ 0 ] = ( q15_t ) MVE_ASRL_SAT16 ( acc0 , 15 ) ;
px+ + ;
if ( numRowsA & 1u ) {
col = numColsB ;
i = 0u ;
/*
* Decrement the column loop counter
* point to last row in output matrix
*/
col- - ;
px = pDst - > pData + ( numColsB ) * ( numRowsA - 1 ) ;
/*
* Update the pointer pInB to point to the starting address of the next column
* col loop
*/
pInB = pSrcB - > pData + ( numColsB - col ) ;
while ( col > 0 ) {
q15_t const * pSrcAVec , * pSrcBVec ;
q15x8_t vecA , vecB ;
q63_t acc0 ;
/*
* point to last row in matrix A
*/
pInA = pSrcA - > pData + ( numRowsA - 1 ) * numColsA ;
pInB = pSrcBT + i ;
/*
* Set the variable sum , that acts as accumulator , to zero
*/
pSrcAVec = ( q15_t const * ) pInA ;
pSrcBVec = ( q15_t const * ) pInB ;
acc0 = 0LL ;
blkCnt = ( ( numColsA ) / 8 ) ;
while ( blkCnt > 0U ) {
vecA = vld1q ( pSrcAVec ) ;
pSrcAVec + = 8 ;
vecB = vld1q ( pSrcBVec ) ;
pSrcBVec + = 8 ;
acc0 = vmlaldavaq ( acc0 , vecA , vecB ) ;
blkCnt - - ;
}
/*
* tail
*/
blkCnt = ( numColsA & 7 ) ;
if ( blkCnt > 0U ) {
mve_pred16_t p0 = vctp16q ( blkCnt ) ;
vecA = vld1q ( pSrcAVec ) ;
vecB = vld1q ( pSrcBVec ) ;
acc0 = vmlaldavaq_p ( acc0 , vecA , vecB , p0 ) ;
}
* px + + = ( q15_t ) MVE_ASRL_SAT16 ( acc0 , 15 ) ;
i + = numColsA ;
/*
* Decrement the col loop counter
*/
col - - ;
}
}
/*
* Update the pointer pInA to point to the starting address of the next row
*/
pInA + = ( numColsA ) ;
rowCnt - - ;
/* Set status as ARM_MATH_SUCCESS */
status = ARM_MATH_SUCCESS ;
}
/* Set status as ARM_MATH_SUCCESS */
status = ARM_MATH_SUCCESS ;
}
/* Return to application */
return ( status ) ;
/* Return to application */
return ( status ) ;
}
# else
arm_status arm_mat_mult_q15 (
const arm_matrix_instance_q15 * pSrcA ,