1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/Operation.h" 12 #include "llvm/ADT/SmallPtrSet.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // TransformState 18 //===----------------------------------------------------------------------===// 19 20 constexpr const Value transform::TransformState::kTopLevelValue; 21 22 transform::TransformState::TransformState(Operation *root) { 23 operationMapping[kTopLevelValue].push_back(root); 24 } 25 26 Operation *transform::TransformState::getTopLevel() const { 27 return operationMapping.lookup(kTopLevelValue).front(); 28 } 29 30 ArrayRef<Operation *> 31 transform::TransformState::getPayloadOps(Value value) const { 32 auto iter = operationMapping.find(value); 33 assert(iter != operationMapping.end() && "unknown handle"); 34 return iter->getSecond(); 35 } 36 37 LogicalResult 38 transform::TransformState::setPayloadOps(Value value, 39 ArrayRef<Operation *> targets) { 40 assert(value != kTopLevelValue && 41 "attempting to reset the transformation root"); 42 43 if (value.use_empty()) 44 return success(); 45 46 // Setting new payload for the value without cleaning it first is a misuse of 47 // the API, assert here. 48 SmallVector<Operation *> storedTargets(targets.begin(), targets.end()); 49 bool inserted = 50 operationMapping.insert({value, std::move(storedTargets)}).second; 51 assert(inserted && "value is already associated with another list"); 52 (void)inserted; 53 54 // Having multiple handles to the same operation is an error in the transform 55 // expressed using the dialect and may be constructed by valid API calls from 56 // valid IR. Emit an error here. 57 for (Operation *op : targets) { 58 auto insertionResult = reverseMapping.insert({op, value}); 59 if (!insertionResult.second) { 60 InFlightDiagnostic diag = op->emitError() 61 << "operation tracked by two handles"; 62 diag.attachNote(value.getLoc()) << "handle"; 63 diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; 64 return diag; 65 } 66 } 67 68 return success(); 69 } 70 71 void transform::TransformState::removePayloadOps(Value value) { 72 for (Operation *op : operationMapping[value]) 73 reverseMapping.erase(op); 74 operationMapping.erase(value); 75 } 76 77 void transform::TransformState::updatePayloadOps( 78 Value value, function_ref<Operation *(Operation *)> callback) { 79 auto it = operationMapping.find(value); 80 assert(it != operationMapping.end() && "unknown handle"); 81 SmallVector<Operation *> &association = it->getSecond(); 82 SmallVector<Operation *> updated; 83 updated.reserve(association.size()); 84 85 for (Operation *op : association) 86 if (Operation *updatedOp = callback(op)) 87 updated.push_back(updatedOp); 88 89 std::swap(association, updated); 90 } 91 92 LogicalResult 93 transform::TransformState::applyTransform(TransformOpInterface transform) { 94 transform::TransformResults results(transform->getNumResults()); 95 if (failed(transform.apply(results, *this))) 96 return failure(); 97 98 for (Value target : transform->getOperands()) 99 removePayloadOps(target); 100 101 for (auto &en : llvm::enumerate(transform->getResults())) 102 if (failed(setPayloadOps(en.value(), results.get(en.index())))) 103 return failure(); 104 105 return success(); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // TransformResults 110 //===----------------------------------------------------------------------===// 111 112 transform::TransformResults::TransformResults(unsigned numSegments) { 113 segments.resize(numSegments, 114 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0))); 115 } 116 117 void transform::TransformResults::set(OpResult value, 118 ArrayRef<Operation *> ops) { 119 unsigned position = value.getResultNumber(); 120 assert(position < segments.size() && 121 "setting results for a non-existent handle"); 122 assert(segments[position].data() == nullptr && "results already set"); 123 unsigned start = operations.size(); 124 llvm::append_range(operations, ops); 125 segments[position] = makeArrayRef(operations).drop_front(start); 126 } 127 128 ArrayRef<Operation *> 129 transform::TransformResults::get(unsigned resultNumber) const { 130 assert(resultNumber < segments.size() && 131 "querying results for a non-existent handle"); 132 assert(segments[resultNumber].data() != nullptr && "querying unset results"); 133 return segments[resultNumber]; 134 } 135 136 //===----------------------------------------------------------------------===// 137 // Generated interface implementation. 138 //===----------------------------------------------------------------------===// 139 140 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" 141