1 //===-- Elementary operations for aarch64 --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
10 #define LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
11 
12 #include <src/string/memory_utils/elements.h>
13 #include <stddef.h> // size_t
14 #include <stdint.h> // uint8_t, uint16_t, uint32_t, uint64_t
15 
16 #ifdef __ARM_NEON
17 #include <arm_neon.h>
18 #endif
19 
20 namespace __llvm_libc {
21 namespace aarch64_memset {
22 #ifdef __ARM_NEON
23 struct Splat8 {
24   static constexpr size_t kSize = 8;
25   static void SplatSet(char *dst, const unsigned char value) {
26     vst1_u8((uint8_t *)dst, vdup_n_u8(value));
27   }
28 };
29 
30 struct Splat16 {
31   static constexpr size_t kSize = 16;
32   static void SplatSet(char *dst, const unsigned char value) {
33     vst1q_u8((uint8_t *)dst, vdupq_n_u8(value));
34   }
35 };
36 
37 using _8 = Splat8;
38 using _16 = Splat16;
39 #else
40 using _8 = __llvm_libc::scalar::_8;
41 using _16 = Repeated<_8, 2>;
42 #endif // __ARM_NEON
43 
44 using _1 = __llvm_libc::scalar::_1;
45 using _2 = __llvm_libc::scalar::_2;
46 using _3 = __llvm_libc::scalar::_3;
47 using _4 = __llvm_libc::scalar::_4;
48 using _32 = Chained<_16, _16>;
49 using _64 = Chained<_32, _32>;
50 
51 struct ZVA {
52   static constexpr size_t kSize = 64;
53   static void SplatSet(char *dst, const unsigned char value) {
54     asm("dc zva, %[dst]" : : [dst] "r"(dst) : "memory");
55   }
56 };
57 
58 inline static bool AArch64ZVA(char *dst, size_t count) {
59   uint64_t zva_val;
60   asm("mrs %[zva_val], dczid_el0" : [zva_val] "=r"(zva_val));
61   if ((zva_val & 31) != 4)
62     return false;
63   SplatSet<Align<_64, Arg::_1>::Then<Loop<ZVA, _64>>>(dst, 0, count);
64   return true;
65 }
66 
67 } // namespace aarch64_memset
68 
69 namespace aarch64 {
70 
71 using _1 = __llvm_libc::scalar::_1;
72 using _2 = __llvm_libc::scalar::_2;
73 using _3 = __llvm_libc::scalar::_3;
74 using _4 = __llvm_libc::scalar::_4;
75 using _8 = __llvm_libc::scalar::_8;
76 using _16 = __llvm_libc::scalar::_16;
77 
78 #ifdef __ARM_NEON
79 struct N32 {
80   static constexpr size_t kSize = 32;
81   static bool Equals(const char *lhs, const char *rhs) {
82     uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs);
83     uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs);
84     uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16));
85     uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16));
86     uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1));
87     uint64_t res =
88         vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0);
89     return res == 0;
90   }
91   static int ThreeWayCompare(const char *lhs, const char *rhs) {
92     uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs);
93     uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs);
94     uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16));
95     uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16));
96     uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1));
97     uint64_t res =
98         vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0);
99     if (res == 0)
100       return 0;
101     size_t index = (__builtin_ctzl(res) >> 3) << 2;
102     uint32_t l = *((const uint32_t *)(lhs + index));
103     uint32_t r = *((const uint32_t *)(rhs + index));
104     return __llvm_libc::scalar::_4::ScalarThreeWayCompare(l, r);
105   }
106 };
107 
108 using _32 = N32;
109 #else
110 using _32 = __llvm_libc::scalar::_32;
111 #endif // __ARM_NEON
112 
113 } // namespace aarch64
114 } // namespace __llvm_libc
115 
116 #endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
117