1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/FunctionInterfaces.h" 15 #include "mlir/Transforms/InliningUtils.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 20 #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc" 21 22 /// Attribute name used to mark function arguments who's buffers can be written 23 /// to during One-Shot Module Bufferize. 24 constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName; 25 26 /// Attribute name used to mark the bufferization layout for region arguments 27 /// during One-Shot Module Bufferize. 28 constexpr const ::llvm::StringLiteral 29 BufferizationDialect::kBufferLayoutAttrName; 30 31 /// Attribute name used to mark escaping behavior of buffer allocations. 32 constexpr const ::llvm::StringLiteral BufferizationDialect::kEscapeAttrName; 33 34 //===----------------------------------------------------------------------===// 35 // Bufferization Dialect Interfaces 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 struct BufferizationInlinerInterface : public DialectInlinerInterface { 40 using DialectInlinerInterface::DialectInlinerInterface; 41 42 /// Operations in Bufferization dialect are always legal to inline. isLegalToInline__anon4220dc3d0111::BufferizationInlinerInterface43 bool isLegalToInline(Operation *, Region *, bool, 44 BlockAndValueMapping &) const final { 45 return true; 46 } 47 }; 48 } // namespace 49 50 //===----------------------------------------------------------------------===// 51 // Bufferization Dialect 52 //===----------------------------------------------------------------------===// 53 initialize()54void mlir::bufferization::BufferizationDialect::initialize() { 55 addOperations< 56 #define GET_OP_LIST 57 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 58 >(); 59 addInterfaces<BufferizationInlinerInterface>(); 60 } 61 62 LogicalResult verifyOperationAttribute(Operation * op,NamedAttribute attr)63BufferizationDialect::verifyOperationAttribute(Operation *op, 64 NamedAttribute attr) { 65 using bufferization::BufferizableOpInterface; 66 67 if (attr.getName() == kWritableAttrName) { 68 if (!attr.getValue().isa<BoolAttr>()) { 69 return op->emitError() << "'" << kWritableAttrName 70 << "' is expected to be a boolean attribute"; 71 } 72 if (!isa<FunctionOpInterface>(op)) 73 return op->emitError() << "expected " << attr.getName() 74 << " to be used on function-like operations"; 75 return success(); 76 } 77 if (attr.getName() == kBufferLayoutAttrName) { 78 if (!attr.getValue().isa<AffineMapAttr>()) { 79 return op->emitError() << "'" << kBufferLayoutAttrName 80 << "' is expected to be a affine map attribute"; 81 } 82 if (!isa<FunctionOpInterface>(op)) 83 return op->emitError() << "expected " << attr.getName() 84 << " to be used on function-like operations"; 85 return success(); 86 } 87 if (attr.getName() == kEscapeAttrName) { 88 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 89 if (!arrayAttr) 90 return op->emitError() << "'" << kEscapeAttrName 91 << "' is expected to be a bool array attribute"; 92 if (arrayAttr.size() != op->getNumResults()) 93 return op->emitError() 94 << "'" << kEscapeAttrName 95 << "' has wrong number of elements, expected " 96 << op->getNumResults() << ", got " << arrayAttr.size(); 97 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 98 if (!bufferizableOp) 99 return op->emitError() 100 << "'" << kEscapeAttrName << "' only valid on bufferizable ops"; 101 for (const auto &it : llvm::enumerate(arrayAttr)) { 102 auto attr = it.value(); 103 auto boolAttr = attr.dyn_cast<BoolAttr>(); 104 if (!boolAttr) 105 return op->emitError() << "'" << kEscapeAttrName 106 << "' is expected to be a bool array attribute"; 107 if (!boolAttr.getValue()) 108 continue; 109 if (!op->getResult(it.index()).getType().isa<TensorType>()) 110 return op->emitError() 111 << "'" << kEscapeAttrName << "' only valid for tensor results"; 112 if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index()))) 113 return op->emitError() << "'" << kEscapeAttrName 114 << "' only valid for allocation results"; 115 } 116 return success(); 117 } 118 119 return op->emitError() << "attribute '" << attr.getName() 120 << "' not supported by the bufferization dialect"; 121 } 122