1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/Operation.h" 12 #include "llvm/ADT/ScopeExit.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // TransformState 19 //===----------------------------------------------------------------------===// 20 21 constexpr const Value transform::TransformState::kTopLevelValue; 22 23 transform::TransformState::TransformState(Region ®ion, Operation *root) 24 : topLevel(root) { 25 auto result = mappings.try_emplace(®ion); 26 assert(result.second && "the region scope is already present"); 27 (void)result; 28 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 29 regionStack.push_back(®ion); 30 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 31 } 32 33 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 34 35 ArrayRef<Operation *> 36 transform::TransformState::getPayloadOps(Value value) const { 37 const TransformOpMapping &operationMapping = getMapping(value).direct; 38 auto iter = operationMapping.find(value); 39 assert(iter != operationMapping.end() && "unknown handle"); 40 return iter->getSecond(); 41 } 42 43 LogicalResult 44 transform::TransformState::setPayloadOps(Value value, 45 ArrayRef<Operation *> targets) { 46 assert(value != kTopLevelValue && 47 "attempting to reset the transformation root"); 48 49 if (value.use_empty()) 50 return success(); 51 52 // Setting new payload for the value without cleaning it first is a misuse of 53 // the API, assert here. 54 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 55 Mappings &mappings = getMapping(value); 56 bool inserted = 57 mappings.direct.insert({value, std::move(storedTargets)}).second; 58 assert(inserted && "value is already associated with another list"); 59 (void)inserted; 60 61 // Having multiple handles to the same operation is an error in the transform 62 // expressed using the dialect and may be constructed by valid API calls from 63 // valid IR. Emit an error here. 64 for (Operation *op : targets) { 65 auto insertionResult = mappings.reverse.insert({op, value}); 66 if (!insertionResult.second) { 67 InFlightDiagnostic diag = op->emitError() 68 << "operation tracked by two handles"; 69 diag.attachNote(value.getLoc()) << "handle"; 70 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 71 return diag; 72 } 73 } 74 75 return success(); 76 } 77 78 void transform::TransformState::removePayloadOps(Value value) { 79 Mappings &mappings = getMapping(value); 80 for (Operation *op : mappings.direct[value]) 81 mappings.reverse.erase(op); 82 mappings.direct.erase(value); 83 } 84 85 void transform::TransformState::updatePayloadOps( 86 Value value, function_ref<Operation *(Operation *)> callback) { 87 auto it = getMapping(value).direct.find(value); 88 assert(it != getMapping(value).direct.end() && "unknown handle"); 89 SmallVector<Operation *> &association = it->getSecond(); 90 SmallVector<Operation *> updated; 91 updated.reserve(association.size()); 92 93 for (Operation *op : association) 94 if (Operation *updatedOp = callback(op)) 95 updated.push_back(updatedOp); 96 97 std::swap(association, updated); 98 } 99 100 LogicalResult 101 transform::TransformState::applyTransform(TransformOpInterface transform) { 102 transform::TransformResults results(transform->getNumResults()); 103 if (failed(transform.apply(results, *this))) 104 return failure(); 105 106 for (Value target : transform->getOperands()) 107 removePayloadOps(target); 108 109 for (auto &en : llvm::enumerate(transform->getResults())) { 110 assert(en.value().getDefiningOp() == transform.getOperation() && 111 "payload IR association for a value other than the result of the " 112 "current transform op"); 113 if (failed(setPayloadOps(en.value(), results.get(en.index())))) 114 return failure(); 115 } 116 117 return success(); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // TransformResults 122 //===----------------------------------------------------------------------===// 123 124 transform::TransformResults::TransformResults(unsigned numSegments) { 125 segments.resize(numSegments, 126 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 127 } 128 129 void transform::TransformResults::set(OpResult value, 130 ArrayRef<Operation *> ops) { 131 unsigned position = value.getResultNumber(); 132 assert(position < segments.size() && 133 "setting results for a non-existent handle"); 134 assert(segments[position].data() == nullptr && "results already set"); 135 unsigned start = operations.size(); 136 llvm::append_range(operations, ops); 137 segments[position] = makeArrayRef(operations).drop_front(start); 138 } 139 140 ArrayRef<Operation *> 141 transform::TransformResults::get(unsigned resultNumber) const { 142 assert(resultNumber < segments.size() && 143 "querying results for a non-existent handle"); 144 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 145 return segments[resultNumber]; 146 } 147 148 //===----------------------------------------------------------------------===// 149 // Generated interface implementation. 150 //===----------------------------------------------------------------------===// 151 152 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 153