1 //===- Interchange.cpp - Linalg interchange transformation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg interchange transformation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 #define DEBUG_TYPE "linalg-interchange" 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 35 LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( 36 Operation *op, ArrayRef<unsigned> interchangeVector) { 37 if (interchangeVector.empty()) 38 return failure(); 39 // Transformation applies to generic ops only. 40 if (!isa<GenericOp>(op) && !isa<IndexedGenericOp>(op)) 41 return failure(); 42 LinalgOp linOp = cast<LinalgOp>(op); 43 // Transformation applies to buffers only. 44 if (!linOp.hasBufferSemantics()) 45 return failure(); 46 // Permutation must be applicable. 47 if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) 48 return failure(); 49 // Permutation map must be invertible. 50 if (!inversePermutation( 51 AffineMap::getPermutationMap(interchangeVector, op->getContext()))) 52 return failure(); 53 return success(); 54 } 55 56 LinalgOp mlir::linalg::interchange(LinalgOp op, 57 ArrayRef<unsigned> interchangeVector) { 58 if (interchangeVector.empty()) 59 return op; 60 61 MLIRContext *context = op.getContext(); 62 auto permutationMap = inversePermutation( 63 AffineMap::getPermutationMap(interchangeVector, context)); 64 assert(permutationMap && "expected permutation to be invertible"); 65 SmallVector<Attribute, 4> newIndexingMaps; 66 auto indexingMaps = op.indexing_maps().getValue(); 67 for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) { 68 AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue(); 69 if (!permutationMap.isEmpty()) 70 m = m.compose(permutationMap); 71 newIndexingMaps.push_back(AffineMapAttr::get(m)); 72 } 73 auto itTypes = op.iterator_types().getValue(); 74 SmallVector<Attribute, 4> itTypesVector; 75 for (unsigned i = 0, e = itTypes.size(); i != e; ++i) 76 itTypesVector.push_back(itTypes[i]); 77 applyPermutationToVector(itTypesVector, interchangeVector); 78 79 op.setAttr(getIndexingMapsAttrName(), 80 ArrayAttr::get(newIndexingMaps, context)); 81 op.setAttr(getIteratorTypesAttrName(), 82 ArrayAttr::get(itTypesVector, context)); 83 84 return op; 85 } 86