diff --git a/src/dsp/enc_sse2.c b/src/dsp/enc_sse2.c index b874225c..253b1c7c 100644 --- a/src/dsp/enc_sse2.c +++ b/src/dsp/enc_sse2.c @@ -463,28 +463,36 @@ static void FTransform2(const uint8_t* src, const uint8_t* ref, int16_t* out) { FTransformPass2(&v01h, &v32h, out + 16); } +static void FTransformWHTRow(const int16_t* const in, __m128i* const out) { + const __m128i kMult1 = _mm_set_epi16(0, 0, 0, 0, 1, 1, 1, 1); + const __m128i kMult2 = _mm_set_epi16(0, 0, 0, 0, -1, 1, -1, 1); + const __m128i src0 = _mm_loadl_epi64((__m128i*)&in[0 * 16]); + const __m128i src1 = _mm_loadl_epi64((__m128i*)&in[1 * 16]); + const __m128i src2 = _mm_loadl_epi64((__m128i*)&in[2 * 16]); + const __m128i src3 = _mm_loadl_epi64((__m128i*)&in[3 * 16]); + const __m128i A01 = _mm_unpacklo_epi16(src0, src1); // A0 A1 | ... + const __m128i A23 = _mm_unpacklo_epi16(src2, src3); // A2 A3 | ... + const __m128i B0 = _mm_adds_epi16(A01, A23); // a0 | a1 | ... + const __m128i B1 = _mm_subs_epi16(A01, A23); // a3 | a2 | ... + const __m128i C0 = _mm_unpacklo_epi32(B0, B1); // a0 | a1 | a3 | a2 + const __m128i C1 = _mm_unpacklo_epi32(B1, B0); // a3 | a2 | a0 | a1 + const __m128i D0 = _mm_madd_epi16(C0, kMult1); // out0, out1 + const __m128i D1 = _mm_madd_epi16(C1, kMult2); // out2, out3 + *out = _mm_unpacklo_epi64(D0, D1); +} + static void FTransformWHT(const int16_t* in, int16_t* out) { - int32_t tmp[16]; - int i; - for (i = 0; i < 4; ++i, in += 64) { - const int a0 = (in[0 * 16] + in[2 * 16]); - const int a1 = (in[1 * 16] + in[3 * 16]); - const int a2 = (in[1 * 16] - in[3 * 16]); - const int a3 = (in[0 * 16] - in[2 * 16]); - tmp[0 + i * 4] = a0 + a1; - tmp[1 + i * 4] = a3 + a2; - tmp[2 + i * 4] = a3 - a2; - tmp[3 + i * 4] = a0 - a1; - } + __m128i row0, row1, row2, row3; + FTransformWHTRow(in + 0 * 64, &row0); + FTransformWHTRow(in + 1 * 64, &row1); + FTransformWHTRow(in + 2 * 64, &row2); + FTransformWHTRow(in + 3 * 64, &row3); + { - const __m128i src0 = _mm_loadu_si128((__m128i*)&tmp[0]); - const __m128i src1 = _mm_loadu_si128((__m128i*)&tmp[4]); - const __m128i src2 = _mm_loadu_si128((__m128i*)&tmp[8]); - const __m128i src3 = _mm_loadu_si128((__m128i*)&tmp[12]); - const __m128i a0 = _mm_add_epi32(src0, src2); - const __m128i a1 = _mm_add_epi32(src1, src3); - const __m128i a2 = _mm_sub_epi32(src1, src3); - const __m128i a3 = _mm_sub_epi32(src0, src2); + const __m128i a0 = _mm_add_epi32(row0, row2); + const __m128i a1 = _mm_add_epi32(row1, row3); + const __m128i a2 = _mm_sub_epi32(row1, row3); + const __m128i a3 = _mm_sub_epi32(row0, row2); const __m128i b0 = _mm_srai_epi32(_mm_add_epi32(a0, a1), 1); const __m128i b1 = _mm_srai_epi32(_mm_add_epi32(a3, a2), 1); const __m128i b2 = _mm_srai_epi32(_mm_sub_epi32(a3, a2), 1);