1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "llvm/ADT/ScopeExit.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // TransformState 20 //===----------------------------------------------------------------------===// 21 22 constexpr const Value transform::TransformState::kTopLevelValue; 23 24 transform::TransformState::TransformState(Region ®ion, Operation *root) 25 : topLevel(root) { 26 auto result = mappings.try_emplace(®ion); 27 assert(result.second && "the region scope is already present"); 28 (void)result; 29 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 30 regionStack.push_back(®ion); 31 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 32 } 33 34 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 35 36 ArrayRef<Operation *> 37 transform::TransformState::getPayloadOps(Value value) const { 38 const TransformOpMapping &operationMapping = getMapping(value).direct; 39 auto iter = operationMapping.find(value); 40 assert(iter != operationMapping.end() && "unknown handle"); 41 return iter->getSecond(); 42 } 43 44 LogicalResult 45 transform::TransformState::setPayloadOps(Value value, 46 ArrayRef<Operation *> targets) { 47 assert(value != kTopLevelValue && 48 "attempting to reset the transformation root"); 49 50 if (value.use_empty()) 51 return success(); 52 53 // Setting new payload for the value without cleaning it first is a misuse of 54 // the API, assert here. 55 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 56 Mappings &mappings = getMapping(value); 57 bool inserted = 58 mappings.direct.insert({value, std::move(storedTargets)}).second; 59 assert(inserted && "value is already associated with another list"); 60 (void)inserted; 61 62 // Having multiple handles to the same operation is an error in the transform 63 // expressed using the dialect and may be constructed by valid API calls from 64 // valid IR. Emit an error here. 65 for (Operation *op : targets) { 66 auto insertionResult = mappings.reverse.insert({op, value}); 67 if (!insertionResult.second) { 68 InFlightDiagnostic diag = op->emitError() 69 << "operation tracked by two handles"; 70 diag.attachNote(value.getLoc()) << "handle"; 71 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 72 return diag; 73 } 74 } 75 76 return success(); 77 } 78 79 void transform::TransformState::removePayloadOps(Value value) { 80 Mappings &mappings = getMapping(value); 81 for (Operation *op : mappings.direct[value]) 82 mappings.reverse.erase(op); 83 mappings.direct.erase(value); 84 } 85 86 void transform::TransformState::updatePayloadOps( 87 Value value, function_ref<Operation *(Operation *)> callback) { 88 auto it = getMapping(value).direct.find(value); 89 assert(it != getMapping(value).direct.end() && "unknown handle"); 90 SmallVector<Operation *> &association = it->getSecond(); 91 SmallVector<Operation *> updated; 92 updated.reserve(association.size()); 93 94 for (Operation *op : association) 95 if (Operation *updatedOp = callback(op)) 96 updated.push_back(updatedOp); 97 98 std::swap(association, updated); 99 } 100 101 LogicalResult 102 transform::TransformState::applyTransform(TransformOpInterface transform) { 103 transform::TransformResults results(transform->getNumResults()); 104 if (failed(transform.apply(results, *this))) 105 return failure(); 106 107 // Remove the mapping for the operand if it is consumed by the operation. This 108 // allows us to catch use-after-free with assertions later on. 109 auto memEffectInterface = 110 cast<MemoryEffectOpInterface>(transform.getOperation()); 111 SmallVector<MemoryEffects::EffectInstance, 2> effects; 112 for (Value target : transform->getOperands()) { 113 effects.clear(); 114 memEffectInterface.getEffectsOnValue(target, effects); 115 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 116 return isa<transform::TransformMappingResource>( 117 effect.getResource()) && 118 isa<MemoryEffects::Free>(effect.getEffect()); 119 })) { 120 removePayloadOps(target); 121 } 122 } 123 124 for (auto &en : llvm::enumerate(transform->getResults())) { 125 assert(en.value().getDefiningOp() == transform.getOperation() && 126 "payload IR association for a value other than the result of the " 127 "current transform op"); 128 if (failed(setPayloadOps(en.value(), results.get(en.index())))) 129 return failure(); 130 } 131 132 return success(); 133 } 134 135 transform::TransformState::Extension::~Extension() = default; 136 137 //===----------------------------------------------------------------------===// 138 // TransformResults 139 //===----------------------------------------------------------------------===// 140 141 transform::TransformResults::TransformResults(unsigned numSegments) { 142 segments.resize(numSegments, 143 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 144 } 145 146 void transform::TransformResults::set(OpResult value, 147 ArrayRef<Operation *> ops) { 148 unsigned position = value.getResultNumber(); 149 assert(position < segments.size() && 150 "setting results for a non-existent handle"); 151 assert(segments[position].data() == nullptr && "results already set"); 152 unsigned start = operations.size(); 153 llvm::append_range(operations, ops); 154 segments[position] = makeArrayRef(operations).drop_front(start); 155 } 156 157 ArrayRef<Operation *> 158 transform::TransformResults::get(unsigned resultNumber) const { 159 assert(resultNumber < segments.size() && 160 "querying results for a non-existent handle"); 161 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 162 return segments[resultNumber]; 163 } 164 165 //===----------------------------------------------------------------------===// 166 // Utilities for PossibleTopLevelTransformOpTrait. 167 //===----------------------------------------------------------------------===// 168 169 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 170 TransformState &state, Operation *op) { 171 SmallVector<Operation *> targets; 172 if (op->getNumOperands() != 0) 173 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 174 else 175 targets.push_back(state.getTopLevel()); 176 177 return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), 178 targets); 179 } 180 181 LogicalResult 182 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 183 // Attaching this trait without the interface is a misuse of the API, but it 184 // cannot be caught via a static_assert because interface registration is 185 // dynamic. 186 assert(isa<TransformOpInterface>(op) && 187 "should implement TransformOpInterface to have " 188 "PossibleTopLevelTransformOpTrait"); 189 190 if (op->getNumRegions() != 1) 191 return op->emitOpError() << "expects one region"; 192 193 Region *bodyRegion = &op->getRegion(0); 194 if (!llvm::hasNItems(*bodyRegion, 1)) 195 return op->emitOpError() << "expects a single-block region"; 196 197 Block *body = &bodyRegion->front(); 198 if (body->getNumArguments() != 1 || 199 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) { 200 return op->emitOpError() 201 << "expects the entry block to have one argument of type " 202 << pdl::OperationType::get(op->getContext()); 203 } 204 205 if (auto *parent = 206 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 207 if (op->getNumOperands() == 0) { 208 InFlightDiagnostic diag = 209 op->emitOpError() 210 << "expects the root operation to be provided for a nested op"; 211 diag.attachNote(parent->getLoc()) 212 << "nested in another possible top-level op"; 213 return diag; 214 } 215 } 216 217 return success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // Generated interface implementation. 222 //===----------------------------------------------------------------------===// 223 224 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 225