Program Listing for File intrinsics.h

Return to documentation for file (include/intrinsics.h)

#ifndef intrinsics_h
#define intrinsics_h

namespace exafmm_t {
#if FLOAT
  inline void matmult_8x8x2(float*& M_, float*& IN0, float*& IN1, float*& OUT0, float*& OUT1){
    __m128 out00,out01,out10,out11;
    __m128 out20,out21,out30,out31;
    float* in0__ = IN0;
    float* in1__ = IN1;
    out00 = _mm_load_ps(OUT0);
    out01 = _mm_load_ps(OUT1);
    out10 = _mm_load_ps(OUT0+4);
    out11 = _mm_load_ps(OUT1+4);
    out20 = _mm_load_ps(OUT0+8);
    out21 = _mm_load_ps(OUT1+8);
    out30 = _mm_load_ps(OUT0+12);
    out31 = _mm_load_ps(OUT1+12);
    for(int i2=0;i2<8;i2+=2){
      __m128 m00;
      __m128 mt0,mtt0;
      __m128 in00,in00_r,in01,in01_r;
      in00 = _mm_castpd_ps(_mm_load_pd1((const double*)in0__));
      in00_r = _mm_shuffle_ps(in00,in00,_MM_SHUFFLE(2,3,0,1));
      in01 = _mm_castpd_ps(_mm_load_pd1((const double*)in1__));
      in01_r = _mm_shuffle_ps(in01,in01,_MM_SHUFFLE(2,3,0,1));
      m00 = _mm_load_ps(M_);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out00= _mm_add_ps   (out00,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out00= _mm_addsub_ps(out00,_mm_mul_ps(mtt0,in00_r));
      out01 = _mm_add_ps   (out01,_mm_mul_ps( mt0,in01  ));
      out01 = _mm_addsub_ps(out01,_mm_mul_ps(mtt0,in01_r));
      m00 = _mm_load_ps(M_+4);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out10= _mm_add_ps   (out10,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out10= _mm_addsub_ps(out10,_mm_mul_ps(mtt0,in00_r));
      out11 = _mm_add_ps   (out11,_mm_mul_ps( mt0,in01  ));
      out11 = _mm_addsub_ps(out11,_mm_mul_ps(mtt0,in01_r));
      m00 = _mm_load_ps(M_+8);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out20= _mm_add_ps   (out20,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out20= _mm_addsub_ps(out20,_mm_mul_ps(mtt0,in00_r));
      out21 = _mm_add_ps   (out21,_mm_mul_ps( mt0,in01  ));
      out21 = _mm_addsub_ps(out21,_mm_mul_ps(mtt0,in01_r));
      m00 = _mm_load_ps(M_+12);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out30= _mm_add_ps   (out30,_mm_mul_ps( mt0,  in00));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out30= _mm_addsub_ps(out30,_mm_mul_ps(mtt0,in00_r));
      out31 = _mm_add_ps   (out31,_mm_mul_ps( mt0,in01  ));
      out31 = _mm_addsub_ps(out31,_mm_mul_ps(mtt0,in01_r));
      in00 = _mm_castpd_ps(_mm_load_pd1((const double*) (in0__+2)));
      in00_r = _mm_shuffle_ps(in00,in00,_MM_SHUFFLE(2,3,0,1));
      in01 = _mm_castpd_ps(_mm_load_pd1((const double*) (in1__+2)));
      in01_r = _mm_shuffle_ps(in01,in01,_MM_SHUFFLE(2,3,0,1));
      m00 = _mm_load_ps(M_+16);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out00= _mm_add_ps   (out00,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out00= _mm_addsub_ps(out00,_mm_mul_ps(mtt0,in00_r));
      out01 = _mm_add_ps   (out01,_mm_mul_ps( mt0,in01  ));
      out01 = _mm_addsub_ps(out01,_mm_mul_ps(mtt0,in01_r));
      m00 = _mm_load_ps(M_+20);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out10= _mm_add_ps   (out10,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out10= _mm_addsub_ps(out10,_mm_mul_ps(mtt0,in00_r));
      out11 = _mm_add_ps   (out11,_mm_mul_ps( mt0,in01 ));
      out11 = _mm_addsub_ps(out11,_mm_mul_ps(mtt0,in01_r));
      m00 = _mm_load_ps(M_+24);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out20= _mm_add_ps   (out20,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out20= _mm_addsub_ps(out20,_mm_mul_ps(mtt0,in00_r));
      out21 = _mm_add_ps   (out21,_mm_mul_ps( mt0,in01  ));
      out21 = _mm_addsub_ps(out21,_mm_mul_ps(mtt0,in01_r));
      m00 = _mm_load_ps(M_+28);
      mt0  = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(2,2,0,0));
      out30= _mm_add_ps   (out30,_mm_mul_ps( mt0,in00  ));
      mtt0 = _mm_shuffle_ps(m00,m00,_MM_SHUFFLE(3,3,1,1));
      out30= _mm_addsub_ps(out30,_mm_mul_ps(mtt0,in00_r));
      out31 = _mm_add_ps   (out31,_mm_mul_ps( mt0,in01  ));
      out31 = _mm_addsub_ps(out31,_mm_mul_ps(mtt0,in01_r));
      M_ += 32;
      in0__ += 4;
      in1__ += 4;
    }
    _mm_store_ps(OUT0,out00);
    _mm_store_ps(OUT1,out01);
    _mm_store_ps(OUT0+4,out10);
    _mm_store_ps(OUT1+4,out11);
    _mm_store_ps(OUT0+8,out20);
    _mm_store_ps(OUT1+8,out21);
    _mm_store_ps(OUT0+12,out30);
    _mm_store_ps(OUT1+12,out31);
  }
#else
#ifdef __AVX512F__
  inline void matmult_8x8x2(double*& M_, double*& IN0, double*& IN1, double*& OUT0, double*& OUT1){
    double* in0_ = IN0;
    double* in1_ = IN1;

    __m512d out00, out10, out01, out11;
    __m512d one = _mm512_set1_pd(1.0);
    out00 = _mm512_load_pd(OUT0);
    out01 = _mm512_load_pd(OUT0+8);
    out10 = _mm512_load_pd(OUT1);
    out11 = _mm512_load_pd(OUT1+8);

    for(int i=0;i<8;i+=2){
      __m512d in00, in00_r, in10, in10_r;
      __m512d m00, mt0, mtt0;
      __m512d temp;

      // (in00, in01, in00, in01, in00, in01, in00, in01)
      in00 = _mm512_broadcast_f64x4(_mm256_broadcast_pd((const __m128d*)in0_));
      in10 = _mm512_broadcast_f64x4(_mm256_broadcast_pd((const __m128d*)in1_));
      // (in01, in00, in01, in00, in01, in00, in01, in00)
      in00_r = _mm512_permute_pd(in00, 85);
      in10_r = _mm512_permute_pd(in10, 85);
      // M column1 row0~4 * IN0 row1/IN1 row1
      // column 1, 4*2(real and imag)
      m00 = _mm512_load_pd(M_);
      // M shuffle (M0, M0, M2, M2, M4, M4, M6, M6)
      mt0 = _mm512_unpacklo_pd(m00, m00);
      // M shuffle (M1, M1, M3, M3, M5, M5, M7, M7)
      mtt0 = _mm512_unpackhi_pd(m00, m00);
      temp = _mm512_mul_pd(mt0, in00);
      out00 = _mm512_add_pd(out00, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in00_r)));
      temp = _mm512_mul_pd(mt0, in10);
      out10 = _mm512_add_pd(out10, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in10_r)));

      // M column1 row5~8 * IN0 row1/IN1 row1
      m00 = _mm512_load_pd(M_+8);
      mt0 = _mm512_unpacklo_pd(m00, m00);
      mtt0 = _mm512_unpackhi_pd(m00, m00);
      temp = _mm512_mul_pd(mt0, in00);
      out01 = _mm512_add_pd(out01, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in00_r)));
      temp = _mm512_mul_pd(mt0, in10);
      out11 = _mm512_add_pd(out11, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in10_r)));

      // M column2 row0~4 * IN0 row2/In1 row2
      in00 = _mm512_broadcast_f64x4(_mm256_broadcast_pd((const __m128d*)(in0_+2)));
      in10 = _mm512_broadcast_f64x4(_mm256_broadcast_pd((const __m128d*)(in1_+2)));
      in00_r = _mm512_permute_pd(in00, 85);
      in10_r = _mm512_permute_pd(in10, 85);
      m00 = _mm512_load_pd(M_+16);
      mt0 = _mm512_unpacklo_pd(m00, m00);
      mtt0 = _mm512_unpackhi_pd(m00, m00);
      temp = _mm512_mul_pd(mt0, in00);
      out00 = _mm512_add_pd(out00, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in00_r)));
      temp = _mm512_mul_pd(mt0, in10);
      out10 = _mm512_add_pd(out10, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in10_r)));

      // M column2 row5~8 * IN0 row2/IN1 row2
      m00 = _mm512_load_pd(M_+24);
      mt0 = _mm512_unpacklo_pd(m00, m00);
      mtt0 = _mm512_unpackhi_pd(m00, m00);
      temp = _mm512_mul_pd(mt0, in00);
      out01 = _mm512_add_pd(out01, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in00_r)));
      temp = _mm512_mul_pd(mt0, in10);
      out11 = _mm512_add_pd(out11, _mm512_fmaddsub_pd(temp, one, _mm512_mul_pd(mtt0, in10_r)));

      M_+=32; // Jump to (column+2).
      in0_+=4;
      in1_+=4;
    }
    _mm512_store_pd(OUT0, out00);
    _mm512_store_pd(OUT0+8, out01);
    _mm512_store_pd(OUT1, out10);
    _mm512_store_pd(OUT1+8, out11);
  }
