//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::tensor; //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// /// Determines whether tensor::CastOp casts to a more dynamic version of the /// source tensor. This is useful to fold a tensor.cast into a consuming op and /// implement canonicalization patterns for ops in different dialects that may /// consume the results of tensor.cast operations. Such foldable tensor.cast /// operations are typically inserted as `subtensor` ops and are canonicalized, /// to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked tensors with same element type and rank. /// 2. the tensor type has more static information than the result /// /// Example: /// ```mlir /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor /// %2 = consumer %1 ... : tensor ... /// ``` /// /// folds into: /// /// ```mlir /// %2 = consumer %0 ... : tensor<8x16xf32> ... /// ``` bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { if (!castOp) return false; RankedTensorType sourceType = castOp.source().getType().dyn_cast(); RankedTensorType resultType = castOp.getType().dyn_cast(); // Requires RankedTensorType. if (!sourceType || !resultType) return false; // Requires same elemental type. if (sourceType.getElementType() != resultType.getElementType()) return false; // Requires same rank. if (sourceType.getRank() != resultType.getRank()) return false; // If cast is towards more static sizes along any dimension, don't fold. for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { if (ShapedType::isDynamic(std::get<0>(t)) && !ShapedType::isDynamic(std::get<1>(t))) return false; } return true; } /// Performs folding of any operand of `op` if it comes from a tensor::CastOp /// that can be folded. LogicalResult mlir::tensor::foldTensorCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) return false; if (aT.getElementType() != bT.getElementType()) return false; return succeeded(verifyCompatibleShape(aT, bT)); } /// Compute a TensorType that has the joined shape knowledge of the two /// given TensorTypes. The element types need to match. static TensorType joinShapes(TensorType one, TensorType two) { assert(one.getElementType() == two.getElementType()); if (!one.hasRank()) return two; if (!two.hasRank()) return one; int64_t rank = one.getRank(); if (rank != two.getRank()) return {}; SmallVector join; join.reserve(rank); for (int64_t i = 0; i < rank; ++i) { if (one.isDynamicDim(i)) { join.push_back(two.getDimSize(i)); continue; } if (two.isDynamicDim(i)) { join.push_back(one.getDimSize(i)); continue; } if (one.getDimSize(i) != two.getDimSize(i)) return {}; join.push_back(one.getDimSize(i)); } return RankedTensorType::get(join, one.getElementType()); } namespace { /// Replaces chains of two tensor.cast operations by a single tensor.cast /// operation if doing so does not remove runtime constraints. struct ChainedTensorCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CastOp tensorCast, PatternRewriter &rewriter) const final { auto tensorCastOperand = tensorCast.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); auto sourceType = tensorCastOperand.getOperand().getType().cast(); auto intermediateType = tensorCastOperand.getType().cast(); auto resultType = tensorCast.getType().cast(); // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. auto firstJoin = joinShapes(joinShapes(sourceType, intermediateType), resultType); // The join might not exist if the cast sequence would fail at runtime. if (!firstJoin) return failure(); // The newJoin always exists if the above join exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. auto newJoin = joinShapes(sourceType, resultType); if (firstJoin != newJoin) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType, tensorCastOperand.getOperand()); return success(); } }; } // namespace void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// static LogicalResult verify(ExtractOp op) { // Verify the # indices match if we have a ranked type. if (auto tensorType = op.tensor().getType().dyn_cast()) if (tensorType.getRank() != static_cast(op.indices().size())) return op.emitOpError("incorrect number of indices for extract_element"); return success(); } OpFoldResult ExtractOp::fold(ArrayRef operands) { // The tensor operand must be a known constant. Attribute tensor = operands.front(); if (!tensor) return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatTensor = tensor.dyn_cast()) return splatTensor.getSplatValue(); // Otherwise, collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValue(indices); return {}; } //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// void FromElementsOp::build(OpBuilder &builder, OperationState &result, Type elementType, ValueRange elements) { Type resultTy = RankedTensorType::get({static_cast(elements.size())}, elementType); result.addOperands(elements); result.addTypes(resultTy); } void FromElementsOp::build(OpBuilder &builder, OperationState &result, ValueRange elements) { assert(!elements.empty() && "expected at least one element"); build(builder, result, elements.front().getType(), elements); } namespace { // Canonicalizes the pattern of the form // // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> // // to just %element. struct ExtractElementFromTensorFromElements : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { if (extract.indices().size() != 1) return failure(); auto tensorFromElements = extract.tensor().getDefiningOp(); if (tensorFromElements == nullptr) return failure(); APInt index; if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) return failure(); // Prevent out of bounds accesses. This can happen in invalid code that will // never execute. if (tensorFromElements->getNumOperands() <= index.getZExtValue() || index.getSExtValue() < 0) return failure(); rewriter.replaceOp(extract, tensorFromElements.getOperand(index.getZExtValue())); return success(); } }; } // namespace void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// static LogicalResult verify(GenerateOp op) { // Ensure that the tensor type has as many dynamic dimensions as are specified // by the operands. RankedTensorType resultTy = op.getType().cast(); if (op.getNumOperands() != resultTy.getNumDynamicDims()) return op.emitError("must have as many index operands as dynamic extents " "in the result type"); // Ensure that region arguments span the index space. if (!llvm::all_of(op.body().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) return op.emitError("all body arguments must be index"); if (op.body().getNumArguments() != resultTy.getRank()) return op.emitError("must have one body argument per input dimension"); // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(op.body().getBlocks().front().getTerminator()); if (yieldOp.value().getType() != resultTy.getElementType()) return op.emitOpError( "body must be terminated with a `yield` operation of the tensor " "element type"); return success(); } void GenerateOp::build( OpBuilder &b, OperationState &result, Type resultTy, ValueRange dynamicExtents, function_ref bodyBuilder) { build(b, result, resultTy, dynamicExtents); // Build and populate body. OpBuilder::InsertionGuard guard(b); Region *bodyRegion = result.regions.front().get(); auto rank = resultTy.cast().getRank(); SmallVector argumentTypes(rank, b.getIndexType()); Block *bodyBlock = b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); bodyBuilder(b, result.location, bodyBlock->getArguments()); } namespace { /// Canonicalizes tensor.generate operations with a constant /// operand into the equivalent operation with the operand expressed in the /// result type, instead. We also insert a type cast to make sure that the /// resulting IR is still well-typed. struct StaticTensorGenerate : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenerateOp tensorFromElements, PatternRewriter &rewriter) const final { auto resultType = tensorFromElements.getResult().getType().cast(); if (resultType.hasStaticShape()) return failure(); SmallVector newOperands; SmallVector newShape; auto operandsIt = tensorFromElements.dynamicExtents().begin(); for (int64_t dim : resultType.getShape()) { if (dim != RankedTensorType::kDynamicSize) { newShape.push_back(dim); continue; } APInt index; if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { newShape.push_back(RankedTensorType::kDynamicSize); newOperands.push_back(*operandsIt++); continue; } newShape.push_back(index.getSExtValue()); operandsIt++; } if (newOperands.size() == tensorFromElements.dynamicExtents().size()) return failure(); auto loc = tensorFromElements.getLoc(); auto newOp = rewriter.create( loc, RankedTensorType::get(newShape, resultType.getElementType()), newOperands); rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), newOp.body().begin()); rewriter.replaceOpWithNewOp(tensorFromElements, resultType, newOp); return success(); } }; /// Canonicalizes the pattern of the form /// /// %tensor = tensor.generate %x { /// ^bb0(%arg0: index): // no predecessors /// /// yield %1 : index /// } : tensor /// %extracted_element = tensor.extract %tensor[%c0] : tensor /// /// to just with %arg0 replaced by %c0. We only do this if the /// tensor.generate operation has no side-effects. struct ExtractFromTensorGenerate : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { auto tensorFromElements = extract.tensor().getDefiningOp(); if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) return failure(); BlockAndValueMapping mapping; Block *body = tensorFromElements.getBody(); mapping.map(body->getArguments(), extract.indices()); for (auto &op : body->without_terminator()) rewriter.clone(op, mapping); auto yield = cast(body->getTerminator()); rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); return success(); } }; /// Canonicalizes the pattern of the form /// /// %val = tensor.cast %source : : tensor to tensor<2xi32> /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> /// /// to /// /// %extracted_element = tensor.extract %source[%c0] : tensor struct ExtractFromTensorCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { auto tensorCast = extract.tensor().getDefiningOp(); if (!tensorCast) return failure(); rewriter.replaceOpWithNewOp(extract, tensorCast.source(), extract.indices()); return success(); } }; } // namespace void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // TODO: Move extract patterns to tensor::ExtractOp. results.add(context); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"