1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/PDL/IR/PDL.h" 14 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Interfaces/SideEffectInterfaces.h" 17 #include "mlir/Parser/Parser.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/Support/FormatVariadic.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 using namespace mlir::transform; 24 25 /// Extracts a vector of int64_t from an array attribute. Asserts if the 26 /// attribute contains values other than integers. 27 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { 28 SmallVector<int64_t> result; 29 result.reserve(attr.size()); 30 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 31 result.push_back(value.getSExtValue()); 32 return result; 33 } 34 35 /// Extracts a vector of unsigned from an array attribute. Asserts if the 36 /// attribute contains values other than intergers. May truncate. 37 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { 38 SmallVector<unsigned> result; 39 result.reserve(attr.size()); 40 for (APInt value : attr.getAsValueRange<IntegerAttr>()) 41 result.push_back(value.getZExtValue()); 42 return result; 43 } 44 45 namespace { 46 /// A simple pattern rewriter that implements no special logic. 47 class SimpleRewriter : public PatternRewriter { 48 public: 49 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 50 }; 51 } // namespace 52 53 /// Attempts to apply the pattern specified as template argument to the given 54 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 55 /// function that returns the "main" result or failure. Returns failure if the 56 /// pattern failed to apply. Extra arguments are forwarded to the pattern 57 /// constructor. 58 template <typename PatternTy, typename... Args> 59 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 60 // Check if the given operation has the type expected by the pattern. 61 using OpTy = typename llvm::function_traits< 62 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 63 auto op = dyn_cast<OpTy>(operation); 64 if (!op) 65 return failure(); 66 67 // Apply the pattern directly to the op. 68 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 69 SimpleRewriter rewriter(operation->getContext()); 70 rewriter.setInsertionPoint(operation); 71 auto result = pattern.returningMatchAndRewrite(op, rewriter); 72 if (failed(result)) 73 return failure(); 74 return cast<LinalgOp>(result->getOperation()); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // DecomposeOp 79 //===----------------------------------------------------------------------===// 80 81 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) { 82 FailureOr<LinalgOp> windowed = 83 tryApply<DownscaleSizeOneWindowed2DConvolution>(target); 84 if (succeeded(windowed)) 85 return windowed; 86 87 FailureOr<LinalgOp> depthwise = 88 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); 89 if (succeeded(depthwise)) 90 return depthwise; 91 92 InFlightDiagnostic diag = emitError() << "failed to apply"; 93 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 94 return diag; 95 } 96 97 //===----------------------------------------------------------------------===// 98 // GeneralizeOp 99 //===----------------------------------------------------------------------===// 100 101 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) { 102 // Exit early if no transformation is needed. 103 if (isa<GenericOp>(target)) 104 return target; 105 106 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); 107 if (succeeded(generic)) 108 return generic; 109 110 InFlightDiagnostic diag = emitError() << "failed to apply"; 111 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 112 return diag; 113 } 114 115 //===----------------------------------------------------------------------===// 116 // InterchangeOp 117 //===----------------------------------------------------------------------===// 118 119 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) { 120 SmallVector<unsigned> interchangeVector = 121 extractUIntArray(getIteratorInterchange()); 122 // Exit early if no transformation is needed. 123 if (interchangeVector.empty()) 124 return target; 125 126 auto genericTarget = dyn_cast<GenericOp>(target.getOperation()); 127 if (!genericTarget) { 128 InFlightDiagnostic diag = emitOpError() 129 << "applies to " << GenericOp::getOperationName() 130 << " ops"; 131 diag.attachNote(target.getLoc()) << "attempted to apply to this op"; 132 return diag; 133 } 134 135 return tryApply<GenericOpInterchangePattern>(target, interchangeVector); 136 } 137 138 LogicalResult transform::InterchangeOp::verify() { 139 SmallVector<unsigned> permutation = 140 extractUIntArray(getIteratorInterchange()); 141 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); 142 if (!std::is_permutation(sequence.begin(), sequence.end(), 143 permutation.begin(), permutation.end())) { 144 return emitOpError() 145 << "expects iterator_interchange to be a permutation, found " 146 << getIteratorInterchange(); 147 } 148 return success(); 149 } 150 151 //===---------------------------------------------------------------------===// 152 // PadOp 153 //===---------------------------------------------------------------------===// 154 155 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) { 156 // Convert the integer packing flags to booleans. 157 SmallVector<bool> packPaddings; 158 for (int64_t packPadding : extractI64Array(getPackPaddings())) 159 packPaddings.push_back(static_cast<bool>(packPadding)); 160 161 // Convert the padding values to attributes. 162 SmallVector<Attribute> paddingValues; 163 for (auto const &it : 164 llvm::zip(getPaddingValues(), target->getOperandTypes())) { 165 Attribute attr = std::get<0>(it); 166 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 167 // Try to parse string attributes to obtain an attribute of element type. 168 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 169 paddingValues.push_back( 170 parseAttribute(attr.cast<StringAttr>(), elementType)); 171 if (!paddingValues.back()) { 172 InFlightDiagnostic diag = emitOpError() 173 << "expects a padding value that parses to " 174 << elementType << ", got " << std::get<0>(it); 175 diag.attachNote(target.getLoc()) << "when applied to this op"; 176 return diag; 177 } 178 continue; 179 } 180 // Otherwise, add the attribute directly. 181 if (attr.getType() != elementType) { 182 InFlightDiagnostic diag = emitOpError() 183 << "expects a padding value of type " 184 << elementType << ", got " << attr; 185 diag.attachNote(target.getLoc()) << "when applied to this op"; 186 return diag; 187 } 188 paddingValues.push_back(attr); 189 } 190 191 // Extract the transpose vectors. 192 SmallVector<SmallVector<int64_t>> transposePaddings; 193 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) 194 transposePaddings.push_back( 195 extractI64Array(transposeVector.cast<ArrayAttr>())); 196 197 LinalgPaddingOptions paddingOptions; 198 paddingOptions.setPaddingValues(paddingValues); 199 paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); 200 paddingOptions.setPackPaddings(packPaddings); 201 paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); 202 paddingOptions.setTransposePaddings(transposePaddings); 203 204 FailureOr<LinalgOp> result = 205 tryApply<LinalgPaddingPattern>(target, paddingOptions); 206 if (succeeded(result)) 207 return result; 208 209 InFlightDiagnostic diag = emitError() 210 << "failed to apply pattern to target op"; 211 diag.attachNote(target.getLoc()) << "target op"; 212 return diag; 213 } 214 215 LogicalResult transform::PadOp::verify() { 216 SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings()); 217 if (any_of(packPaddings, [](int64_t packPadding) { 218 return packPadding != 0 && packPadding != 1; 219 })) { 220 return emitOpError() 221 << "expects pack_paddings to contain booleans (0/1), found " 222 << getPackPaddings(); 223 } 224 225 SmallVector<int64_t> paddingDimensions = 226 extractI64Array(getPaddingDimensions()); 227 if (any_of(paddingDimensions, 228 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 229 return emitOpError() 230 << "expects padding_dimensions to contain positive integers, found " 231 << getPaddingDimensions(); 232 } 233 234 SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings()); 235 if (any_of(hoistPaddings, 236 [](int64_t hoistPadding) { return hoistPadding < 0; })) { 237 return emitOpError() 238 << "expects hoist_paddings to contain positive integers, found " 239 << getHoistPaddings(); 240 } 241 242 ArrayAttr transposes = getTransposePaddings(); 243 for (Attribute attr : transposes) { 244 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); 245 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 246 if (!std::is_permutation(sequence.begin(), sequence.end(), 247 transpose.begin(), transpose.end())) { 248 return emitOpError() 249 << "expects transpose_paddings to be a permutation, found " 250 << attr; 251 } 252 } 253 return success(); 254 } 255 256 //===----------------------------------------------------------------------===// 257 // ScalarizeOp 258 //===----------------------------------------------------------------------===// 259 260 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) { 261 LinalgTilingOptions tilingOptions; 262 tilingOptions.scalarizeDynamicDims(); 263 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile 264 // sizes and asserts that it is not already set. 265 SmallVector<int64_t> emptyTileSizes; 266 LinalgTilingPattern pattern(getContext(), tilingOptions); 267 SimpleRewriter rewriter(getContext()); 268 rewriter.setInsertionPoint(target); 269 FailureOr<TiledLinalgOp> result = 270 pattern.returningMatchAndRewrite(target, rewriter); 271 if (failed(result)) 272 return failure(); 273 274 return result->op; 275 } 276 277 //===----------------------------------------------------------------------===// 278 // TileOp 279 //===----------------------------------------------------------------------===// 280 281 /// Apply a tiling transformation to all payload ops and store both the 282 /// tiled operation as well as the created tile loops. 283 static LogicalResult 284 applyTilingToAll(Operation *transformOp, Value target, 285 ArrayRef<int64_t> tileSizes, 286 transform::TransformResults &transformResults, 287 transform::TransformState &state, 288 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) { 289 // Number of loops: Number of tiles sizes that are not zero. 290 size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); 291 // All payload ops. These should all be LinalgOps for now. 292 ArrayRef<Operation *> payloadOps = state.getPayloadOps(target); 293 294 SmallVector<Operation *> tiledLinalgOps; 295 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 296 for (unsigned int i = 0; i < numLoops; ++i) 297 loopOps[i].reserve(payloadOps.size()); 298 299 for (Operation *target : payloadOps) { 300 auto linalgOp = dyn_cast<linalg::LinalgOp>(target); 301 if (!linalgOp) 302 return transformOp->emitError("only LinalgOps are supported"); 303 304 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp); 305 if (failed(tiled)) 306 return failure(); 307 308 tiledLinalgOps.push_back(tiled->op); 309 if (tiled->loops.size() != numLoops) 310 // Not enough loops were generated. This usually means that the input size 311 // was smaller than the tiling size. 312 // TODO: LinalgTilingPattern should return failure(). 313 return failure(); 314 for (unsigned int i = 0; i < numLoops; ++i) 315 loopOps[i].push_back(tiled->loops[i]); 316 } 317 318 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 319 for (unsigned int i = 0; i < numLoops; ++i) 320 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 321 return success(); 322 } 323 324 LogicalResult transform::TileOp::apply(TransformResults &transformResults, 325 TransformState &state) { 326 LinalgTilingOptions tilingOptions; 327 SmallVector<int64_t> tileSizes = extractI64Array(getSizes()); 328 329 if (!tileSizes.empty()) 330 tilingOptions.setTileSizes(tileSizes); 331 tilingOptions.setInterchange(extractUIntArray(getInterchange())); 332 LinalgTilingPattern pattern(getContext(), tilingOptions); 333 334 return applyTilingToAll(getOperation(), getTarget(), tileSizes, 335 transformResults, state, [&](LinalgOp linalgOp) { 336 SimpleRewriter rewriter(linalgOp.getContext()); 337 return pattern.returningMatchAndRewrite(linalgOp, 338 rewriter); 339 }); 340 } 341 342 ParseResult transform::TileOp::parse(OpAsmParser &parser, 343 OperationState &result) { 344 StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); 345 OpAsmParser::UnresolvedOperand targetOperand; 346 SMLoc opLoc = parser.getCurrentLocation(); 347 if (parser.parseOperand(targetOperand) || 348 parser.parseOptionalAttrDict(result.attributes)) 349 return failure(); 350 Attribute sizesAttr = result.attributes.get(sizesAttrName); 351 if (!sizesAttr) 352 return parser.emitError(opLoc) 353 << "expected '" << sizesAttrName << "' attribute"; 354 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); 355 if (!sizesArrayAttr) 356 return parser.emitError(opLoc) 357 << "'" << sizesAttrName << "' attribute must be an array"; 358 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); 359 size_t numExpectedLoops = 360 sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); 361 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); 362 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) 363 return failure(); 364 return success(); 365 } 366 367 void TileOp::print(OpAsmPrinter &p) { 368 p << ' '; 369 p << getTarget(); 370 p.printOptionalAttrDict((*this)->getAttrs()); 371 } 372 373 void TileOp::getEffects( 374 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 375 &effects) { 376 // `target` arg is consumed and can no longer be used. 377 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 378 TransformMappingResource::get()); 379 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 380 TransformMappingResource::get()); 381 382 for (Value r : getResults()) { 383 effects.emplace_back(MemoryEffects::Write::get(), r, 384 TransformMappingResource::get()); 385 effects.emplace_back(MemoryEffects::Allocate::get(), r, 386 TransformMappingResource::get()); 387 } 388 389 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 390 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // VectorizeOp 395 //===----------------------------------------------------------------------===// 396 397 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) { 398 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 399 InFlightDiagnostic diag = emitOpError() 400 << "applies only to isolated-from-above targets"; 401 diag.attachNote(target->getLoc()) << "non-isolated target"; 402 return diag; 403 } 404 405 MLIRContext *ctx = getContext(); 406 RewritePatternSet patterns(ctx); 407 patterns.add<LinalgVectorizationPattern>(ctx); 408 409 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 410 vector::populateVectorReductionToContractPatterns(patterns); 411 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 412 linalg::LinalgCopyVTWForwardingPattern>(ctx, 413 /*benefit=*/2); 414 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 415 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 416 if (getVectorizePadding()) 417 linalg::populatePadOpVectorizationPatterns(patterns); 418 419 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { 420 InFlightDiagnostic diag = emitError() << "failed to apply"; 421 diag.attachNote(target->getLoc()) << "target op"; 422 return diag; 423 } 424 return target; 425 } 426 427 //===----------------------------------------------------------------------===// 428 // Transform op registration 429 //===----------------------------------------------------------------------===// 430 431 namespace { 432 /// Registers new ops and declares PDL as dependent dialect since the additional 433 /// ops are using PDL types for operands and results. 434 class LinalgTransformDialectExtension 435 : public transform::TransformDialectExtension< 436 LinalgTransformDialectExtension> { 437 public: 438 LinalgTransformDialectExtension() { 439 declareDependentDialect<pdl::PDLDialect>(); 440 declareDependentDialect<scf::SCFDialect>(); 441 declareDependentDialect<vector::VectorDialect>(); 442 registerTransformOps< 443 #define GET_OP_LIST 444 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 445 >(); 446 } 447 }; 448 } // namespace 449 450 #define GET_OP_CLASSES 451 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 452 453 void mlir::linalg::registerTransformDialectExtension( 454 DialectRegistry ®istry) { 455 registry.addExtensions<LinalgTransformDialectExtension>(); 456 } 457