1 //===- NormalizeMemRefs.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an interprocedural pass to normalize memrefs to have
10 // identity layout maps.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/Utils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/Support/Debug.h"
22
23 #define DEBUG_TYPE "normalize-memrefs"
24
25 using namespace mlir;
26
27 namespace {
28
29 /// All memrefs passed across functions with non-trivial layout maps are
30 /// converted to ones with trivial identity layout ones.
31 /// If all the memref types/uses in a function are normalizable, we treat
32 /// such functions as normalizable. Also, if a normalizable function is known
33 /// to call a non-normalizable function, we treat that function as
34 /// non-normalizable as well. We assume external functions to be normalizable.
35 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
36 void runOnOperation() override;
37 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
38 bool areMemRefsNormalizable(func::FuncOp funcOp);
39 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
40 void setCalleesAndCallersNonNormalizable(
41 func::FuncOp funcOp, ModuleOp moduleOp,
42 DenseSet<func::FuncOp> &normalizableFuncs);
43 Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp);
44 };
45
46 } // namespace
47
48 std::unique_ptr<OperationPass<ModuleOp>>
createNormalizeMemRefsPass()49 mlir::memref::createNormalizeMemRefsPass() {
50 return std::make_unique<NormalizeMemRefs>();
51 }
52
runOnOperation()53 void NormalizeMemRefs::runOnOperation() {
54 LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
55 ModuleOp moduleOp = getOperation();
56 // We maintain all normalizable FuncOps in a DenseSet. It is initialized
57 // with all the functions within a module and then functions which are not
58 // normalizable are removed from this set.
59 // TODO: Change this to work on FuncLikeOp once there is an operation
60 // interface for it.
61 DenseSet<func::FuncOp> normalizableFuncs;
62 // Initialize `normalizableFuncs` with all the functions within a module.
63 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
64
65 // Traverse through all the functions applying a filter which determines
66 // whether that function is normalizable or not. All callers/callees of
67 // a non-normalizable function will also become non-normalizable even if
68 // they aren't passing any or specific non-normalizable memrefs. So,
69 // functions which calls or get called by a non-normalizable becomes non-
70 // normalizable functions themselves.
71 moduleOp.walk([&](func::FuncOp funcOp) {
72 if (normalizableFuncs.contains(funcOp)) {
73 if (!areMemRefsNormalizable(funcOp)) {
74 LLVM_DEBUG(llvm::dbgs()
75 << "@" << funcOp.getName()
76 << " contains ops that cannot normalize MemRefs\n");
77 // Since this function is not normalizable, we set all the caller
78 // functions and the callees of this function as not normalizable.
79 // TODO: Drop this conservative assumption in the future.
80 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
81 normalizableFuncs);
82 }
83 }
84 });
85
86 LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
87 << " functions\n");
88 // Those functions which can be normalized are subjected to normalization.
89 for (func::FuncOp &funcOp : normalizableFuncs)
90 normalizeFuncOpMemRefs(funcOp, moduleOp);
91 }
92
93 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
94 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
95 /// are satisfied will the value become a candidate for replacement.
96 /// TODO: Extend this for DimOps.
isMemRefNormalizable(Value::user_range opUsers)97 static bool isMemRefNormalizable(Value::user_range opUsers) {
98 return llvm::all_of(opUsers, [](Operation *op) {
99 return op->hasTrait<OpTrait::MemRefsNormalizable>();
100 });
101 }
102
103 /// Set all the calling functions and the callees of the function as not
104 /// normalizable.
setCalleesAndCallersNonNormalizable(func::FuncOp funcOp,ModuleOp moduleOp,DenseSet<func::FuncOp> & normalizableFuncs)105 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
106 func::FuncOp funcOp, ModuleOp moduleOp,
107 DenseSet<func::FuncOp> &normalizableFuncs) {
108 if (!normalizableFuncs.contains(funcOp))
109 return;
110
111 LLVM_DEBUG(
112 llvm::dbgs() << "@" << funcOp.getName()
113 << " calls or is called by non-normalizable function\n");
114 normalizableFuncs.erase(funcOp);
115 // Caller of the function.
116 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
117 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
118 // TODO: Extend this for ops that are FunctionOpInterface. This would
119 // require creating an OpInterface for FunctionOpInterface ops.
120 func::FuncOp parentFuncOp =
121 symbolUse.getUser()->getParentOfType<func::FuncOp>();
122 for (func::FuncOp &funcOp : normalizableFuncs) {
123 if (parentFuncOp == funcOp) {
124 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
125 normalizableFuncs);
126 break;
127 }
128 }
129 }
130
131 // Functions called by this function.
132 funcOp.walk([&](func::CallOp callOp) {
133 StringAttr callee = callOp.getCalleeAttr().getAttr();
134 for (func::FuncOp &funcOp : normalizableFuncs) {
135 // We compare func::FuncOp and callee's name.
136 if (callee == funcOp.getNameAttr()) {
137 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
138 normalizableFuncs);
139 break;
140 }
141 }
142 });
143 }
144
145 /// Check whether all the uses of AllocOps, CallOps and function arguments of a
146 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
147 /// or ReturnOp. Only if these constraints are satisfied will the function
148 /// become a candidate for normalization. We follow a conservative approach here
149 /// wherein even if the non-normalizable memref is not a part of the function's
150 /// argument or return type, we still label the entire function as
151 /// non-normalizable. We assume external functions to be normalizable.
areMemRefsNormalizable(func::FuncOp funcOp)152 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
153 // We assume external functions to be normalizable.
154 if (funcOp.isExternal())
155 return true;
156
157 if (funcOp
158 .walk([&](memref::AllocOp allocOp) -> WalkResult {
159 Value oldMemRef = allocOp.getResult();
160 if (!isMemRefNormalizable(oldMemRef.getUsers()))
161 return WalkResult::interrupt();
162 return WalkResult::advance();
163 })
164 .wasInterrupted())
165 return false;
166
167 if (funcOp
168 .walk([&](func::CallOp callOp) -> WalkResult {
169 for (unsigned resIndex :
170 llvm::seq<unsigned>(0, callOp.getNumResults())) {
171 Value oldMemRef = callOp.getResult(resIndex);
172 if (oldMemRef.getType().isa<MemRefType>())
173 if (!isMemRefNormalizable(oldMemRef.getUsers()))
174 return WalkResult::interrupt();
175 }
176 return WalkResult::advance();
177 })
178 .wasInterrupted())
179 return false;
180
181 for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
182 BlockArgument oldMemRef = funcOp.getArgument(argIndex);
183 if (oldMemRef.getType().isa<MemRefType>())
184 if (!isMemRefNormalizable(oldMemRef.getUsers()))
185 return false;
186 }
187
188 return true;
189 }
190
191 /// Fetch the updated argument list and result of the function and update the
192 /// function signature. This updates the function's return type at the caller
193 /// site and in case the return type is a normalized memref then it updates
194 /// the calling function's signature.
195 /// TODO: An update to the calling function signature is required only if the
196 /// returned value is in turn used in ReturnOp of the calling function.
updateFunctionSignature(func::FuncOp funcOp,ModuleOp moduleOp)197 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
198 ModuleOp moduleOp) {
199 FunctionType functionType = funcOp.getFunctionType();
200 SmallVector<Type, 4> resultTypes;
201 FunctionType newFuncType;
202 resultTypes = llvm::to_vector<4>(functionType.getResults());
203
204 // External function's signature was already updated in
205 // 'normalizeFuncOpMemRefs()'.
206 if (!funcOp.isExternal()) {
207 SmallVector<Type, 8> argTypes;
208 for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
209 argTypes.push_back(argEn.value().getType());
210
211 // Traverse ReturnOps to check if an update to the return type in the
212 // function signature is required.
213 funcOp.walk([&](func::ReturnOp returnOp) {
214 for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
215 Type opType = operandEn.value().getType();
216 MemRefType memrefType = opType.dyn_cast<MemRefType>();
217 // If type is not memref or if the memref type is same as that in
218 // function's return signature then no update is required.
219 if (!memrefType || memrefType == resultTypes[operandEn.index()])
220 continue;
221 // Update function's return type signature.
222 // Return type gets normalized either as a result of function argument
223 // normalization, AllocOp normalization or an update made at CallOp.
224 // There can be many call flows inside a function and an update to a
225 // specific ReturnOp has not yet been made. So we check that the result
226 // memref type is normalized.
227 // TODO: When selective normalization is implemented, handle multiple
228 // results case where some are normalized, some aren't.
229 if (memrefType.getLayout().isIdentity())
230 resultTypes[operandEn.index()] = memrefType;
231 }
232 });
233
234 // We create a new function type and modify the function signature with this
235 // new type.
236 newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
237 /*results=*/resultTypes);
238 }
239
240 // Since we update the function signature, it might affect the result types at
241 // the caller site. Since this result might even be used by the caller
242 // function in ReturnOps, the caller function's signature will also change.
243 // Hence we record the caller function in 'funcOpsToUpdate' to update their
244 // signature as well.
245 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
246 // We iterate over all symbolic uses of the function and update the return
247 // type at the caller site.
248 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
249 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
250 Operation *userOp = symbolUse.getUser();
251 OpBuilder builder(userOp);
252 // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
253 // that the non-CallOp has no memrefs to be replaced.
254 // TODO: Handle cases where a non-CallOp symbol use of a function deals with
255 // memrefs.
256 auto callOp = dyn_cast<func::CallOp>(userOp);
257 if (!callOp)
258 continue;
259 Operation *newCallOp =
260 builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
261 resultTypes, userOp->getOperands());
262 bool replacingMemRefUsesFailed = false;
263 bool returnTypeChanged = false;
264 for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
265 OpResult oldResult = userOp->getResult(resIndex);
266 OpResult newResult = newCallOp->getResult(resIndex);
267 // This condition ensures that if the result is not of type memref or if
268 // the resulting memref was already having a trivial map layout then we
269 // need not perform any use replacement here.
270 if (oldResult.getType() == newResult.getType())
271 continue;
272 AffineMap layoutMap =
273 oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
274 if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
275 /*extraIndices=*/{},
276 /*indexRemap=*/layoutMap,
277 /*extraOperands=*/{},
278 /*symbolOperands=*/{},
279 /*domOpFilter=*/nullptr,
280 /*postDomOpFilter=*/nullptr,
281 /*allowNonDereferencingOps=*/true,
282 /*replaceInDeallocOp=*/true))) {
283 // If it failed (due to escapes for example), bail out.
284 // It should never hit this part of the code because it is called by
285 // only those functions which are normalizable.
286 newCallOp->erase();
287 replacingMemRefUsesFailed = true;
288 break;
289 }
290 returnTypeChanged = true;
291 }
292 if (replacingMemRefUsesFailed)
293 continue;
294 // Replace all uses for other non-memref result types.
295 userOp->replaceAllUsesWith(newCallOp);
296 userOp->erase();
297 if (returnTypeChanged) {
298 // Since the return type changed it might lead to a change in function's
299 // signature.
300 // TODO: If funcOp doesn't return any memref type then no need to update
301 // signature.
302 // TODO: Further optimization - Check if the memref is indeed part of
303 // ReturnOp at the parentFuncOp and only then updation of signature is
304 // required.
305 // TODO: Extend this for ops that are FunctionOpInterface. This would
306 // require creating an OpInterface for FunctionOpInterface ops.
307 func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>();
308 funcOpsToUpdate.insert(parentFuncOp);
309 }
310 }
311 // Because external function's signature is already updated in
312 // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
313 if (!funcOp.isExternal())
314 funcOp.setType(newFuncType);
315
316 // Updating the signature type of those functions which call the current
317 // function. Only if the return type of the current function has a normalized
318 // memref will the caller function become a candidate for signature update.
319 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
320 updateFunctionSignature(parentFuncOp, moduleOp);
321 }
322
323 /// Normalizes the memrefs within a function which includes those arising as a
324 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument
325 /// is used to help update function's signature after normalization.
normalizeFuncOpMemRefs(func::FuncOp funcOp,ModuleOp moduleOp)326 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
327 ModuleOp moduleOp) {
328 // Turn memrefs' non-identity layouts maps into ones with identity. Collect
329 // alloc ops first and then process since normalizeMemRef replaces/erases ops
330 // during memref rewriting.
331 SmallVector<memref::AllocOp, 4> allocOps;
332 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
333 for (memref::AllocOp allocOp : allocOps)
334 (void)normalizeMemRef(&allocOp);
335
336 // We use this OpBuilder to create new memref layout later.
337 OpBuilder b(funcOp);
338
339 FunctionType functionType = funcOp.getFunctionType();
340 SmallVector<Location> functionArgLocs(llvm::map_range(
341 funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
342 SmallVector<Type, 8> inputTypes;
343 // Walk over each argument of a function to perform memref normalization (if
344 for (unsigned argIndex :
345 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
346 Type argType = functionType.getInput(argIndex);
347 MemRefType memrefType = argType.dyn_cast<MemRefType>();
348 // Check whether argument is of MemRef type. Any other argument type can
349 // simply be part of the final function signature.
350 if (!memrefType) {
351 inputTypes.push_back(argType);
352 continue;
353 }
354 // Fetch a new memref type after normalizing the old memref to have an
355 // identity map layout.
356 MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
357 /*numSymbolicOperands=*/0);
358 if (newMemRefType == memrefType || funcOp.isExternal()) {
359 // Either memrefType already had an identity map or the map couldn't be
360 // transformed to an identity map.
361 inputTypes.push_back(newMemRefType);
362 continue;
363 }
364
365 // Insert a new temporary argument with the new memref type.
366 BlockArgument newMemRef = funcOp.front().insertArgument(
367 argIndex, newMemRefType, functionArgLocs[argIndex]);
368 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
369 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
370 // Replace all uses of the old memref.
371 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
372 /*extraIndices=*/{},
373 /*indexRemap=*/layoutMap,
374 /*extraOperands=*/{},
375 /*symbolOperands=*/{},
376 /*domOpFilter=*/nullptr,
377 /*postDomOpFilter=*/nullptr,
378 /*allowNonDereferencingOps=*/true,
379 /*replaceInDeallocOp=*/true))) {
380 // If it failed (due to escapes for example), bail out. Removing the
381 // temporary argument inserted previously.
382 funcOp.front().eraseArgument(argIndex);
383 continue;
384 }
385
386 // All uses for the argument with old memref type were replaced
387 // successfully. So we remove the old argument now.
388 funcOp.front().eraseArgument(argIndex + 1);
389 }
390
391 // Walk over normalizable operations to normalize memrefs of the operation
392 // results. When `op` has memrefs with affine map in the operation results,
393 // new operation containin normalized memrefs is created. Then, the memrefs
394 // are replaced. `CallOp` is skipped here because it is handled in
395 // `updateFunctionSignature()`.
396 funcOp.walk([&](Operation *op) {
397 if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
398 op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
399 !funcOp.isExternal()) {
400 // Create newOp containing normalized memref in the operation result.
401 Operation *newOp = createOpResultsNormalized(funcOp, op);
402 // When all of the operation results have no memrefs or memrefs without
403 // affine map, `newOp` is the same with `op` and following process is
404 // skipped.
405 if (op != newOp) {
406 bool replacingMemRefUsesFailed = false;
407 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
408 // Replace all uses of the old memrefs.
409 Value oldMemRef = op->getResult(resIndex);
410 Value newMemRef = newOp->getResult(resIndex);
411 MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
412 // Check whether the operation result is MemRef type.
413 if (!oldMemRefType)
414 continue;
415 MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
416 if (oldMemRefType == newMemRefType)
417 continue;
418 // TODO: Assume single layout map. Multiple maps not supported.
419 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
420 if (failed(replaceAllMemRefUsesWith(oldMemRef,
421 /*newMemRef=*/newMemRef,
422 /*extraIndices=*/{},
423 /*indexRemap=*/layoutMap,
424 /*extraOperands=*/{},
425 /*symbolOperands=*/{},
426 /*domOpFilter=*/nullptr,
427 /*postDomOpFilter=*/nullptr,
428 /*allowNonDereferencingOps=*/true,
429 /*replaceInDeallocOp=*/true))) {
430 newOp->erase();
431 replacingMemRefUsesFailed = true;
432 continue;
433 }
434 }
435 if (!replacingMemRefUsesFailed) {
436 // Replace other ops with new op and delete the old op when the
437 // replacement succeeded.
438 op->replaceAllUsesWith(newOp);
439 op->erase();
440 }
441 }
442 }
443 });
444
445 // In a normal function, memrefs in the return type signature gets normalized
446 // as a result of normalization of functions arguments, AllocOps or CallOps'
447 // result types. Since an external function doesn't have a body, memrefs in
448 // the return type signature can only get normalized by iterating over the
449 // individual return types.
450 if (funcOp.isExternal()) {
451 SmallVector<Type, 4> resultTypes;
452 for (unsigned resIndex :
453 llvm::seq<unsigned>(0, functionType.getNumResults())) {
454 Type resType = functionType.getResult(resIndex);
455 MemRefType memrefType = resType.dyn_cast<MemRefType>();
456 // Check whether result is of MemRef type. Any other argument type can
457 // simply be part of the final function signature.
458 if (!memrefType) {
459 resultTypes.push_back(resType);
460 continue;
461 }
462 // Computing a new memref type after normalizing the old memref to have an
463 // identity map layout.
464 MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
465 /*numSymbolicOperands=*/0);
466 resultTypes.push_back(newMemRefType);
467 }
468
469 FunctionType newFuncType =
470 FunctionType::get(&getContext(), /*inputs=*/inputTypes,
471 /*results=*/resultTypes);
472 // Setting the new function signature for this external function.
473 funcOp.setType(newFuncType);
474 }
475 updateFunctionSignature(funcOp, moduleOp);
476 }
477
478 /// Create an operation containing normalized memrefs in the operation results.
479 /// When the results of `oldOp` have memrefs with affine map, the memrefs are
480 /// normalized, and new operation containing them in the operation results is
481 /// returned. If all of the results of `oldOp` have no memrefs or memrefs
482 /// without affine map, `oldOp` is returned without modification.
createOpResultsNormalized(func::FuncOp funcOp,Operation * oldOp)483 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
484 Operation *oldOp) {
485 // Prepare OperationState to create newOp containing normalized memref in
486 // the operation results.
487 OperationState result(oldOp->getLoc(), oldOp->getName());
488 result.addOperands(oldOp->getOperands());
489 result.addAttributes(oldOp->getAttrs());
490 // Add normalized MemRefType to the OperationState.
491 SmallVector<Type, 4> resultTypes;
492 OpBuilder b(funcOp);
493 bool resultTypeNormalized = false;
494 for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
495 auto resultType = oldOp->getResult(resIndex).getType();
496 MemRefType memrefType = resultType.dyn_cast<MemRefType>();
497 // Check whether the operation result is MemRef type.
498 if (!memrefType) {
499 resultTypes.push_back(resultType);
500 continue;
501 }
502 // Fetch a new memref type after normalizing the old memref.
503 MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
504 /*numSymbolicOperands=*/0);
505 if (newMemRefType == memrefType) {
506 // Either memrefType already had an identity map or the map couldn't
507 // be transformed to an identity map.
508 resultTypes.push_back(memrefType);
509 continue;
510 }
511 resultTypes.push_back(newMemRefType);
512 resultTypeNormalized = true;
513 }
514 result.addTypes(resultTypes);
515 // When all of the results of `oldOp` have no memrefs or memrefs without
516 // affine map, `oldOp` is returned without modification.
517 if (resultTypeNormalized) {
518 OpBuilder bb(oldOp);
519 for (auto &oldRegion : oldOp->getRegions()) {
520 Region *newRegion = result.addRegion();
521 newRegion->takeBody(oldRegion);
522 }
523 return bb.create(result);
524 }
525 return oldOp;
526 }
527