diff --git a/Source/BasicMathFunctions/arm_dot_prod_f32.c b/Source/BasicMathFunctions/arm_dot_prod_f32.c index 8510c022..3eee3b97 100644 --- a/Source/BasicMathFunctions/arm_dot_prod_f32.c +++ b/Source/BasicMathFunctions/arm_dot_prod_f32.c @@ -72,7 +72,7 @@ void arm_dot_prod_f32( float32x4_t vec1; float32x4_t vec2; float32x4_t res; - float32x2_t accum = vdup_n_f32(0); + float32x4_t accum = vdupq_n_f32(0); /* Compute 4 outputs at a time */ blkCnt = blockSize >> 2U; @@ -85,8 +85,7 @@ void arm_dot_prod_f32( /* C = A[0]*B[0] + A[1]*B[1] + A[2]*B[2] + ... + A[blockSize-1]*B[blockSize-1] */ /* Calculate dot product and then store the result in a temporary buffer. */ - res = vmulq_f32(vec1, vec2); - accum = vadd_f32(accum, vpadd_f32(vget_low_f32(res), vget_high_f32(res))); + accum = vmlaq_f32(accum, vec1, vec2); /* Increment pointers */ pSrcA += 4; @@ -98,7 +97,12 @@ void arm_dot_prod_f32( /* Decrement the loop counter */ blkCnt--; } - sum += accum[0] + accum[1]; + +#if __aarch64__ + sum = vpadds_f32(vpadd_f32(vget_low_f32(accum), vget_high_f32(accum))); +#else + sum = (vpadd_f32(vget_low_f32(accum), vget_high_f32(accum)))[0] + (vpadd_f32(vget_low_f32(accum), vget_high_f32(accum)))[1]; +#endif /* Tail */ blkCnt = blockSize & 0x3;