1 //===- MatmulOptimizer.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "polly/MatmulOptimizer.h"
10 #include "polly/DependenceInfo.h"
11 #include "polly/Options.h"
12 #include "polly/ScheduleTreeTransform.h"
13 #include "polly/ScopInfo.h"
14 #include "polly/ScopPass.h"
15 #include "polly/Simplify.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/Optional.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/IR/DataLayout.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/TypeSize.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "isl/ctx.h"
31 #include "isl/schedule_node.h"
32 #include "isl/schedule_type.h"
33 #include "isl/union_map.h"
34 #include "isl/union_set.h"
35 #include <algorithm>
36 #include <cassert>
37 #include <cmath>
38 #include <cstdint>
39 #include <string>
40 #include <vector>
41 
42 #define DEBUG_TYPE "polly-opt-isl"
43 
44 using namespace llvm;
45 using namespace polly;
46 
47 namespace llvm {
48 class Value;
49 }
50 
51 static cl::opt<int> LatencyVectorFma(
52     "polly-target-latency-vector-fma",
53     cl::desc("The minimal number of cycles between issuing two "
54              "dependent consecutive vector fused multiply-add "
55              "instructions."),
56     cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
57 
58 static cl::opt<int> ThroughputVectorFma(
59     "polly-target-throughput-vector-fma",
60     cl::desc("A throughput of the processor floating-point arithmetic units "
61              "expressed in the number of vector fused multiply-add "
62              "instructions per clock cycle."),
63     cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
64 
65 static cl::opt<int> FirstCacheLevelSize(
66     "polly-target-1st-cache-level-size",
67     cl::desc("The size of the first cache level specified in bytes."),
68     cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
69 
70 static cl::opt<int> FirstCacheLevelDefaultSize(
71     "polly-target-1st-cache-level-default-size",
72     cl::desc("The default size of the first cache level specified in bytes"
73              " (if not enough were provided by the TargetTransformInfo)."),
74     cl::Hidden, cl::init(32768), cl::ZeroOrMore, cl::cat(PollyCategory));
75 
76 static cl::opt<int> SecondCacheLevelSize(
77     "polly-target-2nd-cache-level-size",
78     cl::desc("The size of the second level specified in bytes."), cl::Hidden,
79     cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
80 
81 static cl::opt<int> SecondCacheLevelDefaultSize(
82     "polly-target-2nd-cache-level-default-size",
83     cl::desc("The default size of the second cache level specified in bytes"
84              " (if not enough were provided by the TargetTransformInfo)."),
85     cl::Hidden, cl::init(262144), cl::ZeroOrMore, cl::cat(PollyCategory));
86 
87 // This option, along with --polly-target-2nd-cache-level-associativity,
88 // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size
89 // represent the parameters of the target cache, which do not have typical
90 // values that can be used by default. However, to apply the pattern matching
91 // optimizations, we use the values of the parameters of Intel Core i7-3820
92 // SandyBridge in case the parameters are not specified or not provided by the
93 // TargetTransformInfo.
94 static cl::opt<int> FirstCacheLevelAssociativity(
95     "polly-target-1st-cache-level-associativity",
96     cl::desc("The associativity of the first cache level."), cl::Hidden,
97     cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
98 
99 static cl::opt<int> FirstCacheLevelDefaultAssociativity(
100     "polly-target-1st-cache-level-default-associativity",
101     cl::desc("The default associativity of the first cache level"
102              " (if not enough were provided by the TargetTransformInfo)."),
103     cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
104 
105 static cl::opt<int> SecondCacheLevelAssociativity(
106     "polly-target-2nd-cache-level-associativity",
107     cl::desc("The associativity of the second cache level."), cl::Hidden,
108     cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
109 
110 static cl::opt<int> SecondCacheLevelDefaultAssociativity(
111     "polly-target-2nd-cache-level-default-associativity",
112     cl::desc("The default associativity of the second cache level"
113              " (if not enough were provided by the TargetTransformInfo)."),
114     cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
115 
116 static cl::opt<int> VectorRegisterBitwidth(
117     "polly-target-vector-register-bitwidth",
118     cl::desc("The size in bits of a vector register (if not set, this "
119              "information is taken from LLVM's target information."),
120     cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
121 
122 static cl::opt<int> PollyPatternMatchingNcQuotient(
123     "polly-pattern-matching-nc-quotient",
124     cl::desc("Quotient that is obtained by dividing Nc, the parameter of the"
125              "macro-kernel, by Nr, the parameter of the micro-kernel"),
126     cl::Hidden, cl::init(256), cl::ZeroOrMore, cl::cat(PollyCategory));
127 
128 namespace {
129 /// Parameters of the micro kernel.
130 ///
131 /// Parameters, which determine sizes of rank-1 (i.e., outer product) update
132 /// used in the optimized matrix multiplication.
133 struct MicroKernelParamsTy {
134   int Mr;
135   int Nr;
136 };
137 
138 /// Parameters of the macro kernel.
139 ///
140 /// Parameters, which determine sizes of blocks of partitioned matrices
141 /// used in the optimized matrix multiplication.
142 struct MacroKernelParamsTy {
143   int Mc;
144   int Nc;
145   int Kc;
146 };
147 
148 /// Parameters of the matrix multiplication operands.
149 ///
150 /// Parameters, which describe access relations that represent operands of the
151 /// matrix multiplication.
152 struct MatMulInfoTy {
153   MemoryAccess *A = nullptr;
154   MemoryAccess *B = nullptr;
155   MemoryAccess *ReadFromC = nullptr;
156   MemoryAccess *WriteToC = nullptr;
157   int i = -1;
158   int j = -1;
159   int k = -1;
160 };
161 
162 /// Create an isl::union_set, which describes the option of the form
163 /// [isolate[] -> unroll[x]].
164 ///
165 /// @param Ctx An isl::ctx, which is used to create the isl::union_set.
166 static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) {
167   isl::space Space = isl::space(Ctx, 0, 0, 1);
168   isl::map UnrollIsolatedSetOption = isl::map::universe(Space);
169   isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr);
170   isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr);
171   UnrollIsolatedSetOption =
172       UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId);
173   UnrollIsolatedSetOption =
174       UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId);
175   return UnrollIsolatedSetOption.wrap();
176 }
177 
178 /// Permute the two dimensions of the isl map.
179 ///
180 /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
181 /// have type @p DimType.
182 ///
183 /// @param Map     The isl map to be modified.
184 /// @param DimType The type of the dimensions.
185 /// @param DstPos  The first dimension.
186 /// @param SrcPos  The second dimension.
187 /// @return        The modified map.
188 static isl::map permuteDimensions(isl::map Map, isl::dim DimType,
189                                   unsigned DstPos, unsigned SrcPos) {
190   assert((isl_size)DstPos < Map.dim(DimType) &&
191          (isl_size)SrcPos < Map.dim(DimType));
192   if (DstPos == SrcPos)
193     return Map;
194   isl::id DimId;
195   if (Map.has_tuple_id(DimType))
196     DimId = Map.get_tuple_id(DimType);
197   auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in;
198   isl::id FreeDimId;
199   if (Map.has_tuple_id(FreeDim))
200     FreeDimId = Map.get_tuple_id(FreeDim);
201   auto MaxDim = std::max(DstPos, SrcPos);
202   auto MinDim = std::min(DstPos, SrcPos);
203   Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1);
204   Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1);
205   Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1);
206   Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1);
207   if (DimId)
208     Map = Map.set_tuple_id(DimType, DimId);
209   if (FreeDimId)
210     Map = Map.set_tuple_id(FreeDim, FreeDimId);
211   return Map;
212 }
213 
214 /// Check the form of the access relation.
215 ///
216 /// Check that the access relation @p AccMap has the form M[i][j], where i
217 /// is a @p FirstPos and j is a @p SecondPos.
218 ///
219 /// @param AccMap    The access relation to be checked.
220 /// @param FirstPos  The index of the input dimension that is mapped to
221 ///                  the first output dimension.
222 /// @param SecondPos The index of the input dimension that is mapped to the
223 ///                  second output dimension.
224 /// @return          True in case @p AccMap has the expected form and false,
225 ///                  otherwise.
226 static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
227                                int &SecondPos) {
228   isl::space Space = AccMap.get_space();
229   isl::map Universe = isl::map::universe(Space);
230 
231   if (Space.dim(isl::dim::out) != 2)
232     return false;
233 
234   // MatMul has the form:
235   // for (i = 0; i < N; i++)
236   //   for (j = 0; j < M; j++)
237   //     for (k = 0; k < P; k++)
238   //       C[i, j] += A[i, k] * B[k, j]
239   //
240   // Permutation of three outer loops: 3! = 6 possibilities.
241   int FirstDims[] = {0, 0, 1, 1, 2, 2};
242   int SecondDims[] = {1, 2, 2, 0, 0, 1};
243   for (int i = 0; i < 6; i += 1) {
244     auto PossibleMatMul =
245         Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
246             .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
247 
248     AccMap = AccMap.intersect_domain(Domain);
249     PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
250 
251     // If AccMap spans entire domain (Non-partial write),
252     // compute FirstPos and SecondPos.
253     // If AccMap != PossibleMatMul here (the two maps have been gisted at
254     // this point), it means that the writes are not complete, or in other
255     // words, it is a Partial write and Partial writes must be rejected.
256     if (AccMap.is_equal(PossibleMatMul)) {
257       if (FirstPos != -1 && FirstPos != FirstDims[i])
258         continue;
259       FirstPos = FirstDims[i];
260       if (SecondPos != -1 && SecondPos != SecondDims[i])
261         continue;
262       SecondPos = SecondDims[i];
263       return true;
264     }
265   }
266 
267   return false;
268 }
269 
270 /// Does the memory access represent a non-scalar operand of the matrix
271 /// multiplication.
272 ///
273 /// Check that the memory access @p MemAccess is the read access to a non-scalar
274 /// operand of the matrix multiplication or its result.
275 ///
276 /// @param MemAccess The memory access to be checked.
277 /// @param MMI       Parameters of the matrix multiplication operands.
278 /// @return          True in case the memory access represents the read access
279 ///                  to a non-scalar operand of the matrix multiplication and
280 ///                  false, otherwise.
281 static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
282                                         MatMulInfoTy &MMI) {
283   if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
284     return false;
285   auto AccMap = MemAccess->getLatestAccessRelation();
286   isl::set StmtDomain = MemAccess->getStatement()->getDomain();
287   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
288     MMI.ReadFromC = MemAccess;
289     return true;
290   }
291   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
292     MMI.A = MemAccess;
293     return true;
294   }
295   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
296     MMI.B = MemAccess;
297     return true;
298   }
299   return false;
300 }
301 
302 /// Check accesses to operands of the matrix multiplication.
303 ///
304 /// Check that accesses of the SCoP statement, which corresponds to
305 /// the partial schedule @p PartialSchedule, are scalar in terms of loops
306 /// containing the matrix multiplication, in case they do not represent
307 /// accesses to the non-scalar operands of the matrix multiplication or
308 /// its result.
309 ///
310 /// @param  PartialSchedule The partial schedule of the SCoP statement.
311 /// @param  MMI             Parameters of the matrix multiplication operands.
312 /// @return                 True in case the corresponding SCoP statement
313 ///                         represents matrix multiplication and false,
314 ///                         otherwise.
315 static bool containsOnlyMatrMultAcc(isl::map PartialSchedule,
316                                     MatMulInfoTy &MMI) {
317   auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in);
318   auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user());
319   isl_size OutDimNum = PartialSchedule.dim(isl::dim::out);
320   assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest "
321                           "and, consequently, the corresponding scheduling "
322                           "functions have at least three dimensions.");
323   auto MapI =
324       permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1);
325   auto MapJ =
326       permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1);
327   auto MapK =
328       permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1);
329 
330   auto Accesses = getAccessesInOrder(*Stmt);
331   for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) {
332     auto *MemAccessPtr = *MemA;
333     if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC &&
334         !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) &&
335         !(MemAccessPtr->isStrideZero(MapI)) &&
336         MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK))
337       return false;
338   }
339   return true;
340 }
341 
342 /// Check for dependencies corresponding to the matrix multiplication.
343 ///
344 /// Check that there is only true dependence of the form
345 /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement
346 /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds
347 /// to the dependency produced by the matrix multiplication.
348 ///
349 /// @param  Schedule The schedule of the SCoP statement.
350 /// @param  D The SCoP dependencies.
351 /// @param  Pos The parameter to describe an acceptable true dependence.
352 ///             In case it has a negative value, try to determine its
353 ///             acceptable value.
354 /// @return True in case dependencies correspond to the matrix multiplication
355 ///         and false, otherwise.
356 static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D,
357                                   int &Pos) {
358   isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
359   isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
360   if (Red)
361     Dep = Dep.unite(Red);
362   auto DomainSpace = Schedule.get_space().domain();
363   auto Space = DomainSpace.map_from_domain_and_range(DomainSpace);
364   auto Deltas = Dep.extract_map(Space).deltas();
365   isl_size DeltasDimNum = Deltas.dim(isl::dim::set);
366   for (int i = 0; i < DeltasDimNum; i++) {
367     auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i);
368     Pos = Pos < 0 && Val.is_one() ? i : Pos;
369     if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one())))
370       return false;
371   }
372   if (DeltasDimNum == 0 || Pos < 0)
373     return false;
374   return true;
375 }
376 
377 /// Check if the SCoP statement could probably be optimized with analytical
378 /// modeling.
379 ///
380 /// containsMatrMult tries to determine whether the following conditions
381 /// are true:
382 /// 1. The last memory access modeling an array, MA1, represents writing to
383 ///    memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or
384 ///    S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement
385 ///    under consideration.
386 /// 2. There is only one loop-carried true dependency, and it has the
387 ///    form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no
388 ///    loop-carried or anti dependencies.
389 /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent
390 ///    reading from memory and have the form S(..., i3, ...) -> M(i1, i3),
391 ///    S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively,
392 ///    and all memory accesses of the SCoP that are different from MA1, MA2,
393 ///    MA3, and MA4 have stride 0, if the innermost loop is exchanged with any
394 ///    of loops i1, i2 and i3.
395 ///
396 /// @param PartialSchedule The PartialSchedule that contains a SCoP statement
397 ///        to check.
398 /// @D     The SCoP dependencies.
399 /// @MMI   Parameters of the matrix multiplication operands.
400 static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
401                              MatMulInfoTy &MMI) {
402   auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
403   auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
404   if (Stmt->size() <= 1)
405     return false;
406 
407   auto Accesses = getAccessesInOrder(*Stmt);
408   for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) {
409     auto *MemAccessPtr = *MemA;
410     if (!MemAccessPtr->isLatestArrayKind())
411       continue;
412     if (!MemAccessPtr->isWrite())
413       return false;
414     auto AccMap = MemAccessPtr->getLatestAccessRelation();
415     if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
416       return false;
417     MMI.WriteToC = MemAccessPtr;
418     break;
419   }
420 
421   if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k))
422     return false;
423 
424   if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI))
425     return false;
426 
427   if (!MMI.A || !MMI.B || !MMI.ReadFromC)
428     return false;
429   return true;
430 }
431 
432 /// Permute two dimensions of the band node.
433 ///
434 /// Permute FirstDim and SecondDim dimensions of the Node.
435 ///
436 /// @param Node The band node to be modified.
437 /// @param FirstDim The first dimension to be permuted.
438 /// @param SecondDim The second dimension to be permuted.
439 static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node,
440                                                     unsigned FirstDim,
441                                                     unsigned SecondDim) {
442   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band &&
443          (unsigned)isl_schedule_node_band_n_member(Node.get()) >
444              std::max(FirstDim, SecondDim));
445   auto PartialSchedule =
446       isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get()));
447   auto PartialScheduleFirstDim = PartialSchedule.get_union_pw_aff(FirstDim);
448   auto PartialScheduleSecondDim = PartialSchedule.get_union_pw_aff(SecondDim);
449   PartialSchedule =
450       PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim);
451   PartialSchedule =
452       PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim);
453   Node = isl::manage(isl_schedule_node_delete(Node.release()));
454   return Node.insert_partial_schedule(PartialSchedule);
455 }
456 
457 static isl::schedule_node
458 createMicroKernel(isl::schedule_node Node,
459                   MicroKernelParamsTy MicroKernelParams) {
460   Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr},
461                              1);
462   Node = Node.parent().parent();
463   return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0);
464 }
465 
466 /// Create the BLIS macro-kernel.
467 ///
468 /// We create the BLIS macro-kernel by applying a combination of tiling
469 /// of dimensions of the band node and interchanging of two innermost
470 /// modified dimensions. The values of of MacroKernelParams's fields are used
471 /// as tile sizes.
472 ///
473 /// @param Node The schedule node to be modified.
474 /// @param MacroKernelParams Parameters of the macro kernel
475 ///                          to be used as tile sizes.
476 static isl::schedule_node
477 createMacroKernel(isl::schedule_node Node,
478                   MacroKernelParamsTy MacroKernelParams) {
479   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
480   if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 &&
481       MacroKernelParams.Kc == 1)
482     return Node;
483   int DimOutNum = isl_schedule_node_band_n_member(Node.get());
484   std::vector<int> TileSizes(DimOutNum, 1);
485   TileSizes[DimOutNum - 3] = MacroKernelParams.Mc;
486   TileSizes[DimOutNum - 2] = MacroKernelParams.Nc;
487   TileSizes[DimOutNum - 1] = MacroKernelParams.Kc;
488   Node = tileNode(Node, "1st level tiling", TileSizes, 1);
489   Node = Node.parent().parent();
490   Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1);
491   Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1);
492 
493   // Mark the outermost loop as parallelizable.
494   Node = Node.band_member_set_coincident(0, true);
495 
496   return Node.child(0).child(0);
497 }
498 
499 /// Get the size of the widest type of the matrix multiplication operands
500 /// in bytes, including alignment padding.
501 ///
502 /// @param MMI Parameters of the matrix multiplication operands.
503 /// @return The size of the widest type of the matrix multiplication operands
504 ///         in bytes, including alignment padding.
505 static uint64_t getMatMulAlignTypeSize(MatMulInfoTy MMI) {
506   auto *S = MMI.A->getStatement()->getParent();
507   auto &DL = S->getFunction().getParent()->getDataLayout();
508   auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType());
509   auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType());
510   auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType());
511   return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
512 }
513 
514 /// Get the size of the widest type of the matrix multiplication operands
515 /// in bits.
516 ///
517 /// @param MMI Parameters of the matrix multiplication operands.
518 /// @return The size of the widest type of the matrix multiplication operands
519 ///         in bits.
520 static uint64_t getMatMulTypeSize(MatMulInfoTy MMI) {
521   auto *S = MMI.A->getStatement()->getParent();
522   auto &DL = S->getFunction().getParent()->getDataLayout();
523   auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType());
524   auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType());
525   auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType());
526   return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
527 }
528 
529 /// Get parameters of the BLIS micro kernel.
530 ///
531 /// We choose the Mr and Nr parameters of the micro kernel to be large enough
532 /// such that no stalls caused by the combination of latencies and dependencies
533 /// are introduced during the updates of the resulting matrix of the matrix
534 /// multiplication. However, they should also be as small as possible to
535 /// release more registers for entries of multiplied matrices.
536 ///
537 /// @param TTI Target Transform Info.
538 /// @param MMI Parameters of the matrix multiplication operands.
539 /// @return The structure of type MicroKernelParamsTy.
540 /// @see MicroKernelParamsTy
541 static struct MicroKernelParamsTy
542 getMicroKernelParams(const TargetTransformInfo *TTI, MatMulInfoTy MMI) {
543   assert(TTI && "The target transform info should be provided.");
544 
545   // Nvec - Number of double-precision floating-point numbers that can be hold
546   // by a vector register. Use 2 by default.
547   long RegisterBitwidth = VectorRegisterBitwidth;
548 
549   if (RegisterBitwidth == -1)
550     RegisterBitwidth =
551         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
552   auto ElementSize = getMatMulTypeSize(MMI);
553   assert(ElementSize > 0 && "The element size of the matrix multiplication "
554                             "operands should be greater than zero.");
555   auto Nvec = RegisterBitwidth / ElementSize;
556   if (Nvec == 0)
557     Nvec = 2;
558   int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) /
559                 Nvec) *
560            Nvec;
561   int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr));
562   return {Mr, Nr};
563 }
564 
565 /// Determine parameters of the target cache.
566 ///
567 /// @param TTI Target Transform Info.
568 static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) {
569   auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D;
570   auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D;
571   if (FirstCacheLevelSize == -1) {
572     if (TTI->getCacheSize(L1DCache).hasValue())
573       FirstCacheLevelSize = TTI->getCacheSize(L1DCache).getValue();
574     else
575       FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize);
576   }
577   if (SecondCacheLevelSize == -1) {
578     if (TTI->getCacheSize(L2DCache).hasValue())
579       SecondCacheLevelSize = TTI->getCacheSize(L2DCache).getValue();
580     else
581       SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize);
582   }
583   if (FirstCacheLevelAssociativity == -1) {
584     if (TTI->getCacheAssociativity(L1DCache).hasValue())
585       FirstCacheLevelAssociativity =
586           TTI->getCacheAssociativity(L1DCache).getValue();
587     else
588       FirstCacheLevelAssociativity =
589           static_cast<int>(FirstCacheLevelDefaultAssociativity);
590   }
591   if (SecondCacheLevelAssociativity == -1) {
592     if (TTI->getCacheAssociativity(L2DCache).hasValue())
593       SecondCacheLevelAssociativity =
594           TTI->getCacheAssociativity(L2DCache).getValue();
595     else
596       SecondCacheLevelAssociativity =
597           static_cast<int>(SecondCacheLevelDefaultAssociativity);
598   }
599 }
600 
601 /// Get parameters of the BLIS macro kernel.
602 ///
603 /// During the computation of matrix multiplication, blocks of partitioned
604 /// matrices are mapped to different layers of the memory hierarchy.
605 /// To optimize data reuse, blocks should be ideally kept in cache between
606 /// iterations. Since parameters of the macro kernel determine sizes of these
607 /// blocks, there are upper and lower bounds on these parameters.
608 ///
609 /// @param TTI Target Transform Info.
610 /// @param MicroKernelParams Parameters of the micro-kernel
611 ///                          to be taken into account.
612 /// @param MMI Parameters of the matrix multiplication operands.
613 /// @return The structure of type MacroKernelParamsTy.
614 /// @see MacroKernelParamsTy
615 /// @see MicroKernelParamsTy
616 static struct MacroKernelParamsTy
617 getMacroKernelParams(const llvm::TargetTransformInfo *TTI,
618                      const MicroKernelParamsTy &MicroKernelParams,
619                      MatMulInfoTy MMI) {
620   getTargetCacheParameters(TTI);
621   // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf,
622   // it requires information about the first two levels of a cache to determine
623   // all the parameters of a macro-kernel. It also checks that an associativity
624   // degree of a cache level is greater than two. Otherwise, another algorithm
625   // for determination of the parameters should be used.
626   if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 &&
627         FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 &&
628         FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2))
629     return {1, 1, 1};
630   // The quotient should be greater than zero.
631   if (PollyPatternMatchingNcQuotient <= 0)
632     return {1, 1, 1};
633   int Car = floor(
634       (FirstCacheLevelAssociativity - 1) /
635       (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr));
636 
637   // Car can be computed to be zero since it is floor to int.
638   // On Mac OS, division by 0 does not raise a signal. This causes negative
639   // tile sizes to be computed. Prevent division by Cac==0 by early returning
640   // if this happens.
641   if (Car == 0)
642     return {1, 1, 1};
643 
644   auto ElementSize = getMatMulAlignTypeSize(MMI);
645   assert(ElementSize > 0 && "The element size of the matrix multiplication "
646                             "operands should be greater than zero.");
647   int Kc = (Car * FirstCacheLevelSize) /
648            (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize);
649   double Cac =
650       static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) /
651       SecondCacheLevelSize;
652   int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac);
653   int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr;
654 
655   assert(Mc > 0 && Nc > 0 && Kc > 0 &&
656          "Matrix block sizes should be  greater than zero");
657   return {Mc, Nc, Kc};
658 }
659 
660 /// Create an access relation that is specific to
661 ///        the matrix multiplication pattern.
662 ///
663 /// Create an access relation of the following form:
664 /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ]
665 /// where I is @p FirstDim, J is @p SecondDim.
666 ///
667 /// It can be used, for example, to create relations that helps to consequently
668 /// access elements of operands of a matrix multiplication after creation of
669 /// the BLIS micro and macro kernels.
670 ///
671 /// @see ScheduleTreeOptimizer::createMicroKernel
672 /// @see ScheduleTreeOptimizer::createMacroKernel
673 ///
674 /// Subsequently, the described access relation is applied to the range of
675 /// @p MapOldIndVar, that is used to map original induction variables to
676 /// the ones, which are produced by schedule transformations. It helps to
677 /// define relations using a new space and, at the same time, keep them
678 /// in the original one.
679 ///
680 /// @param MapOldIndVar The relation, which maps original induction variables
681 ///                     to the ones, which are produced by schedule
682 ///                     transformations.
683 /// @param FirstDim, SecondDim The input dimensions that are used to define
684 ///        the specified access relation.
685 /// @return The specified access relation.
686 static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim,
687                                 unsigned SecondDim) {
688   auto AccessRelSpace = isl::space(MapOldIndVar.get_ctx(), 0, 9, 3);
689   auto AccessRel = isl::map::universe(AccessRelSpace);
690   AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0);
691   AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1);
692   AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2);
693   return MapOldIndVar.apply_range(AccessRel);
694 }
695 
696 static isl::schedule_node createExtensionNode(isl::schedule_node Node,
697                                               isl::map ExtensionMap) {
698   auto Extension = isl::union_map(ExtensionMap);
699   auto NewNode = isl::schedule_node::from_extension(Extension);
700   return Node.graft_before(NewNode);
701 }
702 
703 /// Apply the packing transformation.
704 ///
705 /// The packing transformation can be described as a data-layout
706 /// transformation that requires to introduce a new array, copy data
707 /// to the array, and change memory access locations to reference the array.
708 /// It can be used to ensure that elements of the new array are read in-stride
709 /// access, aligned to cache lines boundaries, and preloaded into certain cache
710 /// levels.
711 ///
712 /// As an example let us consider the packing of the array A that would help
713 /// to read its elements with in-stride access. An access to the array A
714 /// is represented by an access relation that has the form
715 /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has
716 /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr),
717 /// k mod Kc, j mod Nr, i mod Mr].
718 ///
719 /// To ensure that elements of the array A are read in-stride access, we add
720 /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using
721 /// Scop::createScopArrayInfo, change the access relation
722 /// S[i, j, k] -> A[i, k] to
723 /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using
724 /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using
725 /// the copy statement created by Scop::addScopStmt.
726 ///
727 /// @param Node The schedule node to be optimized.
728 /// @param MapOldIndVar The relation, which maps original induction variables
729 ///                     to the ones, which are produced by schedule
730 ///                     transformations.
731 /// @param MicroParams, MacroParams Parameters of the BLIS kernel
732 ///                                 to be taken into account.
733 /// @param MMI Parameters of the matrix multiplication operands.
734 /// @return The optimized schedule node.
735 static isl::schedule_node
736 optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar,
737                                  MicroKernelParamsTy MicroParams,
738                                  MacroKernelParamsTy MacroParams,
739                                  MatMulInfoTy &MMI) {
740   auto InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
741   auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
742 
743   // Create a copy statement that corresponds to the memory access to the
744   // matrix B, the second operand of the matrix multiplication.
745   Node = Node.parent().parent().parent().parent().parent().parent();
746   Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2)).child(0);
747   auto AccRel = getMatMulAccRel(MapOldIndVar, 3, 7);
748   unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
749   unsigned SecondDimSize = MacroParams.Kc;
750   unsigned ThirdDimSize = MicroParams.Nr;
751   auto *SAI = Stmt->getParent()->createScopArrayInfo(
752       MMI.B->getElementType(), "Packed_B",
753       {FirstDimSize, SecondDimSize, ThirdDimSize});
754   AccRel = AccRel.set_tuple_id(isl::dim::out, SAI->getBasePtrId());
755   auto OldAcc = MMI.B->getLatestAccessRelation();
756   MMI.B->setNewAccessRelation(AccRel);
757   auto ExtMap = MapOldIndVar.project_out(isl::dim::out, 2,
758                                          MapOldIndVar.dim(isl::dim::out) - 2);
759   ExtMap = ExtMap.reverse();
760   ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0);
761   auto Domain = Stmt->getDomain();
762 
763   // Restrict the domains of the copy statements to only execute when also its
764   // originating statement is executed.
765   auto DomainId = Domain.get_tuple_id();
766   auto *NewStmt = Stmt->getParent()->addScopStmt(
767       OldAcc, MMI.B->getLatestAccessRelation(), Domain);
768   ExtMap = ExtMap.set_tuple_id(isl::dim::out, DomainId);
769   ExtMap = ExtMap.intersect_range(Domain);
770   ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId());
771   Node = createExtensionNode(Node, ExtMap);
772 
773   // Create a copy statement that corresponds to the memory access
774   // to the matrix A, the first operand of the matrix multiplication.
775   Node = Node.child(0);
776   AccRel = getMatMulAccRel(MapOldIndVar, 4, 6);
777   FirstDimSize = MacroParams.Mc / MicroParams.Mr;
778   ThirdDimSize = MicroParams.Mr;
779   SAI = Stmt->getParent()->createScopArrayInfo(
780       MMI.A->getElementType(), "Packed_A",
781       {FirstDimSize, SecondDimSize, ThirdDimSize});
782   AccRel = AccRel.set_tuple_id(isl::dim::out, SAI->getBasePtrId());
783   OldAcc = MMI.A->getLatestAccessRelation();
784   MMI.A->setNewAccessRelation(AccRel);
785   ExtMap = MapOldIndVar.project_out(isl::dim::out, 3,
786                                     MapOldIndVar.dim(isl::dim::out) - 3);
787   ExtMap = ExtMap.reverse();
788   ExtMap = ExtMap.fix_si(isl::dim::out, MMI.j, 0);
789   NewStmt = Stmt->getParent()->addScopStmt(
790       OldAcc, MMI.A->getLatestAccessRelation(), Domain);
791 
792   // Restrict the domains of the copy statements to only execute when also its
793   // originating statement is executed.
794   ExtMap = ExtMap.set_tuple_id(isl::dim::out, DomainId);
795   ExtMap = ExtMap.intersect_range(Domain);
796   ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId());
797   Node = createExtensionNode(Node, ExtMap);
798   return Node.child(0).child(0).child(0).child(0).child(0);
799 }
800 
801 /// Get a relation mapping induction variables produced by schedule
802 /// transformations to the original ones.
803 ///
804 /// @param Node The schedule node produced as the result of creation
805 ///        of the BLIS kernels.
806 /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel
807 ///                                             to be taken into account.
808 /// @return  The relation mapping original induction variables to the ones
809 ///          produced by schedule transformation.
810 /// @see ScheduleTreeOptimizer::createMicroKernel
811 /// @see ScheduleTreeOptimizer::createMacroKernel
812 /// @see getMacroKernelParams
813 static isl::map
814 getInductionVariablesSubstitution(isl::schedule_node Node,
815                                   MicroKernelParamsTy MicroKernelParams,
816                                   MacroKernelParamsTy MacroKernelParams) {
817   auto Child = Node.child(0);
818   auto UnMapOldIndVar = Child.get_prefix_schedule_union_map();
819   auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar);
820   if (MapOldIndVar.dim(isl::dim::out) > 9)
821     return MapOldIndVar.project_out(isl::dim::out, 0,
822                                     MapOldIndVar.dim(isl::dim::out) - 9);
823   return MapOldIndVar;
824 }
825 
826 /// Isolate a set of partial tile prefixes and unroll the isolated part.
827 ///
828 /// The set should ensure that it contains only partial tile prefixes that have
829 /// exactly Mr x Nr iterations of the two innermost loops produced by
830 /// the optimization of the matrix multiplication. Mr and Nr are parameters of
831 /// the micro-kernel.
832 ///
833 /// In case of parametric bounds, this helps to auto-vectorize the unrolled
834 /// innermost loops, using the SLP vectorizer.
835 ///
836 /// @param Node              The schedule node to be modified.
837 /// @param MicroKernelParams Parameters of the micro-kernel
838 ///                          to be taken into account.
839 /// @return The modified isl_schedule_node.
840 static isl::schedule_node
841 isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,
842                                  struct MicroKernelParamsTy MicroKernelParams) {
843   isl::schedule_node Child = Node.get_child(0);
844   isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation();
845   isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range();
846   isl_size Dims = Prefix.dim(isl::dim::set);
847   Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1);
848   Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr);
849   Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr);
850 
851   isl::union_set IsolateOption =
852       getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3);
853   isl::ctx Ctx = Node.get_ctx();
854   auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll"));
855   Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
856   Node = Node.band_set_ast_build_options(Options);
857   Node = Node.parent().parent().parent();
858   IsolateOption = getIsolateOptions(Prefix, 3);
859   Options = IsolateOption.unite(getDimOptions(Ctx, "separate"));
860   Node = Node.band_set_ast_build_options(Options);
861   Node = Node.child(0).child(0).child(0);
862   return Node;
863 }
864 
865 /// Mark @p BasePtr with "Inter iteration alias-free" mark node.
866 ///
867 /// @param Node The child of the mark node to be inserted.
868 /// @param BasePtr The pointer to be marked.
869 /// @return The modified isl_schedule_node.
870 static isl::schedule_node markInterIterationAliasFree(isl::schedule_node Node,
871                                                       Value *BasePtr) {
872   if (!BasePtr)
873     return Node;
874 
875   auto Id =
876       isl::id::alloc(Node.get_ctx(), "Inter iteration alias-free", BasePtr);
877   return Node.insert_mark(Id).child(0);
878 }
879 
880 /// Insert "Loop Vectorizer Disabled" mark node.
881 ///
882 /// @param Node The child of the mark node to be inserted.
883 /// @return The modified isl_schedule_node.
884 static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
885   auto Id = isl::id::alloc(Node.get_ctx(), "Loop Vectorizer Disabled", nullptr);
886   return Node.insert_mark(Id).child(0);
887 }
888 
889 /// Restore the initial ordering of dimensions of the band node
890 ///
891 /// In case the band node represents all the dimensions of the iteration
892 /// domain, recreate the band node to restore the initial ordering of the
893 /// dimensions.
894 ///
895 /// @param Node The band node to be modified.
896 /// @return The modified schedule node.
897 static isl::schedule_node
898 getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
899   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
900   if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf)
901     return Node;
902   auto Domain = Node.get_universe_domain();
903   assert(isl_union_set_n_set(Domain.get()) == 1);
904   if (Node.get_schedule_depth() != 0 ||
905       (isl::set(Domain).dim(isl::dim::set) !=
906        isl_schedule_node_band_n_member(Node.get())))
907     return Node;
908   Node = isl::manage(isl_schedule_node_delete(Node.copy()));
909   auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff();
910   auto PartialScheduleMultiPwAff =
911       isl::multi_union_pw_aff(PartialSchedulePwAff);
912   PartialScheduleMultiPwAff =
913       PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set);
914   return Node.insert_partial_schedule(PartialScheduleMultiPwAff);
915 }
916 
917 static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node,
918                                                 const TargetTransformInfo *TTI,
919                                                 MatMulInfoTy &MMI) {
920   assert(TTI && "The target transform info should be provided.");
921   Node = markInterIterationAliasFree(
922       Node, MMI.WriteToC->getLatestScopArrayInfo()->getBasePtr());
923   int DimOutNum = isl_schedule_node_band_n_member(Node.get());
924   assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
925                           "and, consequently, the corresponding scheduling "
926                           "functions have at least three dimensions.");
927   Node = getBandNodeWithOriginDimOrder(Node);
928   Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
929   int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j;
930   int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k;
931   Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2);
932   NewK = NewK == DimOutNum - 2 ? NewJ : NewK;
933   Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1);
934   auto MicroKernelParams = getMicroKernelParams(TTI, MMI);
935   auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI);
936   Node = createMacroKernel(Node, MacroKernelParams);
937   Node = createMicroKernel(Node, MicroKernelParams);
938   if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 ||
939       MacroKernelParams.Kc == 1)
940     return Node;
941   auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams,
942                                                         MacroKernelParams);
943   if (!MapOldIndVar)
944     return Node;
945   Node = markLoopVectorizerDisabled(Node.parent()).child(0);
946   Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
947   return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
948                                           MacroKernelParams, MMI);
949 }
950 
951 /// Check if this node contains a partial schedule that could
952 ///        probably be optimized with analytical modeling.
953 ///
954 /// isMatrMultPattern tries to determine whether the following conditions
955 /// are true:
956 /// 1. the partial schedule contains only one statement.
957 /// 2. there are exactly three input dimensions.
958 /// 3. all memory accesses of the statement will have stride 0 or 1, if we
959 ///    interchange loops (switch the variable used in the inner loop to
960 ///    the outer loop).
961 /// 4. all memory accesses of the statement except from the last one, are
962 ///    read memory access and the last one is write memory access.
963 /// 5. all subscripts of the last memory access of the statement don't
964 ///    contain the variable used in the inner loop.
965 /// If this is the case, we could try to use an approach that is similar to
966 /// the one used to get close-to-peak performance of matrix multiplications.
967 ///
968 /// @param Node The node to check.
969 /// @param D    The SCoP dependencies.
970 /// @param MMI  Parameters of the matrix multiplication operands.
971 static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D,
972                               MatMulInfoTy &MMI) {
973   auto PartialSchedule = isl::manage(
974       isl_schedule_node_band_get_partial_schedule_union_map(Node.get()));
975   Node = Node.child(0);
976   auto LeafType = isl_schedule_node_get_type(Node.get());
977   Node = Node.parent();
978   if (LeafType != isl_schedule_node_leaf ||
979       isl_schedule_node_band_n_member(Node.get()) < 3 ||
980       Node.get_schedule_depth() != 0 ||
981       isl_union_map_n_map(PartialSchedule.get()) != 1)
982     return false;
983   auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule);
984   if (containsMatrMult(NewPartialSchedule, D, MMI))
985     return true;
986   return false;
987 }
988 
989 } // namespace
990 
991 isl::schedule_node
992 polly::tryOptimizeMatMulPattern(isl::schedule_node Node,
993                                 const llvm::TargetTransformInfo *TTI,
994                                 const Dependences *D) {
995   MatMulInfoTy MMI;
996   if (isMatrMultPattern(Node, D, MMI)) {
997     LLVM_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
998     return optimizeMatMulPattern(Node, TTI, MMI);
999   }
1000   return {};
1001 }
1002