1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "llvm/ADT/ScopeExit.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // TransformState 20 //===----------------------------------------------------------------------===// 21 22 constexpr const Value transform::TransformState::kTopLevelValue; 23 24 transform::TransformState::TransformState(Region ®ion, Operation *root, 25 const TransformOptions &options) 26 : topLevel(root), options(options) { 27 auto result = mappings.try_emplace(®ion); 28 assert(result.second && "the region scope is already present"); 29 (void)result; 30 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 31 regionStack.push_back(®ion); 32 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 33 } 34 35 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 36 37 ArrayRef<Operation *> 38 transform::TransformState::getPayloadOps(Value value) const { 39 const TransformOpMapping &operationMapping = getMapping(value).direct; 40 auto iter = operationMapping.find(value); 41 assert(iter != operationMapping.end() && "unknown handle"); 42 return iter->getSecond(); 43 } 44 45 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { 46 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 47 if (Value handle = mapping.reverse.lookup(op)) 48 return handle; 49 } 50 return Value(); 51 } 52 53 LogicalResult transform::TransformState::tryEmplaceReverseMapping( 54 Mappings &map, Operation *operation, Value handle) { 55 auto insertionResult = map.reverse.insert({operation, handle}); 56 if (!insertionResult.second) { 57 InFlightDiagnostic diag = operation->emitError() 58 << "operation tracked by two handles"; 59 diag.attachNote(handle.getLoc()) << "handle"; 60 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 61 return diag; 62 } 63 return success(); 64 } 65 66 LogicalResult 67 transform::TransformState::setPayloadOps(Value value, 68 ArrayRef<Operation *> targets) { 69 assert(value != kTopLevelValue && 70 "attempting to reset the transformation root"); 71 72 if (value.use_empty()) 73 return success(); 74 75 // Setting new payload for the value without cleaning it first is a misuse of 76 // the API, assert here. 77 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 78 Mappings &mappings = getMapping(value); 79 bool inserted = 80 mappings.direct.insert({value, std::move(storedTargets)}).second; 81 assert(inserted && "value is already associated with another list"); 82 (void)inserted; 83 84 // Having multiple handles to the same operation is an error in the transform 85 // expressed using the dialect and may be constructed by valid API calls from 86 // valid IR. Emit an error here. 87 for (Operation *op : targets) { 88 if (failed(tryEmplaceReverseMapping(mappings, op, value))) 89 return failure(); 90 } 91 92 return success(); 93 } 94 95 void transform::TransformState::removePayloadOps(Value value) { 96 Mappings &mappings = getMapping(value); 97 for (Operation *op : mappings.direct[value]) 98 mappings.reverse.erase(op); 99 mappings.direct.erase(value); 100 } 101 102 LogicalResult transform::TransformState::updatePayloadOps( 103 Value value, function_ref<Operation *(Operation *)> callback) { 104 Mappings &mappings = getMapping(value); 105 auto it = mappings.direct.find(value); 106 assert(it != mappings.direct.end() && "unknown handle"); 107 SmallVector<Operation *> &association = it->getSecond(); 108 SmallVector<Operation *> updated; 109 updated.reserve(association.size()); 110 111 for (Operation *op : association) { 112 mappings.reverse.erase(op); 113 if (Operation *updatedOp = callback(op)) { 114 updated.push_back(updatedOp); 115 if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) 116 return failure(); 117 } 118 } 119 120 std::swap(association, updated); 121 return success(); 122 } 123 124 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { 125 ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get()); 126 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 127 for (const auto &kvp : mapping.reverse) { 128 // If the op is associated with invalidated handle, skip the check as it 129 // may be reading invalid IR. 130 Operation *op = kvp.first; 131 Value otherHandle = kvp.second; 132 if (invalidatedHandles.count(otherHandle)) 133 continue; 134 135 for (Operation *ancestor : potentialAncestors) { 136 if (!ancestor->isProperAncestor(op)) 137 continue; 138 139 // Make sure the error-reporting lambda doesn't capture anything 140 // by-reference because it will go out of scope. Additionally, extract 141 // location from Payload IR ops because the ops themselves may be 142 // deleted before the lambda gets called. 143 Location ancestorLoc = ancestor->getLoc(); 144 Location opLoc = op->getLoc(); 145 Operation *owner = handle.getOwner(); 146 unsigned operandNo = handle.getOperandNumber(); 147 invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, 148 otherHandle]() { 149 InFlightDiagnostic diag = 150 owner->emitOpError() 151 << "invalidated the handle to payload operations nested in the " 152 "payload operation associated with its operand #" 153 << operandNo; 154 diag.attachNote(ancestorLoc) << "ancestor op"; 155 diag.attachNote(opLoc) << "nested op"; 156 diag.attachNote(otherHandle.getLoc()) << "other handle"; 157 }; 158 } 159 } 160 } 161 } 162 163 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( 164 TransformOpInterface transform) { 165 auto memoryEffectsIface = 166 cast<MemoryEffectOpInterface>(transform.getOperation()); 167 SmallVector<MemoryEffects::EffectInstance> effects; 168 memoryEffectsIface.getEffectsOnResource( 169 transform::TransformMappingResource::get(), effects); 170 171 for (OpOperand &target : transform->getOpOperands()) { 172 // If the operand uses an invalidated handle, report it. 173 auto it = invalidatedHandles.find(target.get()); 174 if (it != invalidatedHandles.end()) 175 return it->getSecond()(), failure(); 176 177 // Invalidate handles pointing to the operations nested in the operation 178 // associated with the handle consumed by this operation. 179 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { 180 return isa<MemoryEffects::Free>(effect.getEffect()) && 181 effect.getValue() == target.get(); 182 }; 183 if (llvm::find_if(effects, consumesTarget) != effects.end()) 184 recordHandleInvalidation(target); 185 } 186 return success(); 187 } 188 189 LogicalResult 190 transform::TransformState::applyTransform(TransformOpInterface transform) { 191 if (options.getExpensiveChecksEnabled() && 192 failed(checkAndRecordHandleInvalidation(transform))) { 193 return failure(); 194 } 195 196 transform::TransformResults results(transform->getNumResults()); 197 if (failed(transform.apply(results, *this))) 198 return failure(); 199 200 // Remove the mapping for the operand if it is consumed by the operation. This 201 // allows us to catch use-after-free with assertions later on. 202 auto memEffectInterface = 203 cast<MemoryEffectOpInterface>(transform.getOperation()); 204 SmallVector<MemoryEffects::EffectInstance, 2> effects; 205 for (OpOperand &target : transform->getOpOperands()) { 206 effects.clear(); 207 memEffectInterface.getEffectsOnValue(target.get(), effects); 208 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 209 return isa<transform::TransformMappingResource>( 210 effect.getResource()) && 211 isa<MemoryEffects::Free>(effect.getEffect()); 212 })) { 213 removePayloadOps(target.get()); 214 } 215 } 216 217 for (OpResult result : transform->getResults()) { 218 assert(result.getDefiningOp() == transform.getOperation() && 219 "payload IR association for a value other than the result of the " 220 "current transform op"); 221 if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) 222 return failure(); 223 } 224 225 return success(); 226 } 227 228 //===----------------------------------------------------------------------===// 229 // TransformState::Extension 230 //===----------------------------------------------------------------------===// 231 232 transform::TransformState::Extension::~Extension() = default; 233 234 LogicalResult 235 transform::TransformState::Extension::replacePayloadOp(Operation *op, 236 Operation *replacement) { 237 return state.updatePayloadOps(state.getHandleForPayloadOp(op), 238 [&](Operation *current) { 239 return current == op ? replacement : current; 240 }); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // TransformResults 245 //===----------------------------------------------------------------------===// 246 247 transform::TransformResults::TransformResults(unsigned numSegments) { 248 segments.resize(numSegments, 249 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 250 } 251 252 void transform::TransformResults::set(OpResult value, 253 ArrayRef<Operation *> ops) { 254 unsigned position = value.getResultNumber(); 255 assert(position < segments.size() && 256 "setting results for a non-existent handle"); 257 assert(segments[position].data() == nullptr && "results already set"); 258 unsigned start = operations.size(); 259 llvm::append_range(operations, ops); 260 segments[position] = makeArrayRef(operations).drop_front(start); 261 } 262 263 ArrayRef<Operation *> 264 transform::TransformResults::get(unsigned resultNumber) const { 265 assert(resultNumber < segments.size() && 266 "querying results for a non-existent handle"); 267 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 268 return segments[resultNumber]; 269 } 270 271 //===----------------------------------------------------------------------===// 272 // Utilities for PossibleTopLevelTransformOpTrait. 273 //===----------------------------------------------------------------------===// 274 275 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 276 TransformState &state, Operation *op) { 277 SmallVector<Operation *> targets; 278 if (op->getNumOperands() != 0) 279 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 280 else 281 targets.push_back(state.getTopLevel()); 282 283 return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), 284 targets); 285 } 286 287 LogicalResult 288 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 289 // Attaching this trait without the interface is a misuse of the API, but it 290 // cannot be caught via a static_assert because interface registration is 291 // dynamic. 292 assert(isa<TransformOpInterface>(op) && 293 "should implement TransformOpInterface to have " 294 "PossibleTopLevelTransformOpTrait"); 295 296 if (op->getNumRegions() != 1) 297 return op->emitOpError() << "expects one region"; 298 299 Region *bodyRegion = &op->getRegion(0); 300 if (!llvm::hasNItems(*bodyRegion, 1)) 301 return op->emitOpError() << "expects a single-block region"; 302 303 Block *body = &bodyRegion->front(); 304 if (body->getNumArguments() != 1 || 305 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) { 306 return op->emitOpError() 307 << "expects the entry block to have one argument of type " 308 << pdl::OperationType::get(op->getContext()); 309 } 310 311 if (auto *parent = 312 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 313 if (op->getNumOperands() == 0) { 314 InFlightDiagnostic diag = 315 op->emitOpError() 316 << "expects the root operation to be provided for a nested op"; 317 diag.attachNote(parent->getLoc()) 318 << "nested in another possible top-level op"; 319 return diag; 320 } 321 } 322 323 return success(); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // Generated interface implementation. 328 //===----------------------------------------------------------------------===// 329 330 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 331