1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "llvm/Support/Debug.h" 14 15 #define DEBUG_TYPE "transform-dialect" 16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 17 18 using namespace mlir; 19 20 //===----------------------------------------------------------------------===// 21 // TransformState 22 //===----------------------------------------------------------------------===// 23 24 constexpr const Value transform::TransformState::kTopLevelValue; 25 26 transform::TransformState::TransformState(Region ®ion, Operation *root, 27 const TransformOptions &options) 28 : topLevel(root), options(options) { 29 auto result = mappings.try_emplace(®ion); 30 assert(result.second && "the region scope is already present"); 31 (void)result; 32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 33 regionStack.push_back(®ion); 34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 35 } 36 37 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 38 39 ArrayRef<Operation *> 40 transform::TransformState::getPayloadOps(Value value) const { 41 const TransformOpMapping &operationMapping = getMapping(value).direct; 42 auto iter = operationMapping.find(value); 43 assert(iter != operationMapping.end() && "unknown handle"); 44 return iter->getSecond(); 45 } 46 47 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { 48 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 49 if (Value handle = mapping.reverse.lookup(op)) 50 return handle; 51 } 52 return Value(); 53 } 54 55 LogicalResult transform::TransformState::tryEmplaceReverseMapping( 56 Mappings &map, Operation *operation, Value handle) { 57 auto insertionResult = map.reverse.insert({operation, handle}); 58 if (!insertionResult.second) { 59 InFlightDiagnostic diag = operation->emitError() 60 << "operation tracked by two handles"; 61 diag.attachNote(handle.getLoc()) << "handle"; 62 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 63 return diag; 64 } 65 return success(); 66 } 67 68 LogicalResult 69 transform::TransformState::setPayloadOps(Value value, 70 ArrayRef<Operation *> targets) { 71 assert(value != kTopLevelValue && 72 "attempting to reset the transformation root"); 73 74 if (value.use_empty()) 75 return success(); 76 77 // Setting new payload for the value without cleaning it first is a misuse of 78 // the API, assert here. 79 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 80 Mappings &mappings = getMapping(value); 81 bool inserted = 82 mappings.direct.insert({value, std::move(storedTargets)}).second; 83 assert(inserted && "value is already associated with another list"); 84 (void)inserted; 85 86 // Having multiple handles to the same operation is an error in the transform 87 // expressed using the dialect and may be constructed by valid API calls from 88 // valid IR. Emit an error here. 89 for (Operation *op : targets) { 90 if (failed(tryEmplaceReverseMapping(mappings, op, value))) 91 return failure(); 92 } 93 94 return success(); 95 } 96 97 void transform::TransformState::removePayloadOps(Value value) { 98 Mappings &mappings = getMapping(value); 99 for (Operation *op : mappings.direct[value]) 100 mappings.reverse.erase(op); 101 mappings.direct.erase(value); 102 } 103 104 LogicalResult transform::TransformState::updatePayloadOps( 105 Value value, function_ref<Operation *(Operation *)> callback) { 106 Mappings &mappings = getMapping(value); 107 auto it = mappings.direct.find(value); 108 assert(it != mappings.direct.end() && "unknown handle"); 109 SmallVector<Operation *> &association = it->getSecond(); 110 SmallVector<Operation *> updated; 111 updated.reserve(association.size()); 112 113 for (Operation *op : association) { 114 mappings.reverse.erase(op); 115 if (Operation *updatedOp = callback(op)) { 116 updated.push_back(updatedOp); 117 if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) 118 return failure(); 119 } 120 } 121 122 std::swap(association, updated); 123 return success(); 124 } 125 126 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { 127 ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get()); 128 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 129 for (const auto &kvp : mapping.reverse) { 130 // If the op is associated with invalidated handle, skip the check as it 131 // may be reading invalid IR. 132 Operation *op = kvp.first; 133 Value otherHandle = kvp.second; 134 if (invalidatedHandles.count(otherHandle)) 135 continue; 136 137 for (Operation *ancestor : potentialAncestors) { 138 if (!ancestor->isProperAncestor(op)) 139 continue; 140 141 // Make sure the error-reporting lambda doesn't capture anything 142 // by-reference because it will go out of scope. Additionally, extract 143 // location from Payload IR ops because the ops themselves may be 144 // deleted before the lambda gets called. 145 Location ancestorLoc = ancestor->getLoc(); 146 Location opLoc = op->getLoc(); 147 Operation *owner = handle.getOwner(); 148 unsigned operandNo = handle.getOperandNumber(); 149 invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, 150 otherHandle]() { 151 InFlightDiagnostic diag = 152 owner->emitOpError() 153 << "invalidated the handle to payload operations nested in the " 154 "payload operation associated with its operand #" 155 << operandNo; 156 diag.attachNote(ancestorLoc) << "ancestor op"; 157 diag.attachNote(opLoc) << "nested op"; 158 diag.attachNote(otherHandle.getLoc()) << "other handle"; 159 }; 160 } 161 } 162 } 163 } 164 165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( 166 TransformOpInterface transform) { 167 auto memoryEffectsIface = 168 cast<MemoryEffectOpInterface>(transform.getOperation()); 169 SmallVector<MemoryEffects::EffectInstance> effects; 170 memoryEffectsIface.getEffectsOnResource( 171 transform::TransformMappingResource::get(), effects); 172 173 for (OpOperand &target : transform->getOpOperands()) { 174 // If the operand uses an invalidated handle, report it. 175 auto it = invalidatedHandles.find(target.get()); 176 if (it != invalidatedHandles.end()) 177 return it->getSecond()(), failure(); 178 179 // Invalidate handles pointing to the operations nested in the operation 180 // associated with the handle consumed by this operation. 181 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { 182 return isa<MemoryEffects::Free>(effect.getEffect()) && 183 effect.getValue() == target.get(); 184 }; 185 if (llvm::find_if(effects, consumesTarget) != effects.end()) 186 recordHandleInvalidation(target); 187 } 188 return success(); 189 } 190 191 DiagnosedSilenceableFailure 192 transform::TransformState::applyTransform(TransformOpInterface transform) { 193 LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); 194 if (options.getExpensiveChecksEnabled() && 195 failed(checkAndRecordHandleInvalidation(transform))) { 196 return DiagnosedSilenceableFailure::definiteFailure(); 197 } 198 199 transform::TransformResults results(transform->getNumResults()); 200 DiagnosedSilenceableFailure result(transform.apply(results, *this)); 201 if (!result.succeeded()) 202 return result; 203 204 // Remove the mapping for the operand if it is consumed by the operation. This 205 // allows us to catch use-after-free with assertions later on. 206 auto memEffectInterface = 207 cast<MemoryEffectOpInterface>(transform.getOperation()); 208 SmallVector<MemoryEffects::EffectInstance, 2> effects; 209 for (OpOperand &target : transform->getOpOperands()) { 210 effects.clear(); 211 memEffectInterface.getEffectsOnValue(target.get(), effects); 212 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 213 return isa<transform::TransformMappingResource>( 214 effect.getResource()) && 215 isa<MemoryEffects::Free>(effect.getEffect()); 216 })) { 217 removePayloadOps(target.get()); 218 } 219 } 220 221 for (OpResult result : transform->getResults()) { 222 assert(result.getDefiningOp() == transform.getOperation() && 223 "payload IR association for a value other than the result of the " 224 "current transform op"); 225 if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) 226 return DiagnosedSilenceableFailure::definiteFailure(); 227 } 228 229 return DiagnosedSilenceableFailure::success(); 230 } 231 232 //===----------------------------------------------------------------------===// 233 // TransformState::Extension 234 //===----------------------------------------------------------------------===// 235 236 transform::TransformState::Extension::~Extension() = default; 237 238 LogicalResult 239 transform::TransformState::Extension::replacePayloadOp(Operation *op, 240 Operation *replacement) { 241 return state.updatePayloadOps(state.getHandleForPayloadOp(op), 242 [&](Operation *current) { 243 return current == op ? replacement : current; 244 }); 245 } 246 247 //===----------------------------------------------------------------------===// 248 // TransformResults 249 //===----------------------------------------------------------------------===// 250 251 transform::TransformResults::TransformResults(unsigned numSegments) { 252 segments.resize(numSegments, 253 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 254 } 255 256 void transform::TransformResults::set(OpResult value, 257 ArrayRef<Operation *> ops) { 258 unsigned position = value.getResultNumber(); 259 assert(position < segments.size() && 260 "setting results for a non-existent handle"); 261 assert(segments[position].data() == nullptr && "results already set"); 262 unsigned start = operations.size(); 263 llvm::append_range(operations, ops); 264 segments[position] = makeArrayRef(operations).drop_front(start); 265 } 266 267 ArrayRef<Operation *> 268 transform::TransformResults::get(unsigned resultNumber) const { 269 assert(resultNumber < segments.size() && 270 "querying results for a non-existent handle"); 271 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 272 return segments[resultNumber]; 273 } 274 275 //===----------------------------------------------------------------------===// 276 // Utilities for PossibleTopLevelTransformOpTrait. 277 //===----------------------------------------------------------------------===// 278 279 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 280 TransformState &state, Operation *op, Region ®ion) { 281 SmallVector<Operation *> targets; 282 if (op->getNumOperands() != 0) 283 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 284 else 285 targets.push_back(state.getTopLevel()); 286 287 return state.mapBlockArguments(region.front().getArgument(0), targets); 288 } 289 290 LogicalResult 291 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 292 // Attaching this trait without the interface is a misuse of the API, but it 293 // cannot be caught via a static_assert because interface registration is 294 // dynamic. 295 assert(isa<TransformOpInterface>(op) && 296 "should implement TransformOpInterface to have " 297 "PossibleTopLevelTransformOpTrait"); 298 299 if (op->getNumRegions() < 1) 300 return op->emitOpError() << "expects at least one region"; 301 302 Region *bodyRegion = &op->getRegion(0); 303 if (!llvm::hasNItems(*bodyRegion, 1)) 304 return op->emitOpError() << "expects a single-block region"; 305 306 Block *body = &bodyRegion->front(); 307 if (body->getNumArguments() != 1 || 308 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) { 309 return op->emitOpError() 310 << "expects the entry block to have one argument of type " 311 << pdl::OperationType::get(op->getContext()); 312 } 313 314 if (auto *parent = 315 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 316 if (op->getNumOperands() == 0) { 317 InFlightDiagnostic diag = 318 op->emitOpError() 319 << "expects the root operation to be provided for a nested op"; 320 diag.attachNote(parent->getLoc()) 321 << "nested in another possible top-level op"; 322 return diag; 323 } 324 } 325 326 return success(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // Generated interface implementation. 331 //===----------------------------------------------------------------------===// 332 333 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 334