#elif __AVX__
  inline void matmult_8x8x2(double*& M_, double*& IN0, double*& IN1, double*& OUT0, double*& OUT1){
    __m256d out00,out01,out10,out11;
    __m256d out20,out21,out30,out31;
    double* in0__ = IN0;
    double* in1__ = IN1;
    out00 = _mm256_load_pd(OUT0);
    out01 = _mm256_load_pd(OUT1);
    out10 = _mm256_load_pd(OUT0+4);
    out11 = _mm256_load_pd(OUT1+4);
    out20 = _mm256_load_pd(OUT0+8);
    out21 = _mm256_load_pd(OUT1+8);
    out30 = _mm256_load_pd(OUT0+12);
    out31 = _mm256_load_pd(OUT1+12);
    for(int i2=0;i2<8;i2+=2){
      __m256d m00;
      __m256d ot00;
      __m256d mt0,mtt0;
      __m256d in00,in00_r,in01,in01_r;
      in00 = _mm256_broadcast_pd((const __m128d*)in0__);
      in00_r = _mm256_permute_pd(in00,5);
      in01 = _mm256_broadcast_pd((const __m128d*)in1__);
      in01_r = _mm256_permute_pd(in01,5);
      m00 = _mm256_load_pd(M_);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out00 = _mm256_add_pd(out00,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out01 = _mm256_add_pd(out01,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      m00 = _mm256_load_pd(M_+4);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out10 = _mm256_add_pd(out10,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out11 = _mm256_add_pd(out11,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      m00 = _mm256_load_pd(M_+8);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out20 = _mm256_add_pd(out20,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out21 = _mm256_add_pd(out21,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      m00 = _mm256_load_pd(M_+12);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out30 = _mm256_add_pd(out30,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out31 = _mm256_add_pd(out31,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      in00 = _mm256_broadcast_pd((const __m128d*) (in0__+2));
      in00_r = _mm256_permute_pd(in00,5);
      in01 = _mm256_broadcast_pd((const __m128d*) (in1__+2));
      in01_r = _mm256_permute_pd(in01,5);
      m00 = _mm256_load_pd(M_+16);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out00 = _mm256_add_pd(out00,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out01 = _mm256_add_pd(out01,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      m00 = _mm256_load_pd(M_+20);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out10 = _mm256_add_pd(out10,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out11 = _mm256_add_pd(out11,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      m00 = _mm256_load_pd(M_+24);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out20 = _mm256_add_pd(out20,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out21 = _mm256_add_pd(out21,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      m00 = _mm256_load_pd(M_+28);
      mt0 = _mm256_unpacklo_pd(m00,m00);
      ot00 = _mm256_mul_pd(mt0,in00);
      mtt0 = _mm256_unpackhi_pd(m00,m00);
      out30 = _mm256_add_pd(out30,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in00_r)));
      ot00 = _mm256_mul_pd(mt0,in01);
      out31 = _mm256_add_pd(out31,_mm256_addsub_pd(ot00,_mm256_mul_pd(mtt0,in01_r)));
      M_ += 32;
      in0__ += 4;
      in1__ += 4;
    }
    _mm256_store_pd(OUT0,out00);
    _mm256_store_pd(OUT1,out01);
    _mm256_store_pd(OUT0+4,out10);
    _mm256_store_pd(OUT1+4,out11);
    _mm256_store_pd(OUT0+8,out20);
    _mm256_store_pd(OUT1+8,out21);
    _mm256_store_pd(OUT0+12,out30);
    _mm256_store_pd(OUT1+12,out31);
  }
#else
  inline void matmult_8x8x2(real_t*& M_, real_t*& IN0, real_t*& IN1, real_t*& OUT0, real_t*& OUT1){
    // Generic code.
    real_t out_reg000, out_reg001, out_reg010, out_reg011;
    real_t out_reg100, out_reg101, out_reg110, out_reg111;
    real_t  in_reg000,  in_reg001,  in_reg010,  in_reg011;
    real_t  in_reg100,  in_reg101,  in_reg110,  in_reg111;
    real_t   m_reg000,   m_reg001,   m_reg010,   m_reg011;
    real_t   m_reg100,   m_reg101,   m_reg110,   m_reg111;
    //#pragma unroll
    for(int i1=0;i1<8;i1+=2){
      real_t* IN0_=IN0;
      real_t* IN1_=IN1;

      out_reg000=OUT0[ 0]; out_reg001=OUT0[ 1];
      out_reg010=OUT0[ 2]; out_reg011=OUT0[ 3];
      out_reg100=OUT1[ 0]; out_reg101=OUT1[ 1];
      out_reg110=OUT1[ 2]; out_reg111=OUT1[ 3];
      //#pragma unroll
      for(int i2=0;i2<8;i2+=2){
        m_reg000=M_[ 0]; m_reg001=M_[ 1];
        m_reg010=M_[ 2]; m_reg011=M_[ 3];
        m_reg100=M_[16]; m_reg101=M_[17];
        m_reg110=M_[18]; m_reg111=M_[19];

        in_reg000=IN0_[0]; in_reg001=IN0_[1];
        in_reg010=IN0_[2]; in_reg011=IN0_[3];
        in_reg100=IN1_[0]; in_reg101=IN1_[1];
        in_reg110=IN1_[2]; in_reg111=IN1_[3];

        out_reg000 += m_reg000*in_reg000 - m_reg001*in_reg001;
        out_reg001 += m_reg000*in_reg001 + m_reg001*in_reg000;
        out_reg010 += m_reg010*in_reg000 - m_reg011*in_reg001;
        out_reg011 += m_reg010*in_reg001 + m_reg011*in_reg000;

        out_reg000 += m_reg100*in_reg010 - m_reg101*in_reg011;
        out_reg001 += m_reg100*in_reg011 + m_reg101*in_reg010;
        out_reg010 += m_reg110*in_reg010 - m_reg111*in_reg011;
        out_reg011 += m_reg110*in_reg011 + m_reg111*in_reg010;

        out_reg100 += m_reg000*in_reg100 - m_reg001*in_reg101;
        out_reg101 += m_reg000*in_reg101 + m_reg001*in_reg100;
        out_reg110 += m_reg010*in_reg100 - m_reg011*in_reg101;
        out_reg111 += m_reg010*in_reg101 + m_reg011*in_reg100;

        out_reg100 += m_reg100*in_reg110 - m_reg101*in_reg111;
        out_reg101 += m_reg100*in_reg111 + m_reg101*in_reg110;
        out_reg110 += m_reg110*in_reg110 - m_reg111*in_reg111;
        out_reg111 += m_reg110*in_reg111 + m_reg111*in_reg110;

        M_+=32; // Jump to (column+2).
        IN0_+=4;
        IN1_+=4;
      }
      OUT0[ 0]=out_reg000; OUT0[ 1]=out_reg001;
      OUT0[ 2]=out_reg010; OUT0[ 3]=out_reg011;
      OUT1[ 0]=out_reg100; OUT1[ 1]=out_reg101;
      OUT1[ 2]=out_reg110; OUT1[ 3]=out_reg111;
      M_+=4-64*2; // Jump back to first column (row+2).
      OUT0+=4;
      OUT1+=4;
    }
  }
#endif
#endif
}
#endif