[AArch64] Add SME implementation of TransposeWxH

We can make use of the ZA tile register to do the transpose without any
explicit permute instructions: just load the tile horizontally and store
it vertically.

Change-Id: I1c31e89af52a408e3491e62d6c9e6fee41b1b80a
Reviewed-on: https://chromium-review.googlesource.com/c/libyuv/libyuv/+/5703587
Reviewed-by: Frank Barchard <fbarchard@chromium.org>
This commit is contained in:
George Steed 2024-05-22 15:46:15 +01:00 committed by Frank Barchard
parent a4ccf9940e
commit 15ecca81f7
3 changed files with 99 additions and 1 deletions

View File

@ -26,6 +26,14 @@ extern "C" {
#if defined(__native_client__)
#define LIBYUV_DISABLE_NEON
#endif
// clang >= 19.0.0 required for SME
#if defined(__clang__) && defined(__aarch64__) && !defined(LIBYUV_DISABLE_SME)
#if __clang_major__ < 19
#define LIBYUV_DISABLE_SME
#endif
#endif
// MemorySanitizer does not support assembly code yet. http://crbug.com/344505
#if defined(__has_feature)
#if __has_feature(memory_sanitizer) && !defined(LIBYUV_DISABLE_NEON)
@ -66,6 +74,10 @@ extern "C" {
#define HAS_TRANSPOSE4X4_32_NEON
#endif
#if !defined(LIBYUV_DISABLE_SME) && defined(__aarch64__)
#define HAS_TRANSPOSEWXH_SME
#endif
#if !defined(LIBYUV_DISABLE_MSA) && defined(__mips_msa)
#define HAS_TRANSPOSEWX16_MSA
#define HAS_TRANSPOSEUVWX16_MSA
@ -103,6 +115,12 @@ void TransposeWx16_NEON(const uint8_t* src,
uint8_t* dst,
int dst_stride,
int width);
void TransposeWxH_SME(const uint8_t* src,
int src_stride,
uint8_t* dst,
int dst_stride,
int width,
int height);
void TransposeWx8_SSSE3(const uint8_t* src,
int src_stride,
uint8_t* dst,

View File

@ -31,6 +31,10 @@ void TransposePlane(const uint8_t* src,
int width,
int height) {
int i = height;
#if defined(HAS_TRANSPOSEWXH_SME)
void (*TransposeWxH)(const uint8_t* src, int src_stride, uint8_t* dst,
int dst_stride, int width, int height) = nullptr;
#endif
#if defined(HAS_TRANSPOSEWX16_MSA) || defined(HAS_TRANSPOSEWX16_LSX) || \
defined(HAS_TRANSPOSEWX16_NEON)
void (*TransposeWx16)(const uint8_t* src, int src_stride, uint8_t* dst,
@ -56,6 +60,11 @@ void TransposePlane(const uint8_t* src,
}
}
#endif
#if defined(HAS_TRANSPOSEWXH_SME)
if (TestCpuFlag(kCpuHasSME)) {
TransposeWxH = TransposeWxH_SME;
}
#endif
#if defined(HAS_TRANSPOSEWX8_SSSE3)
if (TestCpuFlag(kCpuHasSSSE3)) {
TransposeWx8 = TransposeWx8_Any_SSSE3;
@ -89,6 +98,12 @@ void TransposePlane(const uint8_t* src,
}
#endif
#if defined(HAS_TRANSPOSEWXH_SME)
if (TransposeWxH) {
TransposeWxH(src, src_stride, dst, dst_stride, width, height);
return;
}
#endif
#if defined(HAS_TRANSPOSEWX16_MSA) || defined(HAS_TRANSPOSEWX16_LSX) || \
defined(HAS_TRANSPOSEWX16_NEON)
// Work across the source in 16x16 tiles

View File

@ -20,7 +20,72 @@ extern "C" {
#if !defined(LIBYUV_DISABLE_SME) && defined(__aarch64__)
// TODO: Port rotate kernels to SME.
__arm_locally_streaming __arm_new("za") void TransposeWxH_SME(
const uint8_t* src,
int src_stride,
uint8_t* dst,
int dst_stride,
int width,
int height) {
int vl;
asm("cntb %x0" : "=r"(vl));
do {
const uint8_t* src2 = src;
uint8_t* dst2 = dst;
// Process up to VL elements per iteration of the inner loop.
int block_height = height > vl ? vl : height;
int width2 = width;
do {
const uint8_t* src3 = src2;
// Process up to VL elements per iteration of the inner loop.
int block_width = width2 > vl ? vl : width2;
asm volatile(
"mov w12, #0 \n"
// Create a predicate to handle loading partial rows.
"whilelt p0.b, wzr, %w[block_width] \n"
// Load H <= VL rows into ZA0.
"1: \n"
"ld1b {za0h.b[w12, 0]}, p0/z, [%[src3]] \n"
"add %[src3], %[src3], %[src_stride] \n"
"add w12, w12, #1 \n"
"cmp w12, %w[block_height] \n"
"b.ne 1b \n"
// Create a predicate to handle storing partial columns.
"whilelt p0.b, wzr, %w[block_height] \n"
"mov w12, #0 \n"
// Store W <= VL columns from ZA0.
"2: \n"
"st1b {za0v.b[w12, 0]}, p0, [%[dst2]] \n"
"add %[dst2], %[dst2], %[dst_stride] \n"
"add w12, w12, #1 \n"
"cmp w12, %w[block_width] \n"
"b.ne 2b \n"
: [src3] "+r"(src3), // %[src3]
[dst2] "+r"(dst2) // %[dst2]
: [src_stride] "r"((ptrdiff_t)src_stride), // %[src_stride]
[dst_stride] "r"((ptrdiff_t)dst_stride), // %[dst_stride]
[block_width] "r"(block_width), // %[block_width]
[block_height] "r"(block_height) // %[block_height]
: "cc", "memory", "p0", "w12", "za");
src2 += vl;
width2 -= vl;
} while (width2 > 0);
src += vl * src_stride;
dst += vl;
height -= vl;
} while (height > 0);
}
#endif // !defined(LIBYUV_DISABLE_SME) && defined(__aarch64__)