1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Parser/Parser.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/Support/FormatVariadic.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 using namespace mlir::transform; 24 25 /// Extracts a vector of int64_t from an array attribute. Asserts if the 26 /// attribute contains values other than integers. 27 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 28 SmallVector<int64_t> result; 29 result.reserve(attr.size()); 30 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 31 result.push_back(value.getSExtValue()); 32 return result; 33 } 34 35 /// Extracts a vector of unsigned from an array attribute. Asserts if the 36 /// attribute contains values other than intergers. May truncate. 37 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 38 SmallVector<unsigned> result; 39 result.reserve(attr.size()); 40 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 41 result.push_back(value.getZExtValue()); 42 return result; 43 } 44 45 namespace { 46 /// A simple pattern rewriter that implements no special logic. 47 class SimpleRewriter : public PatternRewriter { 48 public: 49 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 50 }; 51 } // namespace 52 53 //===----------------------------------------------------------------------===// 54 // InterchangeOp 55 //===----------------------------------------------------------------------===// 56 57 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) { 58 SmallVector<unsigned> interchangeVector = 59 extractUIntArray(getIteratorInterchange()); 60 // Exit early if no transformation is needed. 61 if (interchangeVector.empty()) 62 return target; 63 64 auto genericTarget = dyn_cast<GenericOp>(target.getOperation()); 65 if (!genericTarget) { 66 InFlightDiagnostic diag = emitOpError() 67 << "applies to " << GenericOp::getOperationName() 68 << " ops"; 69 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 70 return diag; 71 } 72 73 GenericOpInterchangePattern pattern(getContext(), interchangeVector); 74 SimpleRewriter rewriter(getContext()); 75 rewriter.setInsertionPoint(target); 76 FailureOr<GenericOp> result = 77 pattern.returningMatchAndRewrite(genericTarget, rewriter); 78 if (failed(result)) 79 return failure(); 80 81 return cast<LinalgOp>(result->getOperation()); 82 } 83 84 LogicalResult transform::InterchangeOp::verify() { 85 SmallVector<unsigned> permutation = 86 extractUIntArray(getIteratorInterchange()); 87 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 88 if (!std::is_permutation(sequence.begin(), sequence.end(), 89 permutation.begin(), permutation.end())) { 90 return emitOpError() 91 << "expects iterator_interchange to be a permutation, found " 92 << getIteratorInterchange(); 93 } 94 return success(); 95 } 96 97 //===---------------------------------------------------------------------===// 98 // PadOp 99 //===---------------------------------------------------------------------===// 100 101 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) { 102 // Convert the integer packing flags to booleans. 103 SmallVector<bool> packPaddings; 104 for (int64_t packPadding : extractI64Array(getPackPaddings())) 105 packPaddings.push_back(static_cast<bool>(packPadding)); 106 107 // Convert the padding values to attributes. 108 SmallVector<Attribute> paddingValues; 109 for (auto const &it : 110 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 111 Attribute attr = std::get<0>(it); 112 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 113 // Try to parse string attributes to obtain an attribute of element type. 114 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 115 paddingValues.push_back( 116 parseAttribute(attr.cast<StringAttr>(), elementType)); 117 if (!paddingValues.back()) { 118 InFlightDiagnostic diag = emitOpError() 119 << "expects a padding value that parses to " 120 << elementType << ", got " << std::get<0>(it); 121 diag.attachNote(target.getLoc()) << "when applied to this op"; 122 return diag; 123 } 124 continue; 125 } 126 // Otherwise, add the attribute directly. 127 if (attr.getType() != elementType) { 128 InFlightDiagnostic diag = emitOpError() 129 << "expects a padding value of type " 130 << elementType << ", got " << attr; 131 diag.attachNote(target.getLoc()) << "when applied to this op"; 132 return diag; 133 } 134 paddingValues.push_back(attr); 135 } 136 137 // Extract the transpose vectors. 138 SmallVector<SmallVector<int64_t>> transposePaddings; 139 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 140 transposePaddings.push_back( 141 extractI64Array(transposeVector.cast<ArrayAttr>())); 142 143 LinalgPaddingOptions paddingOptions; 144 paddingOptions.setPaddingValues(paddingValues); 145 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 146 paddingOptions.setPackPaddings(packPaddings); 147 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 148 paddingOptions.setTransposePaddings(transposePaddings); 149 150 LinalgPaddingPattern pattern(getContext(), paddingOptions); 151 SimpleRewriter rewriter(getContext()); 152 rewriter.setInsertionPoint(target); 153 FailureOr<LinalgOp> patternResult = 154 pattern.returningMatchAndRewrite(target, rewriter); 155 if (failed(patternResult)) { 156 InFlightDiagnostic diag = emitError() 157 << "failed to apply pattern to target op"; 158 diag.attachNote(target.getLoc()) << "target op"; 159 return diag; 160 } 161 return patternResult; 162 } 163 164 LogicalResult transform::PadOp::verify() { 165 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 166 if (any_of(packPaddings, [](int64_t packPadding) { 167 return packPadding != 0 && packPadding != 1; 168 })) { 169 return emitOpError() 170 << "expects pack_paddings to contain booleans (0/1), found " 171 << getPackPaddings(); 172 } 173 174 SmallVector<int64_t> paddingDimensions = 175 extractI64Array(getPaddingDimensions()); 176 if (any_of(paddingDimensions, 177 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 178 return emitOpError() 179 << "expects padding_dimensions to contain positive integers, found " 180 << getPaddingDimensions(); 181 } 182 183 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 184 if (any_of(hoistPaddings, 185 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 186 return emitOpError() 187 << "expects hoist_paddings to contain positive integers, found " 188 << getHoistPaddings(); 189 } 190 191 ArrayAttr transposes = getTransposePaddings(); 192 for (Attribute attr : transposes) { 193 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 194 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 195 if (!std::is_permutation(sequence.begin(), sequence.end(), 196 transpose.begin(), transpose.end())) { 197 return emitOpError() 198 << "expects transpose_paddings to be a permutation, found " 199 << attr; 200 } 201 } 202 return success(); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // ScalarizeOp 207 //===----------------------------------------------------------------------===// 208 209 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) { 210 LinalgTilingOptions tilingOptions; 211 tilingOptions.scalarizeDynamicDims(); 212 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 213 // sizes and asserts that it is not already set. 214 SmallVector<int64_t> emptyTileSizes; 215 LinalgTilingPattern pattern(getContext(), tilingOptions); 216 SimpleRewriter rewriter(getContext()); 217 rewriter.setInsertionPoint(target); 218 FailureOr<TiledLinalgOp> result = 219 pattern.returningMatchAndRewrite(target, rewriter); 220 if (failed(result)) 221 return failure(); 222 223 return result->op; 224 } 225 226 //===----------------------------------------------------------------------===// 227 // TileOp 228 //===----------------------------------------------------------------------===// 229 230 /// Apply a tiling transformation to all payload ops and store both the 231 /// tiled operation as well as the created tile loops. 232 static LogicalResult 233 applyTilingToAll(Operation *transformOp, Value target, 234 ArrayRef<int64_t> tileSizes, 235 transform::TransformResults &transformResults, 236 transform::TransformState &state, 237 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 238 // Number of loops: Number of tiles sizes that are not zero. 239 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 240 // All payload ops. These should all be LinalgOps for now. 241 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 242 243 SmallVector<Operation *> tiledLinalgOps; 244 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 245 for (unsigned int i = 0; i < numLoops; ++i) 246 loopOps[i].reserve(payloadOps.size()); 247 248 for (Operation *target : payloadOps) { 249 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 250 if (!linalgOp) 251 return transformOp->emitError("only LinalgOps are supported"); 252 253 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 254 if (failed(tiled)) 255 return failure(); 256 257 tiledLinalgOps.push_back(tiled->op); 258 if (tiled->loops.size() != numLoops) 259 // Not enough loops were generated. This usually means that the input size 260 // was smaller than the tiling size. 261 // TODO: LinalgTilingPattern should return failure(). 262 return failure(); 263 for (unsigned int i = 0; i < numLoops; ++i) 264 loopOps[i].push_back(tiled->loops[i]); 265 } 266 267 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 268 for (unsigned int i = 0; i < numLoops; ++i) 269 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 270 return success(); 271 } 272 273 LogicalResult transform::TileOp::apply(TransformResults &transformResults, 274 TransformState &state) { 275 LinalgTilingOptions tilingOptions; 276 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 277 278 if (!tileSizes.empty()) 279 tilingOptions.setTileSizes(tileSizes); 280 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 281 LinalgTilingPattern pattern(getContext(), tilingOptions); 282 283 return applyTilingToAll(getOperation(), getTarget(), tileSizes, 284 transformResults, state, [&](LinalgOp linalgOp) { 285 SimpleRewriter rewriter(linalgOp.getContext()); 286 return pattern.returningMatchAndRewrite(linalgOp, 287 rewriter); 288 }); 289 } 290 291 ParseResult transform::TileOp::parse(OpAsmParser &parser, 292 OperationState &result) { 293 StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); 294 OpAsmParser::UnresolvedOperand targetOperand; 295 SMLoc opLoc = parser.getCurrentLocation(); 296 if (parser.parseOperand(targetOperand) || 297 parser.parseOptionalAttrDict(result.attributes)) 298 return failure(); 299 Attribute sizesAttr = result.attributes.get(sizesAttrName); 300 if (!sizesAttr) 301 return parser.emitError(opLoc) 302 << "expected '" << sizesAttrName << "' attribute"; 303 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 304 if (!sizesArrayAttr) 305 return parser.emitError(opLoc) 306 << "'" << sizesAttrName << "' attribute must be an array"; 307 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 308 size_t numExpectedLoops = 309 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 310 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 311 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 312 return failure(); 313 return success(); 314 } 315 316 void TileOp::print(OpAsmPrinter &p) { 317 p << ' '; 318 p << getTarget(); 319 p.printOptionalAttrDict((*this)->getAttrs()); 320 } 321 322 void TileOp::getEffects( 323 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 324 &effects) { 325 // `target` arg is consumed and can no longer be used. 326 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 327 TransformMappingResource::get()); 328 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 329 TransformMappingResource::get()); 330 331 for (Value r : getResults()) { 332 effects.emplace_back(MemoryEffects::Write::get(), r, 333 TransformMappingResource::get()); 334 effects.emplace_back(MemoryEffects::Allocate::get(), r, 335 TransformMappingResource::get()); 336 } 337 338 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 339 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 340 } 341 342 //===----------------------------------------------------------------------===// 343 // VectorizeOp 344 //===----------------------------------------------------------------------===// 345 346 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) { 347 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 348 InFlightDiagnostic diag = emitOpError() 349 << "applies only to isolated-from-above targets"; 350 diag.attachNote(target->getLoc()) << "non-isolated target"; 351 return diag; 352 } 353 354 MLIRContext *ctx = getContext(); 355 RewritePatternSet patterns(ctx); 356 patterns.add<LinalgVectorizationPattern>(ctx); 357 358 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 359 vector::populateVectorReductionToContractPatterns(patterns); 360 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 361 linalg::LinalgCopyVTWForwardingPattern>(ctx, 362 /*benefit=*/2); 363 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 364 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 365 if (getVectorizePadding()) 366 linalg::populatePadOpVectorizationPatterns(patterns); 367 368 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { 369 InFlightDiagnostic diag = emitError() << "failed to apply"; 370 diag.attachNote(target->getLoc()) << "target op"; 371 return diag; 372 } 373 return target; 374 } 375 376 //===----------------------------------------------------------------------===// 377 // Transform op registration 378 //===----------------------------------------------------------------------===// 379 380 namespace { 381 /// Registers new ops and declares PDL as dependent dialect since the additional 382 /// ops are using PDL types for operands and results. 383 class LinalgTransformDialectExtension 384 : public transform::TransformDialectExtension< 385 LinalgTransformDialectExtension> { 386 public: 387 LinalgTransformDialectExtension() { 388 declareDependentDialect<pdl::PDLDialect>(); 389 declareDependentDialect<scf::SCFDialect>(); 390 declareDependentDialect<vector::VectorDialect>(); 391 registerTransformOps< 392 #define GET_OP_LIST 393 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 394 >(); 395 } 396 }; 397 } // namespace 398 399 #define GET_OP_CLASSES 400 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 401 402 void mlir::linalg::registerTransformDialectExtension( 403 DialectRegistry ®istry) { 404 registry.addExtensions<LinalgTransformDialectExtension>(); 405 } 406