1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have 10 // a single sext/zext or trunc instruction that takes the bottom half of a 11 // vector and extends to a full width, like NEON has with MOVL. Instead it is 12 // expected that this happens through top/bottom instructions. So the MVE 13 // equivalent VMOVLT/B instructions take either the even or odd elements of the 14 // input and extend them to the larger type, producing a vector with half the 15 // number of elements each of double the bitwidth. As there is no simple 16 // instruction, we often have to turn sext/zext/trunc into a series of lane 17 // moves (or stack loads/stores, which we do not do yet). 18 // 19 // This pass takes vector code that starts at truncs, looks for interconnected 20 // blobs of operations that end with sext/zext (or constants/splats) of the 21 // form: 22 // %sa = sext v8i16 %a to v8i32 23 // %sb = sext v8i16 %b to v8i32 24 // %add = add v8i32 %sa, %sb 25 // %r = trunc %add to v8i16 26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions: 27 // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7> 28 // %sa = sext v8i16 %sha to v8i32 29 // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7> 30 // %sb = sext v8i16 %shb to v8i32 31 // %add = add v8i32 %sa, %sb 32 // %r = trunc %add to v8i16 33 // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7> 34 // Which can then be split and lowered to MVE instructions efficiently: 35 // %sa_b = VMOVLB.s16 %a 36 // %sa_t = VMOVLT.s16 %a 37 // %sb_b = VMOVLB.s16 %b 38 // %sb_t = VMOVLT.s16 %b 39 // %add_b = VADD.i32 %sa_b, %sb_b 40 // %add_t = VADD.i32 %sa_t, %sb_t 41 // %r = VMOVNT.i16 %add_b, %add_t 42 // 43 //===----------------------------------------------------------------------===// 44 45 #include "ARM.h" 46 #include "ARMBaseInstrInfo.h" 47 #include "ARMSubtarget.h" 48 #include "llvm/Analysis/TargetTransformInfo.h" 49 #include "llvm/CodeGen/TargetLowering.h" 50 #include "llvm/CodeGen/TargetPassConfig.h" 51 #include "llvm/CodeGen/TargetSubtargetInfo.h" 52 #include "llvm/IR/BasicBlock.h" 53 #include "llvm/IR/Constant.h" 54 #include "llvm/IR/Constants.h" 55 #include "llvm/IR/DerivedTypes.h" 56 #include "llvm/IR/Function.h" 57 #include "llvm/IR/IRBuilder.h" 58 #include "llvm/IR/InstIterator.h" 59 #include "llvm/IR/InstrTypes.h" 60 #include "llvm/IR/Instruction.h" 61 #include "llvm/IR/Instructions.h" 62 #include "llvm/IR/IntrinsicInst.h" 63 #include "llvm/IR/Intrinsics.h" 64 #include "llvm/IR/IntrinsicsARM.h" 65 #include "llvm/IR/PatternMatch.h" 66 #include "llvm/IR/Type.h" 67 #include "llvm/IR/Value.h" 68 #include "llvm/InitializePasses.h" 69 #include "llvm/Pass.h" 70 #include "llvm/Support/Casting.h" 71 #include <algorithm> 72 #include <cassert> 73 74 using namespace llvm; 75 76 #define DEBUG_TYPE "mve-laneinterleave" 77 78 cl::opt<bool> EnableInterleave( 79 "enable-mve-interleave", cl::Hidden, cl::init(true), 80 cl::desc("Enable interleave MVE vector operation lowering")); 81 82 namespace { 83 84 class MVELaneInterleaving : public FunctionPass { 85 public: 86 static char ID; // Pass identification, replacement for typeid 87 88 explicit MVELaneInterleaving() : FunctionPass(ID) { 89 initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry()); 90 } 91 92 bool runOnFunction(Function &F) override; 93 94 StringRef getPassName() const override { return "MVE lane interleaving"; } 95 96 void getAnalysisUsage(AnalysisUsage &AU) const override { 97 AU.setPreservesCFG(); 98 AU.addRequired<TargetPassConfig>(); 99 FunctionPass::getAnalysisUsage(AU); 100 } 101 }; 102 103 } // end anonymous namespace 104 105 char MVELaneInterleaving::ID = 0; 106 107 INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false, 108 false) 109 110 Pass *llvm::createMVELaneInterleavingPass() { 111 return new MVELaneInterleaving(); 112 } 113 114 static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts, 115 SmallSetVector<Instruction *, 4> &Truncs) { 116 // This is not always beneficial to transform. Exts can be incorporated into 117 // loads, Truncs can be folded into stores. 118 // Truncs are usually the same number of instructions, 119 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving 120 // Exts are unfortunately more instructions in the general case: 121 // A=VLDRH.32; B=VLDRH.32; 122 // vs with interleaving: 123 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T 124 // But those VMOVL may be folded into a VMULL. 125 126 // But expensive extends/truncs are always good to remove. 127 for (auto *E : Exts) 128 if (!isa<LoadInst>(E->getOperand(0))) { 129 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n"); 130 return true; 131 } 132 for (auto *T : Truncs) 133 if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) { 134 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n"); 135 return true; 136 } 137 138 // Otherwise, we know we have a load(ext), see if any of the Extends are a 139 // vmull. This is a simple heuristic and certainly not perfect. 140 for (auto *E : Exts) { 141 if (!E->hasOneUse() || 142 cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) { 143 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n"); 144 return false; 145 } 146 } 147 return true; 148 } 149 150 static bool tryInterleave(Instruction *Start, 151 SmallPtrSetImpl<Instruction *> &Visited) { 152 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n"); 153 auto *VT = cast<FixedVectorType>(Start->getType()); 154 155 if (!isa<Instruction>(Start->getOperand(0))) 156 return false; 157 158 // Look for connected operations starting from Ext's, terminating at Truncs. 159 std::vector<Instruction *> Worklist; 160 Worklist.push_back(Start); 161 Worklist.push_back(cast<Instruction>(Start->getOperand(0))); 162 163 SmallSetVector<Instruction *, 4> Truncs; 164 SmallSetVector<Instruction *, 4> Exts; 165 SmallSetVector<Use *, 4> OtherLeafs; 166 SmallSetVector<Instruction *, 4> Ops; 167 168 while (!Worklist.empty()) { 169 Instruction *I = Worklist.back(); 170 Worklist.pop_back(); 171 172 switch (I->getOpcode()) { 173 // Truncs 174 case Instruction::Trunc: 175 if (Truncs.count(I)) 176 continue; 177 Truncs.insert(I); 178 Visited.insert(I); 179 break; 180 181 // Extend leafs 182 case Instruction::SExt: 183 case Instruction::ZExt: 184 if (Exts.count(I)) 185 continue; 186 for (auto *Use : I->users()) 187 Worklist.push_back(cast<Instruction>(Use)); 188 Exts.insert(I); 189 break; 190 191 // Binary/tertiary ops 192 case Instruction::Add: 193 case Instruction::Sub: 194 case Instruction::Mul: 195 case Instruction::AShr: 196 case Instruction::LShr: 197 case Instruction::Shl: 198 case Instruction::ICmp: 199 case Instruction::Select: 200 if (Ops.count(I)) 201 continue; 202 Ops.insert(I); 203 204 for (Use &Op : I->operands()) { 205 if (isa<Instruction>(Op)) 206 Worklist.push_back(cast<Instruction>(&Op)); 207 else 208 OtherLeafs.insert(&Op); 209 } 210 211 for (auto *Use : I->users()) 212 Worklist.push_back(cast<Instruction>(Use)); 213 break; 214 215 case Instruction::ShuffleVector: 216 // A shuffle of a splat is a splat. 217 if (cast<ShuffleVectorInst>(I)->isZeroEltSplat()) 218 continue; 219 LLVM_FALLTHROUGH; 220 221 default: 222 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n"); 223 return false; 224 } 225 } 226 227 if (Exts.empty() && OtherLeafs.empty()) 228 return false; 229 230 LLVM_DEBUG({ 231 dbgs() << "Found group:\n Exts:"; 232 for (auto *I : Exts) 233 dbgs() << " " << *I << "\n"; 234 dbgs() << " Ops:"; 235 for (auto *I : Ops) 236 dbgs() << " " << *I << "\n"; 237 dbgs() << " OtherLeafs:"; 238 for (auto *I : OtherLeafs) 239 dbgs() << " " << *I << "\n"; 240 dbgs() << "Truncs:"; 241 for (auto *I : Truncs) 242 dbgs() << " " << *I << "\n"; 243 }); 244 245 assert(!Truncs.empty() && "Expected some truncs"); 246 247 // Check types 248 unsigned NumElts = VT->getNumElements(); 249 unsigned BaseElts = VT->getScalarSizeInBits() == 16 250 ? 8 251 : (VT->getScalarSizeInBits() == 8 ? 16 : 0); 252 if (BaseElts == 0 || NumElts % BaseElts != 0) { 253 LLVM_DEBUG(dbgs() << " Type is unsupported\n"); 254 return false; 255 } 256 if (Start->getOperand(0)->getType()->getScalarSizeInBits() != 257 VT->getScalarSizeInBits() * 2) { 258 LLVM_DEBUG(dbgs() << " Type not double sized\n"); 259 return false; 260 } 261 for (Instruction *I : Exts) 262 if (I->getOperand(0)->getType() != VT) { 263 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n"); 264 return false; 265 } 266 for (Instruction *I : Truncs) 267 if (I->getType() != VT) { 268 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n"); 269 return false; 270 } 271 272 // Check that it looks beneficial 273 if (!isProfitableToInterleave(Exts, Truncs)) 274 return false; 275 276 // Create new shuffles around the extends / truncs / other leaves. 277 IRBuilder<> Builder(Start); 278 279 SmallVector<int, 16> LeafMask; 280 SmallVector<int, 16> TruncMask; 281 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15 282 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15 283 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { 284 for (unsigned i = 0; i < BaseElts / 2; i++) 285 LeafMask.push_back(Base + i * 2); 286 for (unsigned i = 0; i < BaseElts / 2; i++) 287 LeafMask.push_back(Base + i * 2 + 1); 288 } 289 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { 290 for (unsigned i = 0; i < BaseElts / 2; i++) { 291 TruncMask.push_back(Base + i); 292 TruncMask.push_back(Base + i + BaseElts / 2); 293 } 294 } 295 296 for (Instruction *I : Exts) { 297 LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n"); 298 Builder.SetInsertPoint(I); 299 Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask); 300 bool Sext = isa<SExtInst>(I); 301 Value *Ext = Sext ? Builder.CreateSExt(Shuffle, I->getType()) 302 : Builder.CreateZExt(Shuffle, I->getType()); 303 I->replaceAllUsesWith(Ext); 304 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n"); 305 } 306 307 for (Use *I : OtherLeafs) { 308 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n"); 309 Builder.SetInsertPoint(cast<Instruction>(I->getUser())); 310 Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask); 311 I->getUser()->setOperand(I->getOperandNo(), Shuffle); 312 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n"); 313 } 314 315 for (Instruction *I : Truncs) { 316 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n"); 317 318 Builder.SetInsertPoint(I->getParent(), ++I->getIterator()); 319 Value *Shuf = Builder.CreateShuffleVector(I, TruncMask); 320 I->replaceAllUsesWith(Shuf); 321 cast<Instruction>(Shuf)->setOperand(0, I); 322 323 LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n"); 324 } 325 326 return true; 327 } 328 329 bool MVELaneInterleaving::runOnFunction(Function &F) { 330 if (!EnableInterleave) 331 return false; 332 auto &TPC = getAnalysis<TargetPassConfig>(); 333 auto &TM = TPC.getTM<TargetMachine>(); 334 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 335 if (!ST->hasMVEIntegerOps()) 336 return false; 337 338 bool Changed = false; 339 340 SmallPtrSet<Instruction *, 16> Visited; 341 for (Instruction &I : reverse(instructions(F))) { 342 if (I.getType()->isVectorTy() && 343 (isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I)) 344 Changed |= tryInterleave(&I, Visited); 345 } 346 347 return Changed; 348 } 349