1 //===- MatmulOptimizer.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "polly/MatmulOptimizer.h" 10 #include "polly/DependenceInfo.h" 11 #include "polly/Options.h" 12 #include "polly/ScheduleTreeTransform.h" 13 #include "polly/ScopInfo.h" 14 #include "polly/ScopPass.h" 15 #include "polly/Simplify.h" 16 #include "llvm/ADT/ArrayRef.h" 17 #include "llvm/ADT/Optional.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/iterator_range.h" 22 #include "llvm/Analysis/TargetTransformInfo.h" 23 #include "llvm/IR/DataLayout.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/TypeSize.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include "isl/ctx.h" 31 #include "isl/schedule_node.h" 32 #include "isl/schedule_type.h" 33 #include "isl/union_map.h" 34 #include "isl/union_set.h" 35 #include <algorithm> 36 #include <cassert> 37 #include <cmath> 38 #include <cstdint> 39 #include <string> 40 #include <vector> 41 42 #define DEBUG_TYPE "polly-opt-isl" 43 44 using namespace llvm; 45 using namespace polly; 46 47 namespace llvm { 48 class Value; 49 } 50 51 static cl::opt<int> LatencyVectorFma( 52 "polly-target-latency-vector-fma", 53 cl::desc("The minimal number of cycles between issuing two " 54 "dependent consecutive vector fused multiply-add " 55 "instructions."), 56 cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory)); 57 58 static cl::opt<int> ThroughputVectorFma( 59 "polly-target-throughput-vector-fma", 60 cl::desc("A throughput of the processor floating-point arithmetic units " 61 "expressed in the number of vector fused multiply-add " 62 "instructions per clock cycle."), 63 cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory)); 64 65 static cl::opt<int> FirstCacheLevelSize( 66 "polly-target-1st-cache-level-size", 67 cl::desc("The size of the first cache level specified in bytes."), 68 cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory)); 69 70 static cl::opt<int> FirstCacheLevelDefaultSize( 71 "polly-target-1st-cache-level-default-size", 72 cl::desc("The default size of the first cache level specified in bytes" 73 " (if not enough were provided by the TargetTransformInfo)."), 74 cl::Hidden, cl::init(32768), cl::ZeroOrMore, cl::cat(PollyCategory)); 75 76 static cl::opt<int> SecondCacheLevelSize( 77 "polly-target-2nd-cache-level-size", 78 cl::desc("The size of the second level specified in bytes."), cl::Hidden, 79 cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory)); 80 81 static cl::opt<int> SecondCacheLevelDefaultSize( 82 "polly-target-2nd-cache-level-default-size", 83 cl::desc("The default size of the second cache level specified in bytes" 84 " (if not enough were provided by the TargetTransformInfo)."), 85 cl::Hidden, cl::init(262144), cl::ZeroOrMore, cl::cat(PollyCategory)); 86 87 // This option, along with --polly-target-2nd-cache-level-associativity, 88 // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size 89 // represent the parameters of the target cache, which do not have typical 90 // values that can be used by default. However, to apply the pattern matching 91 // optimizations, we use the values of the parameters of Intel Core i7-3820 92 // SandyBridge in case the parameters are not specified or not provided by the 93 // TargetTransformInfo. 94 static cl::opt<int> FirstCacheLevelAssociativity( 95 "polly-target-1st-cache-level-associativity", 96 cl::desc("The associativity of the first cache level."), cl::Hidden, 97 cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory)); 98 99 static cl::opt<int> FirstCacheLevelDefaultAssociativity( 100 "polly-target-1st-cache-level-default-associativity", 101 cl::desc("The default associativity of the first cache level" 102 " (if not enough were provided by the TargetTransformInfo)."), 103 cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory)); 104 105 static cl::opt<int> SecondCacheLevelAssociativity( 106 "polly-target-2nd-cache-level-associativity", 107 cl::desc("The associativity of the second cache level."), cl::Hidden, 108 cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory)); 109 110 static cl::opt<int> SecondCacheLevelDefaultAssociativity( 111 "polly-target-2nd-cache-level-default-associativity", 112 cl::desc("The default associativity of the second cache level" 113 " (if not enough were provided by the TargetTransformInfo)."), 114 cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory)); 115 116 static cl::opt<int> VectorRegisterBitwidth( 117 "polly-target-vector-register-bitwidth", 118 cl::desc("The size in bits of a vector register (if not set, this " 119 "information is taken from LLVM's target information."), 120 cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory)); 121 122 static cl::opt<int> PollyPatternMatchingNcQuotient( 123 "polly-pattern-matching-nc-quotient", 124 cl::desc("Quotient that is obtained by dividing Nc, the parameter of the" 125 "macro-kernel, by Nr, the parameter of the micro-kernel"), 126 cl::Hidden, cl::init(256), cl::ZeroOrMore, cl::cat(PollyCategory)); 127 128 namespace { 129 /// Parameters of the micro kernel. 130 /// 131 /// Parameters, which determine sizes of rank-1 (i.e., outer product) update 132 /// used in the optimized matrix multiplication. 133 struct MicroKernelParamsTy { 134 int Mr; 135 int Nr; 136 }; 137 138 /// Parameters of the macro kernel. 139 /// 140 /// Parameters, which determine sizes of blocks of partitioned matrices 141 /// used in the optimized matrix multiplication. 142 struct MacroKernelParamsTy { 143 int Mc; 144 int Nc; 145 int Kc; 146 }; 147 148 /// Parameters of the matrix multiplication operands. 149 /// 150 /// Parameters, which describe access relations that represent operands of the 151 /// matrix multiplication. 152 struct MatMulInfoTy { 153 MemoryAccess *A = nullptr; 154 MemoryAccess *B = nullptr; 155 MemoryAccess *ReadFromC = nullptr; 156 MemoryAccess *WriteToC = nullptr; 157 int i = -1; 158 int j = -1; 159 int k = -1; 160 }; 161 162 /// Create an isl::union_set, which describes the option of the form 163 /// [isolate[] -> unroll[x]]. 164 /// 165 /// @param Ctx An isl::ctx, which is used to create the isl::union_set. 166 static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) { 167 isl::space Space = isl::space(Ctx, 0, 0, 1); 168 isl::map UnrollIsolatedSetOption = isl::map::universe(Space); 169 isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr); 170 isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr); 171 UnrollIsolatedSetOption = 172 UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId); 173 UnrollIsolatedSetOption = 174 UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId); 175 return UnrollIsolatedSetOption.wrap(); 176 } 177 178 /// Permute the two dimensions of the isl map. 179 /// 180 /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that 181 /// have type @p DimType. 182 /// 183 /// @param Map The isl map to be modified. 184 /// @param DimType The type of the dimensions. 185 /// @param DstPos The first dimension. 186 /// @param SrcPos The second dimension. 187 /// @return The modified map. 188 static isl::map permuteDimensions(isl::map Map, isl::dim DimType, 189 unsigned DstPos, unsigned SrcPos) { 190 assert((isl_size)DstPos < Map.dim(DimType) && 191 (isl_size)SrcPos < Map.dim(DimType)); 192 if (DstPos == SrcPos) 193 return Map; 194 isl::id DimId; 195 if (Map.has_tuple_id(DimType)) 196 DimId = Map.get_tuple_id(DimType); 197 auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in; 198 isl::id FreeDimId; 199 if (Map.has_tuple_id(FreeDim)) 200 FreeDimId = Map.get_tuple_id(FreeDim); 201 auto MaxDim = std::max(DstPos, SrcPos); 202 auto MinDim = std::min(DstPos, SrcPos); 203 Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1); 204 Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1); 205 Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1); 206 Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1); 207 if (DimId) 208 Map = Map.set_tuple_id(DimType, DimId); 209 if (FreeDimId) 210 Map = Map.set_tuple_id(FreeDim, FreeDimId); 211 return Map; 212 } 213 214 /// Check the form of the access relation. 215 /// 216 /// Check that the access relation @p AccMap has the form M[i][j], where i 217 /// is a @p FirstPos and j is a @p SecondPos. 218 /// 219 /// @param AccMap The access relation to be checked. 220 /// @param FirstPos The index of the input dimension that is mapped to 221 /// the first output dimension. 222 /// @param SecondPos The index of the input dimension that is mapped to the 223 /// second output dimension. 224 /// @return True in case @p AccMap has the expected form and false, 225 /// otherwise. 226 static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, 227 int &SecondPos) { 228 isl::space Space = AccMap.get_space(); 229 isl::map Universe = isl::map::universe(Space); 230 231 if (Space.dim(isl::dim::out) != 2) 232 return false; 233 234 // MatMul has the form: 235 // for (i = 0; i < N; i++) 236 // for (j = 0; j < M; j++) 237 // for (k = 0; k < P; k++) 238 // C[i, j] += A[i, k] * B[k, j] 239 // 240 // Permutation of three outer loops: 3! = 6 possibilities. 241 int FirstDims[] = {0, 0, 1, 1, 2, 2}; 242 int SecondDims[] = {1, 2, 2, 0, 0, 1}; 243 for (int i = 0; i < 6; i += 1) { 244 auto PossibleMatMul = 245 Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0) 246 .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1); 247 248 AccMap = AccMap.intersect_domain(Domain); 249 PossibleMatMul = PossibleMatMul.intersect_domain(Domain); 250 251 // If AccMap spans entire domain (Non-partial write), 252 // compute FirstPos and SecondPos. 253 // If AccMap != PossibleMatMul here (the two maps have been gisted at 254 // this point), it means that the writes are not complete, or in other 255 // words, it is a Partial write and Partial writes must be rejected. 256 if (AccMap.is_equal(PossibleMatMul)) { 257 if (FirstPos != -1 && FirstPos != FirstDims[i]) 258 continue; 259 FirstPos = FirstDims[i]; 260 if (SecondPos != -1 && SecondPos != SecondDims[i]) 261 continue; 262 SecondPos = SecondDims[i]; 263 return true; 264 } 265 } 266 267 return false; 268 } 269 270 /// Does the memory access represent a non-scalar operand of the matrix 271 /// multiplication. 272 /// 273 /// Check that the memory access @p MemAccess is the read access to a non-scalar 274 /// operand of the matrix multiplication or its result. 275 /// 276 /// @param MemAccess The memory access to be checked. 277 /// @param MMI Parameters of the matrix multiplication operands. 278 /// @return True in case the memory access represents the read access 279 /// to a non-scalar operand of the matrix multiplication and 280 /// false, otherwise. 281 static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, 282 MatMulInfoTy &MMI) { 283 if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) 284 return false; 285 auto AccMap = MemAccess->getLatestAccessRelation(); 286 isl::set StmtDomain = MemAccess->getStatement()->getDomain(); 287 if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) { 288 MMI.ReadFromC = MemAccess; 289 return true; 290 } 291 if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) { 292 MMI.A = MemAccess; 293 return true; 294 } 295 if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) { 296 MMI.B = MemAccess; 297 return true; 298 } 299 return false; 300 } 301 302 /// Check accesses to operands of the matrix multiplication. 303 /// 304 /// Check that accesses of the SCoP statement, which corresponds to 305 /// the partial schedule @p PartialSchedule, are scalar in terms of loops 306 /// containing the matrix multiplication, in case they do not represent 307 /// accesses to the non-scalar operands of the matrix multiplication or 308 /// its result. 309 /// 310 /// @param PartialSchedule The partial schedule of the SCoP statement. 311 /// @param MMI Parameters of the matrix multiplication operands. 312 /// @return True in case the corresponding SCoP statement 313 /// represents matrix multiplication and false, 314 /// otherwise. 315 static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, 316 MatMulInfoTy &MMI) { 317 auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in); 318 auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user()); 319 isl_size OutDimNum = PartialSchedule.dim(isl::dim::out); 320 assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " 321 "and, consequently, the corresponding scheduling " 322 "functions have at least three dimensions."); 323 auto MapI = 324 permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1); 325 auto MapJ = 326 permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1); 327 auto MapK = 328 permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1); 329 330 auto Accesses = getAccessesInOrder(*Stmt); 331 for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) { 332 auto *MemAccessPtr = *MemA; 333 if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC && 334 !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) && 335 !(MemAccessPtr->isStrideZero(MapI)) && 336 MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK)) 337 return false; 338 } 339 return true; 340 } 341 342 /// Check for dependencies corresponding to the matrix multiplication. 343 /// 344 /// Check that there is only true dependence of the form 345 /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement 346 /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds 347 /// to the dependency produced by the matrix multiplication. 348 /// 349 /// @param Schedule The schedule of the SCoP statement. 350 /// @param D The SCoP dependencies. 351 /// @param Pos The parameter to describe an acceptable true dependence. 352 /// In case it has a negative value, try to determine its 353 /// acceptable value. 354 /// @return True in case dependencies correspond to the matrix multiplication 355 /// and false, otherwise. 356 static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D, 357 int &Pos) { 358 isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW); 359 isl::union_map Red = D->getDependences(Dependences::TYPE_RED); 360 if (Red) 361 Dep = Dep.unite(Red); 362 auto DomainSpace = Schedule.get_space().domain(); 363 auto Space = DomainSpace.map_from_domain_and_range(DomainSpace); 364 auto Deltas = Dep.extract_map(Space).deltas(); 365 isl_size DeltasDimNum = Deltas.dim(isl::dim::set); 366 for (int i = 0; i < DeltasDimNum; i++) { 367 auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i); 368 Pos = Pos < 0 && Val.is_one() ? i : Pos; 369 if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one()))) 370 return false; 371 } 372 if (DeltasDimNum == 0 || Pos < 0) 373 return false; 374 return true; 375 } 376 377 /// Check if the SCoP statement could probably be optimized with analytical 378 /// modeling. 379 /// 380 /// containsMatrMult tries to determine whether the following conditions 381 /// are true: 382 /// 1. The last memory access modeling an array, MA1, represents writing to 383 /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or 384 /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement 385 /// under consideration. 386 /// 2. There is only one loop-carried true dependency, and it has the 387 /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no 388 /// loop-carried or anti dependencies. 389 /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent 390 /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3), 391 /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively, 392 /// and all memory accesses of the SCoP that are different from MA1, MA2, 393 /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any 394 /// of loops i1, i2 and i3. 395 /// 396 /// @param PartialSchedule The PartialSchedule that contains a SCoP statement 397 /// to check. 398 /// @D The SCoP dependencies. 399 /// @MMI Parameters of the matrix multiplication operands. 400 static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, 401 MatMulInfoTy &MMI) { 402 auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); 403 auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 404 if (Stmt->size() <= 1) 405 return false; 406 407 auto Accesses = getAccessesInOrder(*Stmt); 408 for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { 409 auto *MemAccessPtr = *MemA; 410 if (!MemAccessPtr->isLatestArrayKind()) 411 continue; 412 if (!MemAccessPtr->isWrite()) 413 return false; 414 auto AccMap = MemAccessPtr->getLatestAccessRelation(); 415 if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j)) 416 return false; 417 MMI.WriteToC = MemAccessPtr; 418 break; 419 } 420 421 if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k)) 422 return false; 423 424 if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI)) 425 return false; 426 427 if (!MMI.A || !MMI.B || !MMI.ReadFromC) 428 return false; 429 return true; 430 } 431 432 /// Permute two dimensions of the band node. 433 /// 434 /// Permute FirstDim and SecondDim dimensions of the Node. 435 /// 436 /// @param Node The band node to be modified. 437 /// @param FirstDim The first dimension to be permuted. 438 /// @param SecondDim The second dimension to be permuted. 439 static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node, 440 unsigned FirstDim, 441 unsigned SecondDim) { 442 assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band && 443 (unsigned)isl_schedule_node_band_n_member(Node.get()) > 444 std::max(FirstDim, SecondDim)); 445 auto PartialSchedule = 446 isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get())); 447 auto PartialScheduleFirstDim = PartialSchedule.get_union_pw_aff(FirstDim); 448 auto PartialScheduleSecondDim = PartialSchedule.get_union_pw_aff(SecondDim); 449 PartialSchedule = 450 PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim); 451 PartialSchedule = 452 PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim); 453 Node = isl::manage(isl_schedule_node_delete(Node.release())); 454 return Node.insert_partial_schedule(PartialSchedule); 455 } 456 457 static isl::schedule_node 458 createMicroKernel(isl::schedule_node Node, 459 MicroKernelParamsTy MicroKernelParams) { 460 Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr}, 461 1); 462 Node = Node.parent().parent(); 463 return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0); 464 } 465 466 /// Create the BLIS macro-kernel. 467 /// 468 /// We create the BLIS macro-kernel by applying a combination of tiling 469 /// of dimensions of the band node and interchanging of two innermost 470 /// modified dimensions. The values of of MacroKernelParams's fields are used 471 /// as tile sizes. 472 /// 473 /// @param Node The schedule node to be modified. 474 /// @param MacroKernelParams Parameters of the macro kernel 475 /// to be used as tile sizes. 476 static isl::schedule_node 477 createMacroKernel(isl::schedule_node Node, 478 MacroKernelParamsTy MacroKernelParams) { 479 assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); 480 if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 && 481 MacroKernelParams.Kc == 1) 482 return Node; 483 int DimOutNum = isl_schedule_node_band_n_member(Node.get()); 484 std::vector<int> TileSizes(DimOutNum, 1); 485 TileSizes[DimOutNum - 3] = MacroKernelParams.Mc; 486 TileSizes[DimOutNum - 2] = MacroKernelParams.Nc; 487 TileSizes[DimOutNum - 1] = MacroKernelParams.Kc; 488 Node = tileNode(Node, "1st level tiling", TileSizes, 1); 489 Node = Node.parent().parent(); 490 Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1); 491 Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1); 492 493 // Mark the outermost loop as parallelizable. 494 Node = Node.band_member_set_coincident(0, true); 495 496 return Node.child(0).child(0); 497 } 498 499 /// Get the size of the widest type of the matrix multiplication operands 500 /// in bytes, including alignment padding. 501 /// 502 /// @param MMI Parameters of the matrix multiplication operands. 503 /// @return The size of the widest type of the matrix multiplication operands 504 /// in bytes, including alignment padding. 505 static uint64_t getMatMulAlignTypeSize(MatMulInfoTy MMI) { 506 auto *S = MMI.A->getStatement()->getParent(); 507 auto &DL = S->getFunction().getParent()->getDataLayout(); 508 auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType()); 509 auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType()); 510 auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType()); 511 return std::max({ElementSizeA, ElementSizeB, ElementSizeC}); 512 } 513 514 /// Get the size of the widest type of the matrix multiplication operands 515 /// in bits. 516 /// 517 /// @param MMI Parameters of the matrix multiplication operands. 518 /// @return The size of the widest type of the matrix multiplication operands 519 /// in bits. 520 static uint64_t getMatMulTypeSize(MatMulInfoTy MMI) { 521 auto *S = MMI.A->getStatement()->getParent(); 522 auto &DL = S->getFunction().getParent()->getDataLayout(); 523 auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType()); 524 auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType()); 525 auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType()); 526 return std::max({ElementSizeA, ElementSizeB, ElementSizeC}); 527 } 528 529 /// Get parameters of the BLIS micro kernel. 530 /// 531 /// We choose the Mr and Nr parameters of the micro kernel to be large enough 532 /// such that no stalls caused by the combination of latencies and dependencies 533 /// are introduced during the updates of the resulting matrix of the matrix 534 /// multiplication. However, they should also be as small as possible to 535 /// release more registers for entries of multiplied matrices. 536 /// 537 /// @param TTI Target Transform Info. 538 /// @param MMI Parameters of the matrix multiplication operands. 539 /// @return The structure of type MicroKernelParamsTy. 540 /// @see MicroKernelParamsTy 541 static struct MicroKernelParamsTy 542 getMicroKernelParams(const TargetTransformInfo *TTI, MatMulInfoTy MMI) { 543 assert(TTI && "The target transform info should be provided."); 544 545 // Nvec - Number of double-precision floating-point numbers that can be hold 546 // by a vector register. Use 2 by default. 547 long RegisterBitwidth = VectorRegisterBitwidth; 548 549 if (RegisterBitwidth == -1) 550 RegisterBitwidth = 551 TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector); 552 auto ElementSize = getMatMulTypeSize(MMI); 553 assert(ElementSize > 0 && "The element size of the matrix multiplication " 554 "operands should be greater than zero."); 555 auto Nvec = RegisterBitwidth / ElementSize; 556 if (Nvec == 0) 557 Nvec = 2; 558 int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) / 559 Nvec) * 560 Nvec; 561 int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr)); 562 return {Mr, Nr}; 563 } 564 565 /// Determine parameters of the target cache. 566 /// 567 /// @param TTI Target Transform Info. 568 static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) { 569 auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D; 570 auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D; 571 if (FirstCacheLevelSize == -1) { 572 if (TTI->getCacheSize(L1DCache).hasValue()) 573 FirstCacheLevelSize = TTI->getCacheSize(L1DCache).getValue(); 574 else 575 FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize); 576 } 577 if (SecondCacheLevelSize == -1) { 578 if (TTI->getCacheSize(L2DCache).hasValue()) 579 SecondCacheLevelSize = TTI->getCacheSize(L2DCache).getValue(); 580 else 581 SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize); 582 } 583 if (FirstCacheLevelAssociativity == -1) { 584 if (TTI->getCacheAssociativity(L1DCache).hasValue()) 585 FirstCacheLevelAssociativity = 586 TTI->getCacheAssociativity(L1DCache).getValue(); 587 else 588 FirstCacheLevelAssociativity = 589 static_cast<int>(FirstCacheLevelDefaultAssociativity); 590 } 591 if (SecondCacheLevelAssociativity == -1) { 592 if (TTI->getCacheAssociativity(L2DCache).hasValue()) 593 SecondCacheLevelAssociativity = 594 TTI->getCacheAssociativity(L2DCache).getValue(); 595 else 596 SecondCacheLevelAssociativity = 597 static_cast<int>(SecondCacheLevelDefaultAssociativity); 598 } 599 } 600 601 /// Get parameters of the BLIS macro kernel. 602 /// 603 /// During the computation of matrix multiplication, blocks of partitioned 604 /// matrices are mapped to different layers of the memory hierarchy. 605 /// To optimize data reuse, blocks should be ideally kept in cache between 606 /// iterations. Since parameters of the macro kernel determine sizes of these 607 /// blocks, there are upper and lower bounds on these parameters. 608 /// 609 /// @param TTI Target Transform Info. 610 /// @param MicroKernelParams Parameters of the micro-kernel 611 /// to be taken into account. 612 /// @param MMI Parameters of the matrix multiplication operands. 613 /// @return The structure of type MacroKernelParamsTy. 614 /// @see MacroKernelParamsTy 615 /// @see MicroKernelParamsTy 616 static struct MacroKernelParamsTy 617 getMacroKernelParams(const llvm::TargetTransformInfo *TTI, 618 const MicroKernelParamsTy &MicroKernelParams, 619 MatMulInfoTy MMI) { 620 getTargetCacheParameters(TTI); 621 // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf, 622 // it requires information about the first two levels of a cache to determine 623 // all the parameters of a macro-kernel. It also checks that an associativity 624 // degree of a cache level is greater than two. Otherwise, another algorithm 625 // for determination of the parameters should be used. 626 if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 && 627 FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 && 628 FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2)) 629 return {1, 1, 1}; 630 // The quotient should be greater than zero. 631 if (PollyPatternMatchingNcQuotient <= 0) 632 return {1, 1, 1}; 633 int Car = floor( 634 (FirstCacheLevelAssociativity - 1) / 635 (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr)); 636 637 // Car can be computed to be zero since it is floor to int. 638 // On Mac OS, division by 0 does not raise a signal. This causes negative 639 // tile sizes to be computed. Prevent division by Cac==0 by early returning 640 // if this happens. 641 if (Car == 0) 642 return {1, 1, 1}; 643 644 auto ElementSize = getMatMulAlignTypeSize(MMI); 645 assert(ElementSize > 0 && "The element size of the matrix multiplication " 646 "operands should be greater than zero."); 647 int Kc = (Car * FirstCacheLevelSize) / 648 (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize); 649 double Cac = 650 static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) / 651 SecondCacheLevelSize; 652 int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac); 653 int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr; 654 655 assert(Mc > 0 && Nc > 0 && Kc > 0 && 656 "Matrix block sizes should be greater than zero"); 657 return {Mc, Nc, Kc}; 658 } 659 660 /// Create an access relation that is specific to 661 /// the matrix multiplication pattern. 662 /// 663 /// Create an access relation of the following form: 664 /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ] 665 /// where I is @p FirstDim, J is @p SecondDim. 666 /// 667 /// It can be used, for example, to create relations that helps to consequently 668 /// access elements of operands of a matrix multiplication after creation of 669 /// the BLIS micro and macro kernels. 670 /// 671 /// @see ScheduleTreeOptimizer::createMicroKernel 672 /// @see ScheduleTreeOptimizer::createMacroKernel 673 /// 674 /// Subsequently, the described access relation is applied to the range of 675 /// @p MapOldIndVar, that is used to map original induction variables to 676 /// the ones, which are produced by schedule transformations. It helps to 677 /// define relations using a new space and, at the same time, keep them 678 /// in the original one. 679 /// 680 /// @param MapOldIndVar The relation, which maps original induction variables 681 /// to the ones, which are produced by schedule 682 /// transformations. 683 /// @param FirstDim, SecondDim The input dimensions that are used to define 684 /// the specified access relation. 685 /// @return The specified access relation. 686 static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim, 687 unsigned SecondDim) { 688 auto AccessRelSpace = isl::space(MapOldIndVar.get_ctx(), 0, 9, 3); 689 auto AccessRel = isl::map::universe(AccessRelSpace); 690 AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0); 691 AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1); 692 AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2); 693 return MapOldIndVar.apply_range(AccessRel); 694 } 695 696 static isl::schedule_node createExtensionNode(isl::schedule_node Node, 697 isl::map ExtensionMap) { 698 auto Extension = isl::union_map(ExtensionMap); 699 auto NewNode = isl::schedule_node::from_extension(Extension); 700 return Node.graft_before(NewNode); 701 } 702 703 /// Apply the packing transformation. 704 /// 705 /// The packing transformation can be described as a data-layout 706 /// transformation that requires to introduce a new array, copy data 707 /// to the array, and change memory access locations to reference the array. 708 /// It can be used to ensure that elements of the new array are read in-stride 709 /// access, aligned to cache lines boundaries, and preloaded into certain cache 710 /// levels. 711 /// 712 /// As an example let us consider the packing of the array A that would help 713 /// to read its elements with in-stride access. An access to the array A 714 /// is represented by an access relation that has the form 715 /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has 716 /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr), 717 /// k mod Kc, j mod Nr, i mod Mr]. 718 /// 719 /// To ensure that elements of the array A are read in-stride access, we add 720 /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using 721 /// Scop::createScopArrayInfo, change the access relation 722 /// S[i, j, k] -> A[i, k] to 723 /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using 724 /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using 725 /// the copy statement created by Scop::addScopStmt. 726 /// 727 /// @param Node The schedule node to be optimized. 728 /// @param MapOldIndVar The relation, which maps original induction variables 729 /// to the ones, which are produced by schedule 730 /// transformations. 731 /// @param MicroParams, MacroParams Parameters of the BLIS kernel 732 /// to be taken into account. 733 /// @param MMI Parameters of the matrix multiplication operands. 734 /// @return The optimized schedule node. 735 static isl::schedule_node 736 optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar, 737 MicroKernelParamsTy MicroParams, 738 MacroKernelParamsTy MacroParams, 739 MatMulInfoTy &MMI) { 740 auto InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); 741 auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); 742 743 // Create a copy statement that corresponds to the memory access to the 744 // matrix B, the second operand of the matrix multiplication. 745 Node = Node.parent().parent().parent().parent().parent().parent(); 746 Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2)).child(0); 747 auto AccRel = getMatMulAccRel(MapOldIndVar, 3, 7); 748 unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr; 749 unsigned SecondDimSize = MacroParams.Kc; 750 unsigned ThirdDimSize = MicroParams.Nr; 751 auto *SAI = Stmt->getParent()->createScopArrayInfo( 752 MMI.B->getElementType(), "Packed_B", 753 {FirstDimSize, SecondDimSize, ThirdDimSize}); 754 AccRel = AccRel.set_tuple_id(isl::dim::out, SAI->getBasePtrId()); 755 auto OldAcc = MMI.B->getLatestAccessRelation(); 756 MMI.B->setNewAccessRelation(AccRel); 757 auto ExtMap = MapOldIndVar.project_out(isl::dim::out, 2, 758 MapOldIndVar.dim(isl::dim::out) - 2); 759 ExtMap = ExtMap.reverse(); 760 ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0); 761 auto Domain = Stmt->getDomain(); 762 763 // Restrict the domains of the copy statements to only execute when also its 764 // originating statement is executed. 765 auto DomainId = Domain.get_tuple_id(); 766 auto *NewStmt = Stmt->getParent()->addScopStmt( 767 OldAcc, MMI.B->getLatestAccessRelation(), Domain); 768 ExtMap = ExtMap.set_tuple_id(isl::dim::out, DomainId); 769 ExtMap = ExtMap.intersect_range(Domain); 770 ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId()); 771 Node = createExtensionNode(Node, ExtMap); 772 773 // Create a copy statement that corresponds to the memory access 774 // to the matrix A, the first operand of the matrix multiplication. 775 Node = Node.child(0); 776 AccRel = getMatMulAccRel(MapOldIndVar, 4, 6); 777 FirstDimSize = MacroParams.Mc / MicroParams.Mr; 778 ThirdDimSize = MicroParams.Mr; 779 SAI = Stmt->getParent()->createScopArrayInfo( 780 MMI.A->getElementType(), "Packed_A", 781 {FirstDimSize, SecondDimSize, ThirdDimSize}); 782 AccRel = AccRel.set_tuple_id(isl::dim::out, SAI->getBasePtrId()); 783 OldAcc = MMI.A->getLatestAccessRelation(); 784 MMI.A->setNewAccessRelation(AccRel); 785 ExtMap = MapOldIndVar.project_out(isl::dim::out, 3, 786 MapOldIndVar.dim(isl::dim::out) - 3); 787 ExtMap = ExtMap.reverse(); 788 ExtMap = ExtMap.fix_si(isl::dim::out, MMI.j, 0); 789 NewStmt = Stmt->getParent()->addScopStmt( 790 OldAcc, MMI.A->getLatestAccessRelation(), Domain); 791 792 // Restrict the domains of the copy statements to only execute when also its 793 // originating statement is executed. 794 ExtMap = ExtMap.set_tuple_id(isl::dim::out, DomainId); 795 ExtMap = ExtMap.intersect_range(Domain); 796 ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId()); 797 Node = createExtensionNode(Node, ExtMap); 798 return Node.child(0).child(0).child(0).child(0).child(0); 799 } 800 801 /// Get a relation mapping induction variables produced by schedule 802 /// transformations to the original ones. 803 /// 804 /// @param Node The schedule node produced as the result of creation 805 /// of the BLIS kernels. 806 /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel 807 /// to be taken into account. 808 /// @return The relation mapping original induction variables to the ones 809 /// produced by schedule transformation. 810 /// @see ScheduleTreeOptimizer::createMicroKernel 811 /// @see ScheduleTreeOptimizer::createMacroKernel 812 /// @see getMacroKernelParams 813 static isl::map 814 getInductionVariablesSubstitution(isl::schedule_node Node, 815 MicroKernelParamsTy MicroKernelParams, 816 MacroKernelParamsTy MacroKernelParams) { 817 auto Child = Node.child(0); 818 auto UnMapOldIndVar = Child.get_prefix_schedule_union_map(); 819 auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar); 820 if (MapOldIndVar.dim(isl::dim::out) > 9) 821 return MapOldIndVar.project_out(isl::dim::out, 0, 822 MapOldIndVar.dim(isl::dim::out) - 9); 823 return MapOldIndVar; 824 } 825 826 /// Isolate a set of partial tile prefixes and unroll the isolated part. 827 /// 828 /// The set should ensure that it contains only partial tile prefixes that have 829 /// exactly Mr x Nr iterations of the two innermost loops produced by 830 /// the optimization of the matrix multiplication. Mr and Nr are parameters of 831 /// the micro-kernel. 832 /// 833 /// In case of parametric bounds, this helps to auto-vectorize the unrolled 834 /// innermost loops, using the SLP vectorizer. 835 /// 836 /// @param Node The schedule node to be modified. 837 /// @param MicroKernelParams Parameters of the micro-kernel 838 /// to be taken into account. 839 /// @return The modified isl_schedule_node. 840 static isl::schedule_node 841 isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node, 842 struct MicroKernelParamsTy MicroKernelParams) { 843 isl::schedule_node Child = Node.get_child(0); 844 isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation(); 845 isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range(); 846 isl_size Dims = Prefix.dim(isl::dim::set); 847 Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1); 848 Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr); 849 Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr); 850 851 isl::union_set IsolateOption = 852 getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3); 853 isl::ctx Ctx = Node.get_ctx(); 854 auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll")); 855 Options = Options.unite(getUnrollIsolatedSetOptions(Ctx)); 856 Node = Node.band_set_ast_build_options(Options); 857 Node = Node.parent().parent().parent(); 858 IsolateOption = getIsolateOptions(Prefix, 3); 859 Options = IsolateOption.unite(getDimOptions(Ctx, "separate")); 860 Node = Node.band_set_ast_build_options(Options); 861 Node = Node.child(0).child(0).child(0); 862 return Node; 863 } 864 865 /// Mark @p BasePtr with "Inter iteration alias-free" mark node. 866 /// 867 /// @param Node The child of the mark node to be inserted. 868 /// @param BasePtr The pointer to be marked. 869 /// @return The modified isl_schedule_node. 870 static isl::schedule_node markInterIterationAliasFree(isl::schedule_node Node, 871 Value *BasePtr) { 872 if (!BasePtr) 873 return Node; 874 875 auto Id = 876 isl::id::alloc(Node.get_ctx(), "Inter iteration alias-free", BasePtr); 877 return Node.insert_mark(Id).child(0); 878 } 879 880 /// Insert "Loop Vectorizer Disabled" mark node. 881 /// 882 /// @param Node The child of the mark node to be inserted. 883 /// @return The modified isl_schedule_node. 884 static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) { 885 auto Id = isl::id::alloc(Node.get_ctx(), "Loop Vectorizer Disabled", nullptr); 886 return Node.insert_mark(Id).child(0); 887 } 888 889 /// Restore the initial ordering of dimensions of the band node 890 /// 891 /// In case the band node represents all the dimensions of the iteration 892 /// domain, recreate the band node to restore the initial ordering of the 893 /// dimensions. 894 /// 895 /// @param Node The band node to be modified. 896 /// @return The modified schedule node. 897 static isl::schedule_node 898 getBandNodeWithOriginDimOrder(isl::schedule_node Node) { 899 assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); 900 if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf) 901 return Node; 902 auto Domain = Node.get_universe_domain(); 903 assert(isl_union_set_n_set(Domain.get()) == 1); 904 if (Node.get_schedule_depth() != 0 || 905 (isl::set(Domain).dim(isl::dim::set) != 906 isl_schedule_node_band_n_member(Node.get()))) 907 return Node; 908 Node = isl::manage(isl_schedule_node_delete(Node.copy())); 909 auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff(); 910 auto PartialScheduleMultiPwAff = 911 isl::multi_union_pw_aff(PartialSchedulePwAff); 912 PartialScheduleMultiPwAff = 913 PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set); 914 return Node.insert_partial_schedule(PartialScheduleMultiPwAff); 915 } 916 917 static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node, 918 const TargetTransformInfo *TTI, 919 MatMulInfoTy &MMI) { 920 assert(TTI && "The target transform info should be provided."); 921 Node = markInterIterationAliasFree( 922 Node, MMI.WriteToC->getLatestScopArrayInfo()->getBasePtr()); 923 int DimOutNum = isl_schedule_node_band_n_member(Node.get()); 924 assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest " 925 "and, consequently, the corresponding scheduling " 926 "functions have at least three dimensions."); 927 Node = getBandNodeWithOriginDimOrder(Node); 928 Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3); 929 int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j; 930 int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k; 931 Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2); 932 NewK = NewK == DimOutNum - 2 ? NewJ : NewK; 933 Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1); 934 auto MicroKernelParams = getMicroKernelParams(TTI, MMI); 935 auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI); 936 Node = createMacroKernel(Node, MacroKernelParams); 937 Node = createMicroKernel(Node, MicroKernelParams); 938 if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 || 939 MacroKernelParams.Kc == 1) 940 return Node; 941 auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams, 942 MacroKernelParams); 943 if (!MapOldIndVar) 944 return Node; 945 Node = markLoopVectorizerDisabled(Node.parent()).child(0); 946 Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams); 947 return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, 948 MacroKernelParams, MMI); 949 } 950 951 /// Check if this node contains a partial schedule that could 952 /// probably be optimized with analytical modeling. 953 /// 954 /// isMatrMultPattern tries to determine whether the following conditions 955 /// are true: 956 /// 1. the partial schedule contains only one statement. 957 /// 2. there are exactly three input dimensions. 958 /// 3. all memory accesses of the statement will have stride 0 or 1, if we 959 /// interchange loops (switch the variable used in the inner loop to 960 /// the outer loop). 961 /// 4. all memory accesses of the statement except from the last one, are 962 /// read memory access and the last one is write memory access. 963 /// 5. all subscripts of the last memory access of the statement don't 964 /// contain the variable used in the inner loop. 965 /// If this is the case, we could try to use an approach that is similar to 966 /// the one used to get close-to-peak performance of matrix multiplications. 967 /// 968 /// @param Node The node to check. 969 /// @param D The SCoP dependencies. 970 /// @param MMI Parameters of the matrix multiplication operands. 971 static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D, 972 MatMulInfoTy &MMI) { 973 auto PartialSchedule = isl::manage( 974 isl_schedule_node_band_get_partial_schedule_union_map(Node.get())); 975 Node = Node.child(0); 976 auto LeafType = isl_schedule_node_get_type(Node.get()); 977 Node = Node.parent(); 978 if (LeafType != isl_schedule_node_leaf || 979 isl_schedule_node_band_n_member(Node.get()) < 3 || 980 Node.get_schedule_depth() != 0 || 981 isl_union_map_n_map(PartialSchedule.get()) != 1) 982 return false; 983 auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule); 984 if (containsMatrMult(NewPartialSchedule, D, MMI)) 985 return true; 986 return false; 987 } 988 989 } // namespace 990 991 isl::schedule_node 992 polly::tryOptimizeMatMulPattern(isl::schedule_node Node, 993 const llvm::TargetTransformInfo *TTI, 994 const Dependences *D) { 995 MatMulInfoTy MMI; 996 if (isMatrMultPattern(Node, D, MMI)) { 997 LLVM_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n"); 998 return optimizeMatMulPattern(Node, TTI, MMI); 999 } 1000 return {}; 1001 } 1002