1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have
10 // a single sext/zext or trunc instruction that takes the bottom half of a
11 // vector and extends to a full width, like NEON has with MOVL. Instead it is
12 // expected that this happens through top/bottom instructions. So the MVE
13 // equivalent VMOVLT/B instructions take either the even or odd elements of the
14 // input and extend them to the larger type, producing a vector with half the
15 // number of elements each of double the bitwidth. As there is no simple
16 // instruction, we often have to turn sext/zext/trunc into a series of lane
17 // moves (or stack loads/stores, which we do not do yet).
18 //
19 // This pass takes vector code that starts at truncs, looks for interconnected
20 // blobs of operations that end with sext/zext (or constants/splats) of the
21 // form:
22 //   %sa = sext v8i16 %a to v8i32
23 //   %sb = sext v8i16 %b to v8i32
24 //   %add = add v8i32 %sa, %sb
25 //   %r = trunc %add to v8i16
26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27 //   %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28 //   %sa = sext v8i16 %sha to v8i32
29 //   %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30 //   %sb = sext v8i16 %shb to v8i32
31 //   %add = add v8i32 %sa, %sb
32 //   %r = trunc %add to v8i16
33 //   %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34 // Which can then be split and lowered to MVE instructions efficiently:
35 //   %sa_b = VMOVLB.s16 %a
36 //   %sa_t = VMOVLT.s16 %a
37 //   %sb_b = VMOVLB.s16 %b
38 //   %sb_t = VMOVLT.s16 %b
39 //   %add_b = VADD.i32 %sa_b, %sb_b
40 //   %add_t = VADD.i32 %sa_t, %sb_t
41 //   %r = VMOVNT.i16 %add_b, %add_t
42 //
43 //===----------------------------------------------------------------------===//
44 
45 #include "ARM.h"
46 #include "ARMBaseInstrInfo.h"
47 #include "ARMSubtarget.h"
48 #include "llvm/Analysis/TargetTransformInfo.h"
49 #include "llvm/CodeGen/TargetLowering.h"
50 #include "llvm/CodeGen/TargetPassConfig.h"
51 #include "llvm/CodeGen/TargetSubtargetInfo.h"
52 #include "llvm/IR/BasicBlock.h"
53 #include "llvm/IR/Constant.h"
54 #include "llvm/IR/Constants.h"
55 #include "llvm/IR/DerivedTypes.h"
56 #include "llvm/IR/Function.h"
57 #include "llvm/IR/IRBuilder.h"
58 #include "llvm/IR/InstIterator.h"
59 #include "llvm/IR/InstrTypes.h"
60 #include "llvm/IR/Instruction.h"
61 #include "llvm/IR/Instructions.h"
62 #include "llvm/IR/IntrinsicInst.h"
63 #include "llvm/IR/Intrinsics.h"
64 #include "llvm/IR/IntrinsicsARM.h"
65 #include "llvm/IR/PatternMatch.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/Value.h"
68 #include "llvm/InitializePasses.h"
69 #include "llvm/Pass.h"
70 #include "llvm/Support/Casting.h"
71 #include <algorithm>
72 #include <cassert>
73 
74 using namespace llvm;
75 
76 #define DEBUG_TYPE "mve-laneinterleave"
77 
78 cl::opt<bool> EnableInterleave(
79     "enable-mve-interleave", cl::Hidden, cl::init(true),
80     cl::desc("Enable interleave MVE vector operation lowering"));
81 
82 namespace {
83 
84 class MVELaneInterleaving : public FunctionPass {
85 public:
86   static char ID; // Pass identification, replacement for typeid
87 
88   explicit MVELaneInterleaving() : FunctionPass(ID) {
89     initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
90   }
91 
92   bool runOnFunction(Function &F) override;
93 
94   StringRef getPassName() const override { return "MVE lane interleaving"; }
95 
96   void getAnalysisUsage(AnalysisUsage &AU) const override {
97     AU.setPreservesCFG();
98     AU.addRequired<TargetPassConfig>();
99     FunctionPass::getAnalysisUsage(AU);
100   }
101 };
102 
103 } // end anonymous namespace
104 
105 char MVELaneInterleaving::ID = 0;
106 
107 INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
108                 false)
109 
110 Pass *llvm::createMVELaneInterleavingPass() {
111   return new MVELaneInterleaving();
112 }
113 
114 static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
115                                      SmallSetVector<Instruction *, 4> &Truncs) {
116   // This is not always beneficial to transform. Exts can be incorporated into
117   // loads, Truncs can be folded into stores.
118   // Truncs are usually the same number of instructions,
119   //  VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
120   // Exts are unfortunately more instructions in the general case:
121   //  A=VLDRH.32; B=VLDRH.32;
122   // vs with interleaving:
123   //  T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
124   // But those VMOVL may be folded into a VMULL.
125 
126   // But expensive extends/truncs are always good to remove.
127   for (auto *E : Exts)
128     if (!isa<LoadInst>(E->getOperand(0))) {
129       LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
130       return true;
131     }
132   for (auto *T : Truncs)
133     if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
134       LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
135       return true;
136     }
137 
138   // Otherwise, we know we have a load(ext), see if any of the Extends are a
139   // vmull. This is a simple heuristic and certainly not perfect.
140   for (auto *E : Exts) {
141     if (!E->hasOneUse() ||
142         cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
143       LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
144       return false;
145     }
146   }
147   return true;
148 }
149 
150 static bool tryInterleave(Instruction *Start,
151                           SmallPtrSetImpl<Instruction *> &Visited) {
152   LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
153   auto *VT = cast<FixedVectorType>(Start->getType());
154 
155   if (!isa<Instruction>(Start->getOperand(0)))
156     return false;
157 
158   // Look for connected operations starting from Ext's, terminating at Truncs.
159   std::vector<Instruction *> Worklist;
160   Worklist.push_back(Start);
161   Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
162 
163   SmallSetVector<Instruction *, 4> Truncs;
164   SmallSetVector<Instruction *, 4> Exts;
165   SmallSetVector<Use *, 4> OtherLeafs;
166   SmallSetVector<Instruction *, 4> Ops;
167 
168   while (!Worklist.empty()) {
169     Instruction *I = Worklist.back();
170     Worklist.pop_back();
171 
172     switch (I->getOpcode()) {
173     // Truncs
174     case Instruction::Trunc:
175       if (Truncs.count(I))
176         continue;
177       Truncs.insert(I);
178       Visited.insert(I);
179       break;
180 
181     // Extend leafs
182     case Instruction::SExt:
183     case Instruction::ZExt:
184       if (Exts.count(I))
185         continue;
186       for (auto *Use : I->users())
187         Worklist.push_back(cast<Instruction>(Use));
188       Exts.insert(I);
189       break;
190 
191     // Binary/tertiary ops
192     case Instruction::Add:
193     case Instruction::Sub:
194     case Instruction::Mul:
195     case Instruction::AShr:
196     case Instruction::LShr:
197     case Instruction::Shl:
198     case Instruction::ICmp:
199     case Instruction::Select:
200       if (Ops.count(I))
201         continue;
202       Ops.insert(I);
203 
204       for (Use &Op : I->operands()) {
205         if (isa<Instruction>(Op))
206           Worklist.push_back(cast<Instruction>(&Op));
207         else
208           OtherLeafs.insert(&Op);
209       }
210 
211       for (auto *Use : I->users())
212         Worklist.push_back(cast<Instruction>(Use));
213       break;
214 
215     case Instruction::ShuffleVector:
216       // A shuffle of a splat is a splat.
217       if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
218         continue;
219       LLVM_FALLTHROUGH;
220 
221     default:
222       LLVM_DEBUG(dbgs() << "  Unhandled instruction: " << *I << "\n");
223       return false;
224     }
225   }
226 
227   if (Exts.empty() && OtherLeafs.empty())
228     return false;
229 
230   LLVM_DEBUG({
231     dbgs() << "Found group:\n  Exts:";
232     for (auto *I : Exts)
233       dbgs() << "  " << *I << "\n";
234     dbgs() << "  Ops:";
235     for (auto *I : Ops)
236       dbgs() << "  " << *I << "\n";
237     dbgs() << "  OtherLeafs:";
238     for (auto *I : OtherLeafs)
239       dbgs() << "  " << *I << "\n";
240     dbgs() << "Truncs:";
241     for (auto *I : Truncs)
242       dbgs() << "  " << *I << "\n";
243   });
244 
245   assert(!Truncs.empty() && "Expected some truncs");
246 
247   // Check types
248   unsigned NumElts = VT->getNumElements();
249   unsigned BaseElts = VT->getScalarSizeInBits() == 16
250                           ? 8
251                           : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
252   if (BaseElts == 0 || NumElts % BaseElts != 0) {
253     LLVM_DEBUG(dbgs() << "  Type is unsupported\n");
254     return false;
255   }
256   if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
257       VT->getScalarSizeInBits() * 2) {
258     LLVM_DEBUG(dbgs() << "  Type not double sized\n");
259     return false;
260   }
261   for (Instruction *I : Exts)
262     if (I->getOperand(0)->getType() != VT) {
263       LLVM_DEBUG(dbgs() << "  Wrong type on " << *I << "\n");
264       return false;
265     }
266   for (Instruction *I : Truncs)
267     if (I->getType() != VT) {
268       LLVM_DEBUG(dbgs() << "  Wrong type on " << *I << "\n");
269       return false;
270     }
271 
272   // Check that it looks beneficial
273   if (!isProfitableToInterleave(Exts, Truncs))
274     return false;
275 
276   // Create new shuffles around the extends / truncs / other leaves.
277   IRBuilder<> Builder(Start);
278 
279   SmallVector<int, 16> LeafMask;
280   SmallVector<int, 16> TruncMask;
281   // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7   8, 10, 12, 14,  9, 11, 13, 15
282   // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7   8, 12,  9, 13, 10, 14, 11, 15
283   for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
284     for (unsigned i = 0; i < BaseElts / 2; i++)
285       LeafMask.push_back(Base + i * 2);
286     for (unsigned i = 0; i < BaseElts / 2; i++)
287       LeafMask.push_back(Base + i * 2 + 1);
288   }
289   for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
290     for (unsigned i = 0; i < BaseElts / 2; i++) {
291       TruncMask.push_back(Base + i);
292       TruncMask.push_back(Base + i + BaseElts / 2);
293     }
294   }
295 
296   for (Instruction *I : Exts) {
297     LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
298     Builder.SetInsertPoint(I);
299     Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
300     bool Sext = isa<SExtInst>(I);
301     Value *Ext = Sext ? Builder.CreateSExt(Shuffle, I->getType())
302                       : Builder.CreateZExt(Shuffle, I->getType());
303     I->replaceAllUsesWith(Ext);
304     LLVM_DEBUG(dbgs() << "  with " << *Shuffle << "\n");
305   }
306 
307   for (Use *I : OtherLeafs) {
308     LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
309     Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
310     Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
311     I->getUser()->setOperand(I->getOperandNo(), Shuffle);
312     LLVM_DEBUG(dbgs() << "  with " << *Shuffle << "\n");
313   }
314 
315   for (Instruction *I : Truncs) {
316     LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
317 
318     Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
319     Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
320     I->replaceAllUsesWith(Shuf);
321     cast<Instruction>(Shuf)->setOperand(0, I);
322 
323     LLVM_DEBUG(dbgs() << "  with " << *Shuf << "\n");
324   }
325 
326   return true;
327 }
328 
329 bool MVELaneInterleaving::runOnFunction(Function &F) {
330   if (!EnableInterleave)
331     return false;
332   auto &TPC = getAnalysis<TargetPassConfig>();
333   auto &TM = TPC.getTM<TargetMachine>();
334   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
335   if (!ST->hasMVEIntegerOps())
336     return false;
337 
338   bool Changed = false;
339 
340   SmallPtrSet<Instruction *, 16> Visited;
341   for (Instruction &I : reverse(instructions(F))) {
342     if (I.getType()->isVectorTy() &&
343         (isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I))
344       Changed |= tryInterleave(&I, Visited);
345   }
346 
347   return Changed;
348 }
349