1 //===- Interchange.cpp - Linalg interchange transformation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg interchange transformation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 #define DEBUG_TYPE "linalg-interchange" 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 35 LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( 36 Operation *op, ArrayRef<unsigned> interchangeVector) { 37 // Transformation applies to generic ops only. 38 if (!isa<GenericOp, IndexedGenericOp>(op)) 39 return failure(); 40 LinalgOp linalgOp = cast<LinalgOp>(op); 41 // Interchange vector must be non-empty and match the number of loops. 42 if (interchangeVector.empty() || 43 linalgOp.getNumLoops() != interchangeVector.size()) 44 return failure(); 45 // Permutation map must be invertible. 46 if (!inversePermutation( 47 AffineMap::getPermutationMap(interchangeVector, op->getContext()))) 48 return failure(); 49 return success(); 50 } 51 52 void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op, 53 ArrayRef<unsigned> interchangeVector) { 54 // 1. Compute the inverse permutation map. 55 MLIRContext *context = op.getContext(); 56 AffineMap permutationMap = inversePermutation( 57 AffineMap::getPermutationMap(interchangeVector, context)); 58 assert(permutationMap && "expected permutation to be invertible"); 59 assert(interchangeVector.size() == op.getNumLoops() && 60 "expected interchange vector to have entry for every loop"); 61 62 // 2. Compute the interchanged indexing maps. 63 SmallVector<Attribute, 4> newIndexingMaps; 64 ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue(); 65 for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) { 66 AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue(); 67 if (!permutationMap.isEmpty()) 68 m = m.compose(permutationMap); 69 newIndexingMaps.push_back(AffineMapAttr::get(m)); 70 } 71 op->setAttr(getIndexingMapsAttrName(), 72 ArrayAttr::get(context, newIndexingMaps)); 73 74 // 3. Compute the interchanged iterator types. 75 ArrayRef<Attribute> itTypes = op.iterator_types().getValue(); 76 SmallVector<Attribute, 4> itTypesVector; 77 llvm::append_range(itTypesVector, itTypes); 78 applyPermutationToVector(itTypesVector, interchangeVector); 79 op->setAttr(getIteratorTypesAttrName(), 80 ArrayAttr::get(context, itTypesVector)); 81 82 // 4. Transform the index operations by applying the permutation map. 83 if (op.hasIndexSemantics()) { 84 // TODO: Remove the assertion and add a getBody() method to LinalgOp 85 // interface once every LinalgOp has a body. 86 assert(op->getNumRegions() == 1 && 87 op->getRegion(0).getBlocks().size() == 1 && 88 "expected generic operation to have one block."); 89 Block &block = op->getRegion(0).front(); 90 OpBuilder::InsertionGuard guard(rewriter); 91 for (IndexOp indexOp : 92 llvm::make_early_inc_range(block.getOps<IndexOp>())) { 93 rewriter.setInsertionPoint(indexOp); 94 SmallVector<Value> allIndices; 95 allIndices.reserve(op.getNumLoops()); 96 llvm::transform(llvm::seq<int64_t>(0, op.getNumLoops()), 97 std::back_inserter(allIndices), [&](int64_t dim) { 98 return rewriter.create<IndexOp>(indexOp->getLoc(), dim); 99 }); 100 rewriter.replaceOpWithNewOp<AffineApplyOp>( 101 indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); 102 } 103 } 104 } 105