1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "llvm/Support/Debug.h" 14 15 #define DEBUG_TYPE "transform-dialect" 16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 17 18 using namespace mlir; 19 20 //===----------------------------------------------------------------------===// 21 // TransformState 22 //===----------------------------------------------------------------------===// 23 24 constexpr const Value transform::TransformState::kTopLevelValue; 25 26 transform::TransformState::TransformState(Region ®ion, Operation *root, 27 const TransformOptions &options) 28 : topLevel(root), options(options) { 29 auto result = mappings.try_emplace(®ion); 30 assert(result.second && "the region scope is already present"); 31 (void)result; 32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 33 regionStack.push_back(®ion); 34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 35 } 36 37 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 38 39 ArrayRef<Operation *> 40 transform::TransformState::getPayloadOps(Value value) const { 41 const TransformOpMapping &operationMapping = getMapping(value).direct; 42 auto iter = operationMapping.find(value); 43 assert(iter != operationMapping.end() && "unknown handle"); 44 return iter->getSecond(); 45 } 46 47 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { 48 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 49 if (Value handle = mapping.reverse.lookup(op)) 50 return handle; 51 } 52 return Value(); 53 } 54 55 LogicalResult transform::TransformState::tryEmplaceReverseMapping( 56 Mappings &map, Operation *operation, Value handle) { 57 auto insertionResult = map.reverse.insert({operation, handle}); 58 if (!insertionResult.second && insertionResult.first->second != handle) { 59 InFlightDiagnostic diag = operation->emitError() 60 << "operation tracked by two handles"; 61 diag.attachNote(handle.getLoc()) << "handle"; 62 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 63 return diag; 64 } 65 return success(); 66 } 67 68 LogicalResult 69 transform::TransformState::setPayloadOps(Value value, 70 ArrayRef<Operation *> targets) { 71 assert(value != kTopLevelValue && 72 "attempting to reset the transformation root"); 73 74 if (value.use_empty()) 75 return success(); 76 77 // Setting new payload for the value without cleaning it first is a misuse of 78 // the API, assert here. 79 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 80 Mappings &mappings = getMapping(value); 81 bool inserted = 82 mappings.direct.insert({value, std::move(storedTargets)}).second; 83 assert(inserted && "value is already associated with another list"); 84 (void)inserted; 85 86 // Having multiple handles to the same operation is an error in the transform 87 // expressed using the dialect and may be constructed by valid API calls from 88 // valid IR. Emit an error here. 89 for (Operation *op : targets) { 90 if (failed(tryEmplaceReverseMapping(mappings, op, value))) 91 return failure(); 92 } 93 94 return success(); 95 } 96 97 void transform::TransformState::removePayloadOps(Value value) { 98 Mappings &mappings = getMapping(value); 99 for (Operation *op : mappings.direct[value]) 100 mappings.reverse.erase(op); 101 mappings.direct.erase(value); 102 } 103 104 LogicalResult transform::TransformState::updatePayloadOps( 105 Value value, function_ref<Operation *(Operation *)> callback) { 106 Mappings &mappings = getMapping(value); 107 auto it = mappings.direct.find(value); 108 assert(it != mappings.direct.end() && "unknown handle"); 109 SmallVector<Operation *> &association = it->getSecond(); 110 SmallVector<Operation *> updated; 111 updated.reserve(association.size()); 112 113 for (Operation *op : association) { 114 mappings.reverse.erase(op); 115 if (Operation *updatedOp = callback(op)) { 116 updated.push_back(updatedOp); 117 if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) 118 return failure(); 119 } 120 } 121 122 std::swap(association, updated); 123 return success(); 124 } 125 126 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { 127 ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get()); 128 for (const Mappings &mapping : llvm::make_second_range(mappings)) { 129 for (const auto &kvp : mapping.reverse) { 130 // If the op is associated with invalidated handle, skip the check as it 131 // may be reading invalid IR. 132 Operation *op = kvp.first; 133 Value otherHandle = kvp.second; 134 if (invalidatedHandles.count(otherHandle)) 135 continue; 136 137 for (Operation *ancestor : potentialAncestors) { 138 if (!ancestor->isProperAncestor(op)) 139 continue; 140 141 // Make sure the error-reporting lambda doesn't capture anything 142 // by-reference because it will go out of scope. Additionally, extract 143 // location from Payload IR ops because the ops themselves may be 144 // deleted before the lambda gets called. 145 Location ancestorLoc = ancestor->getLoc(); 146 Location opLoc = op->getLoc(); 147 Operation *owner = handle.getOwner(); 148 unsigned operandNo = handle.getOperandNumber(); 149 invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, 150 otherHandle]() { 151 InFlightDiagnostic diag = 152 owner->emitOpError() 153 << "invalidated the handle to payload operations nested in the " 154 "payload operation associated with its operand #" 155 << operandNo; 156 diag.attachNote(ancestorLoc) << "ancestor op"; 157 diag.attachNote(opLoc) << "nested op"; 158 diag.attachNote(otherHandle.getLoc()) << "other handle"; 159 }; 160 } 161 } 162 } 163 } 164 165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( 166 TransformOpInterface transform) { 167 auto memoryEffectsIface = 168 cast<MemoryEffectOpInterface>(transform.getOperation()); 169 SmallVector<MemoryEffects::EffectInstance> effects; 170 memoryEffectsIface.getEffectsOnResource( 171 transform::TransformMappingResource::get(), effects); 172 173 for (OpOperand &target : transform->getOpOperands()) { 174 // If the operand uses an invalidated handle, report it. 175 auto it = invalidatedHandles.find(target.get()); 176 if (it != invalidatedHandles.end()) 177 return it->getSecond()(), failure(); 178 179 // Invalidate handles pointing to the operations nested in the operation 180 // associated with the handle consumed by this operation. 181 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { 182 return isa<MemoryEffects::Free>(effect.getEffect()) && 183 effect.getValue() == target.get(); 184 }; 185 if (llvm::any_of(effects, consumesTarget)) 186 recordHandleInvalidation(target); 187 } 188 return success(); 189 } 190 191 DiagnosedSilenceableFailure 192 transform::TransformState::applyTransform(TransformOpInterface transform) { 193 LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); 194 if (options.getExpensiveChecksEnabled()) { 195 if (failed(checkAndRecordHandleInvalidation(transform))) 196 return DiagnosedSilenceableFailure::definiteFailure(); 197 198 for (OpOperand &operand : transform->getOpOperands()) { 199 if (!isHandleConsumed(operand.get(), transform)) 200 continue; 201 202 DenseSet<Operation *> seen; 203 for (Operation *op : getPayloadOps(operand.get())) { 204 if (!seen.insert(op).second) { 205 DiagnosedSilenceableFailure diag = 206 transform.emitSilenceableError() 207 << "a handle passed as operand #" << operand.getOperandNumber() 208 << " and consumed by this operation points to a payload " 209 "operation more than once"; 210 diag.attachNote(op->getLoc()) << "repeated target op"; 211 return diag; 212 } 213 } 214 } 215 } 216 217 transform::TransformResults results(transform->getNumResults()); 218 DiagnosedSilenceableFailure result(transform.apply(results, *this)); 219 if (!result.succeeded()) 220 return result; 221 222 // Remove the mapping for the operand if it is consumed by the operation. This 223 // allows us to catch use-after-free with assertions later on. 224 auto memEffectInterface = 225 cast<MemoryEffectOpInterface>(transform.getOperation()); 226 SmallVector<MemoryEffects::EffectInstance, 2> effects; 227 for (OpOperand &target : transform->getOpOperands()) { 228 effects.clear(); 229 memEffectInterface.getEffectsOnValue(target.get(), effects); 230 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 231 return isa<transform::TransformMappingResource>( 232 effect.getResource()) && 233 isa<MemoryEffects::Free>(effect.getEffect()); 234 })) { 235 removePayloadOps(target.get()); 236 } 237 } 238 239 for (OpResult result : transform->getResults()) { 240 assert(result.getDefiningOp() == transform.getOperation() && 241 "payload IR association for a value other than the result of the " 242 "current transform op"); 243 if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) 244 return DiagnosedSilenceableFailure::definiteFailure(); 245 } 246 247 return DiagnosedSilenceableFailure::success(); 248 } 249 250 //===----------------------------------------------------------------------===// 251 // TransformState::Extension 252 //===----------------------------------------------------------------------===// 253 254 transform::TransformState::Extension::~Extension() = default; 255 256 LogicalResult 257 transform::TransformState::Extension::replacePayloadOp(Operation *op, 258 Operation *replacement) { 259 return state.updatePayloadOps(state.getHandleForPayloadOp(op), 260 [&](Operation *current) { 261 return current == op ? replacement : current; 262 }); 263 } 264 265 //===----------------------------------------------------------------------===// 266 // TransformResults 267 //===----------------------------------------------------------------------===// 268 269 transform::TransformResults::TransformResults(unsigned numSegments) { 270 segments.resize(numSegments, 271 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 272 } 273 274 void transform::TransformResults::set(OpResult value, 275 ArrayRef<Operation *> ops) { 276 unsigned position = value.getResultNumber(); 277 assert(position < segments.size() && 278 "setting results for a non-existent handle"); 279 assert(segments[position].data() == nullptr && "results already set"); 280 unsigned start = operations.size(); 281 llvm::append_range(operations, ops); 282 segments[position] = makeArrayRef(operations).drop_front(start); 283 } 284 285 ArrayRef<Operation *> 286 transform::TransformResults::get(unsigned resultNumber) const { 287 assert(resultNumber < segments.size() && 288 "querying results for a non-existent handle"); 289 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 290 return segments[resultNumber]; 291 } 292 293 //===----------------------------------------------------------------------===// 294 // Utilities for PossibleTopLevelTransformOpTrait. 295 //===----------------------------------------------------------------------===// 296 297 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 298 TransformState &state, Operation *op, Region ®ion) { 299 SmallVector<Operation *> targets; 300 if (op->getNumOperands() != 0) 301 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 302 else 303 targets.push_back(state.getTopLevel()); 304 305 return state.mapBlockArguments(region.front().getArgument(0), targets); 306 } 307 308 LogicalResult 309 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 310 // Attaching this trait without the interface is a misuse of the API, but it 311 // cannot be caught via a static_assert because interface registration is 312 // dynamic. 313 assert(isa<TransformOpInterface>(op) && 314 "should implement TransformOpInterface to have " 315 "PossibleTopLevelTransformOpTrait"); 316 317 if (op->getNumRegions() < 1) 318 return op->emitOpError() << "expects at least one region"; 319 320 Region *bodyRegion = &op->getRegion(0); 321 if (!llvm::hasNItems(*bodyRegion, 1)) 322 return op->emitOpError() << "expects a single-block region"; 323 324 Block *body = &bodyRegion->front(); 325 if (body->getNumArguments() != 1 || 326 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) { 327 return op->emitOpError() 328 << "expects the entry block to have one argument of type " 329 << pdl::OperationType::get(op->getContext()); 330 } 331 332 if (auto *parent = 333 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 334 if (op->getNumOperands() == 0) { 335 InFlightDiagnostic diag = 336 op->emitOpError() 337 << "expects the root operation to be provided for a nested op"; 338 diag.attachNote(parent->getLoc()) 339 << "nested in another possible top-level op"; 340 return diag; 341 } 342 } 343 344 return success(); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // Memory effects. 349 //===----------------------------------------------------------------------===// 350 351 void transform::consumesHandle( 352 ValueRange handles, 353 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 354 for (Value handle : handles) { 355 effects.emplace_back(MemoryEffects::Read::get(), handle, 356 TransformMappingResource::get()); 357 effects.emplace_back(MemoryEffects::Free::get(), handle, 358 TransformMappingResource::get()); 359 } 360 } 361 362 /// Returns `true` if the given list of effects instances contains an instance 363 /// with the effect type specified as template parameter. 364 template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource> 365 static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) { 366 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 367 return isa<EffectTy>(effect.getEffect()) && 368 isa<ResourceTy>(effect.getResource()); 369 }); 370 } 371 372 bool transform::isHandleConsumed(Value handle, 373 transform::TransformOpInterface transform) { 374 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 375 SmallVector<MemoryEffects::EffectInstance> effects; 376 iface.getEffectsOnValue(handle, effects); 377 return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) && 378 hasEffect<MemoryEffects::Free, TransformMappingResource>(effects); 379 } 380 381 void transform::producesHandle( 382 ValueRange handles, 383 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 384 for (Value handle : handles) { 385 effects.emplace_back(MemoryEffects::Allocate::get(), handle, 386 TransformMappingResource::get()); 387 effects.emplace_back(MemoryEffects::Write::get(), handle, 388 TransformMappingResource::get()); 389 } 390 } 391 392 void transform::onlyReadsHandle( 393 ValueRange handles, 394 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 395 for (Value handle : handles) { 396 effects.emplace_back(MemoryEffects::Read::get(), handle, 397 TransformMappingResource::get()); 398 } 399 } 400 401 void transform::modifiesPayload( 402 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 403 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 404 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 405 } 406 407 void transform::onlyReadsPayload( 408 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 409 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 410 } 411 412 //===----------------------------------------------------------------------===// 413 // Generated interface implementation. 414 //===----------------------------------------------------------------------===// 415 416 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 417