1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Parser/Parser.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 using namespace mlir::transform; 22 23 /// Extracts a vector of int64_t from an array attribute. Asserts if the 24 /// attribute contains values other than integers. 25 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 26 SmallVector<int64_t> result; 27 result.reserve(attr.size()); 28 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 29 result.push_back(value.getSExtValue()); 30 return result; 31 } 32 33 /// Extracts a vector of unsigned from an array attribute. Asserts if the 34 /// attribute contains values other than intergers. May truncate. 35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 36 SmallVector<unsigned> result; 37 result.reserve(attr.size()); 38 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 39 result.push_back(value.getZExtValue()); 40 return result; 41 } 42 43 namespace { 44 /// A simple pattern rewriter that implements no special logic. 45 class SimpleRewriter : public PatternRewriter { 46 public: 47 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 48 }; 49 } // namespace 50 51 /// Attempts to apply the pattern specified as template argument to the given 52 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 53 /// function that returns the "main" result or failure. Returns failure if the 54 /// pattern failed to apply. Extra arguments are forwarded to the pattern 55 /// constructor. 56 template <typename PatternTy, typename... Args> 57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 58 // Check if the given operation has the type expected by the pattern. 59 using OpTy = typename llvm::function_traits< 60 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 61 auto op = dyn_cast<OpTy>(operation); 62 if (!op) 63 return failure(); 64 65 // Apply the pattern directly to the op. 66 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 67 SimpleRewriter rewriter(operation->getContext()); 68 rewriter.setInsertionPoint(operation); 69 auto result = pattern.returningMatchAndRewrite(op, rewriter); 70 if (failed(result)) 71 return failure(); 72 return cast<LinalgOp>(result->getOperation()); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // DecomposeOp 77 //===----------------------------------------------------------------------===// 78 79 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) { 80 FailureOr<LinalgOp> windowed = 81 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 82 if (succeeded(windowed)) 83 return windowed; 84 85 FailureOr<LinalgOp> depthwise = 86 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 87 if (succeeded(depthwise)) 88 return depthwise; 89 90 return reportUnknownTransformError(target); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // FuseOp 95 //===----------------------------------------------------------------------===// 96 97 /// Apply a tiling transformation to all payload ops and store both the 98 /// tiled operation as well as the created tile loops. 99 static LogicalResult 100 applyTilingToAll(Operation *transformOp, Value target, 101 ArrayRef<int64_t> tileSizes, 102 transform::TransformResults &transformResults, 103 transform::TransformState &state, 104 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 105 // Number of loops: Number of tiles sizes that are not zero. 106 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 107 // All payload ops. These should all be LinalgOps for now. 108 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 109 110 SmallVector<Operation *> tiledLinalgOps; 111 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 112 for (unsigned int i = 0; i < numLoops; ++i) 113 loopOps[i].reserve(payloadOps.size()); 114 115 for (Operation *target : payloadOps) { 116 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 117 if (!linalgOp) 118 return transformOp->emitError("only LinalgOps are supported"); 119 120 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 121 if (failed(tiled)) 122 return failure(); 123 124 tiledLinalgOps.push_back(tiled->op); 125 if (tiled->loops.size() != numLoops) 126 // Not enough loops were generated. This usually means that the input size 127 // was smaller than the tiling size. 128 // TODO: LinalgTilingPattern should return failure(). 129 return failure(); 130 for (unsigned int i = 0; i < numLoops; ++i) 131 loopOps[i].push_back(tiled->loops[i]); 132 } 133 134 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 135 for (unsigned int i = 0; i < numLoops; ++i) 136 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 137 return success(); 138 } 139 140 /// Parse a tiling-like operation that returns the tiled op as well as the 141 /// created tile loops. The function counts the non-zero tile sizes to compute 142 /// the number of results. 143 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, 144 StringRef sizesAttrName) { 145 OpAsmParser::UnresolvedOperand targetOperand; 146 SMLoc opLoc = parser.getCurrentLocation(); 147 if (parser.parseOperand(targetOperand) || 148 parser.parseOptionalAttrDict(result.attributes)) 149 return failure(); 150 Attribute sizesAttr = result.attributes.get(sizesAttrName); 151 if (!sizesAttr) 152 return parser.emitError(opLoc) 153 << "expected '" << sizesAttrName << "' attribute"; 154 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 155 if (!sizesArrayAttr) 156 return parser.emitError(opLoc) 157 << "'" << sizesAttrName << "' attribute must be an array"; 158 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 159 size_t numExpectedLoops = 160 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 161 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 162 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 163 return failure(); 164 return success(); 165 } 166 167 DiagnosedSilencableFailure 168 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, 169 mlir::transform::TransformState &state) { 170 LinalgTilingAndFusionOptions fusionOptions; 171 fusionOptions.tileSizes = extractI64Array(getTileSizes()); 172 fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); 173 174 LogicalResult result = applyTilingToAll( 175 getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, 176 state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> { 177 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); 178 SimpleRewriter rewriter(getContext()); 179 rewriter.setInsertionPoint(linalgOp); 180 FailureOr<TileLoopNest> tileLoopNest = 181 pattern.returningMatchAndRewrite(linalgOp, rewriter); 182 if (failed(tileLoopNest)) 183 return failure(); 184 185 TiledLinalgOp tiledLinalgOp; 186 tiledLinalgOp.op = tileLoopNest->getRootOp(); 187 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), 188 tileLoopNest->getLoopOps().end()}; 189 return tiledLinalgOp; 190 }); 191 return failed(result) ? DiagnosedSilencableFailure::definiteFailure() 192 : DiagnosedSilencableFailure::success(); 193 } 194 195 ParseResult transform::FuseOp::parse(OpAsmParser &parser, 196 OperationState &result) { 197 return parseTileLikeOp( 198 parser, result, 199 transform::FuseOp::getTileSizesAttrName(result.name).getValue()); 200 } 201 202 void transform::FuseOp::print(OpAsmPrinter &p) { 203 p << ' '; 204 p << getTarget(); 205 p.printOptionalAttrDict((*this)->getAttrs()); 206 } 207 208 LogicalResult transform::FuseOp::verify() { 209 SmallVector<int64_t> permutation = extractI64Array(getTileInterchange()); 210 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 211 if (!std::is_permutation(sequence.begin(), sequence.end(), 212 permutation.begin(), permutation.end())) { 213 return emitOpError() << "expects interchange to be a permutation, found " 214 << getTileInterchange(); 215 } 216 return success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // GeneralizeOp 221 //===----------------------------------------------------------------------===// 222 223 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) { 224 // Exit early if no transformation is needed. 225 if (isa<GenericOp>(target)) 226 return target; 227 228 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 229 if (succeeded(generic)) 230 return generic; 231 232 return reportUnknownTransformError(target); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // InterchangeOp 237 //===----------------------------------------------------------------------===// 238 239 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) { 240 SmallVector<unsigned> interchangeVector = 241 extractUIntArray(getIteratorInterchange()); 242 // Exit early if no transformation is needed. 243 if (interchangeVector.empty()) 244 return target; 245 246 auto genericTarget = dyn_cast<GenericOp>(target.getOperation()); 247 if (!genericTarget) { 248 InFlightDiagnostic diag = emitOpError() 249 << "applies to " << GenericOp::getOperationName() 250 << " ops"; 251 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 252 return diag; 253 } 254 255 return tryApply<GenericOpInterchangePattern>(target, interchangeVector); 256 } 257 258 LogicalResult transform::InterchangeOp::verify() { 259 SmallVector<unsigned> permutation = 260 extractUIntArray(getIteratorInterchange()); 261 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 262 if (!std::is_permutation(sequence.begin(), sequence.end(), 263 permutation.begin(), permutation.end())) { 264 return emitOpError() 265 << "expects iterator_interchange to be a permutation, found " 266 << getIteratorInterchange(); 267 } 268 return success(); 269 } 270 271 //===---------------------------------------------------------------------===// 272 // PadOp 273 //===---------------------------------------------------------------------===// 274 275 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) { 276 // Convert the integer packing flags to booleans. 277 SmallVector<bool> packPaddings; 278 for (int64_t packPadding : extractI64Array(getPackPaddings())) 279 packPaddings.push_back(static_cast<bool>(packPadding)); 280 281 // Convert the padding values to attributes. 282 SmallVector<Attribute> paddingValues; 283 for (auto const &it : 284 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 285 Attribute attr = std::get<0>(it); 286 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 287 // Try to parse string attributes to obtain an attribute of element type. 288 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 289 paddingValues.push_back( 290 parseAttribute(attr.cast<StringAttr>(), elementType)); 291 if (!paddingValues.back()) { 292 InFlightDiagnostic diag = emitOpError() 293 << "expects a padding value that parses to " 294 << elementType << ", got " << std::get<0>(it); 295 diag.attachNote(target.getLoc()) << "when applied to this op"; 296 return diag; 297 } 298 continue; 299 } 300 // Otherwise, add the attribute directly. 301 if (attr.getType() != elementType) { 302 InFlightDiagnostic diag = emitOpError() 303 << "expects a padding value of type " 304 << elementType << ", got " << attr; 305 diag.attachNote(target.getLoc()) << "when applied to this op"; 306 return diag; 307 } 308 paddingValues.push_back(attr); 309 } 310 311 // Extract the transpose vectors. 312 SmallVector<SmallVector<int64_t>> transposePaddings; 313 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 314 transposePaddings.push_back( 315 extractI64Array(transposeVector.cast<ArrayAttr>())); 316 317 LinalgPaddingOptions paddingOptions; 318 paddingOptions.setPaddingValues(paddingValues); 319 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 320 paddingOptions.setPackPaddings(packPaddings); 321 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 322 paddingOptions.setTransposePaddings(transposePaddings); 323 324 FailureOr<LinalgOp> result = 325 tryApply<LinalgPaddingPattern>(target, paddingOptions); 326 if (succeeded(result)) 327 return result; 328 329 InFlightDiagnostic diag = emitError() 330 << "failed to apply pattern to target op"; 331 diag.attachNote(target.getLoc()) << "target op"; 332 return diag; 333 } 334 335 LogicalResult transform::PadOp::verify() { 336 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 337 if (any_of(packPaddings, [](int64_t packPadding) { 338 return packPadding != 0 && packPadding != 1; 339 })) { 340 return emitOpError() 341 << "expects pack_paddings to contain booleans (0/1), found " 342 << getPackPaddings(); 343 } 344 345 SmallVector<int64_t> paddingDimensions = 346 extractI64Array(getPaddingDimensions()); 347 if (any_of(paddingDimensions, 348 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 349 return emitOpError() 350 << "expects padding_dimensions to contain positive integers, found " 351 << getPaddingDimensions(); 352 } 353 354 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 355 if (any_of(hoistPaddings, 356 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 357 return emitOpError() 358 << "expects hoist_paddings to contain positive integers, found " 359 << getHoistPaddings(); 360 } 361 362 ArrayAttr transposes = getTransposePaddings(); 363 for (Attribute attr : transposes) { 364 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 365 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 366 if (!std::is_permutation(sequence.begin(), sequence.end(), 367 transpose.begin(), transpose.end())) { 368 return emitOpError() 369 << "expects transpose_paddings to be a permutation, found " 370 << attr; 371 } 372 } 373 return success(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // ScalarizeOp 378 //===----------------------------------------------------------------------===// 379 380 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) { 381 LinalgTilingOptions tilingOptions; 382 tilingOptions.scalarizeDynamicDims(); 383 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 384 // sizes and asserts that it is not already set. 385 SmallVector<int64_t> emptyTileSizes; 386 LinalgTilingPattern pattern(getContext(), tilingOptions); 387 SimpleRewriter rewriter(getContext()); 388 rewriter.setInsertionPoint(target); 389 FailureOr<TiledLinalgOp> result = 390 pattern.returningMatchAndRewrite(target, rewriter); 391 if (failed(result)) 392 return failure(); 393 394 return result->op; 395 } 396 397 //===----------------------------------------------------------------------===// 398 // TileOp 399 //===----------------------------------------------------------------------===// 400 401 DiagnosedSilencableFailure 402 transform::TileOp::apply(TransformResults &transformResults, 403 TransformState &state) { 404 LinalgTilingOptions tilingOptions; 405 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 406 407 if (!tileSizes.empty()) 408 tilingOptions.setTileSizes(tileSizes); 409 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 410 LinalgTilingPattern pattern(getContext(), tilingOptions); 411 412 LogicalResult result = applyTilingToAll( 413 getOperation(), getTarget(), tileSizes, transformResults, state, 414 [&](LinalgOp linalgOp) { 415 SimpleRewriter rewriter(linalgOp.getContext()); 416 return pattern.returningMatchAndRewrite(linalgOp, rewriter); 417 }); 418 return DiagnosedSilencableFailure(result); 419 } 420 421 ParseResult transform::TileOp::parse(OpAsmParser &parser, 422 OperationState &result) { 423 return parseTileLikeOp(parser, result, 424 TileOp::getSizesAttrName(result.name).getValue()); 425 } 426 427 void TileOp::print(OpAsmPrinter &p) { 428 p << ' '; 429 p << getTarget(); 430 p.printOptionalAttrDict((*this)->getAttrs()); 431 } 432 433 //===----------------------------------------------------------------------===// 434 // VectorizeOp 435 //===----------------------------------------------------------------------===// 436 437 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) { 438 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 439 InFlightDiagnostic diag = emitOpError() 440 << "applies only to isolated-from-above targets"; 441 diag.attachNote(target->getLoc()) << "non-isolated target"; 442 return diag; 443 } 444 445 MLIRContext *ctx = getContext(); 446 RewritePatternSet patterns(ctx); 447 patterns.add<LinalgVectorizationPattern>(ctx); 448 449 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 450 vector::populateVectorReductionToContractPatterns(patterns); 451 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 452 linalg::LinalgCopyVTWForwardingPattern>(ctx, 453 /*benefit=*/2); 454 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 455 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 456 if (getVectorizePadding()) 457 linalg::populatePadOpVectorizationPatterns(patterns); 458 459 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) 460 return reportUnknownTransformError(target); 461 return target; 462 } 463 464 //===----------------------------------------------------------------------===// 465 // Transform op registration 466 //===----------------------------------------------------------------------===// 467 468 namespace { 469 /// Registers new ops and declares PDL as dependent dialect since the additional 470 /// ops are using PDL types for operands and results. 471 class LinalgTransformDialectExtension 472 : public transform::TransformDialectExtension< 473 LinalgTransformDialectExtension> { 474 public: 475 LinalgTransformDialectExtension() { 476 declareDependentDialect<pdl::PDLDialect>(); 477 declareDependentDialect<scf::SCFDialect>(); 478 declareDependentDialect<vector::VectorDialect>(); 479 registerTransformOps< 480 #define GET_OP_LIST 481 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 482 >(); 483 } 484 }; 485 } // namespace 486 487 #define GET_OP_CLASSES 488 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 489 490 void mlir::linalg::registerTransformDialectExtension( 491 DialectRegistry ®istry) { 492 registry.addExtensions<LinalgTransformDialectExtension>(); 493 } 494