1 //===- NormalizeMemRefs.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an interprocedural pass to normalize memrefs to have
10 // identity layout maps.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/Utils.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "normalize-memrefs"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// All memrefs passed across functions with non-trivial layout maps are
29 /// converted to ones with trivial identity layout ones.
30 /// If all the memref types/uses in a function are normalizable, we treat
31 /// such functions as normalizable. Also, if a normalizable function is known
32 /// to call a non-normalizable function, we treat that function as
33 /// non-normalizable as well. We assume external functions to be normalizable.
34 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
35   void runOnOperation() override;
36   void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
37   bool areMemRefsNormalizable(FuncOp funcOp);
38   void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
39   void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
40                                            DenseSet<FuncOp> &normalizableFuncs);
41   Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
42 };
43 
44 } // namespace
45 
46 std::unique_ptr<OperationPass<ModuleOp>>
47 mlir::memref::createNormalizeMemRefsPass() {
48   return std::make_unique<NormalizeMemRefs>();
49 }
50 
51 void NormalizeMemRefs::runOnOperation() {
52   LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
53   ModuleOp moduleOp = getOperation();
54   // We maintain all normalizable FuncOps in a DenseSet. It is initialized
55   // with all the functions within a module and then functions which are not
56   // normalizable are removed from this set.
57   // TODO: Change this to work on FuncLikeOp once there is an operation
58   // interface for it.
59   DenseSet<FuncOp> normalizableFuncs;
60   // Initialize `normalizableFuncs` with all the functions within a module.
61   moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
62 
63   // Traverse through all the functions applying a filter which determines
64   // whether that function is normalizable or not. All callers/callees of
65   // a non-normalizable function will also become non-normalizable even if
66   // they aren't passing any or specific non-normalizable memrefs. So,
67   // functions which calls or get called by a non-normalizable becomes non-
68   // normalizable functions themselves.
69   moduleOp.walk([&](FuncOp funcOp) {
70     if (normalizableFuncs.contains(funcOp)) {
71       if (!areMemRefsNormalizable(funcOp)) {
72         LLVM_DEBUG(llvm::dbgs()
73                    << "@" << funcOp.getName()
74                    << " contains ops that cannot normalize MemRefs\n");
75         // Since this function is not normalizable, we set all the caller
76         // functions and the callees of this function as not normalizable.
77         // TODO: Drop this conservative assumption in the future.
78         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
79                                             normalizableFuncs);
80       }
81     }
82   });
83 
84   LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
85                           << " functions\n");
86   // Those functions which can be normalized are subjected to normalization.
87   for (FuncOp &funcOp : normalizableFuncs)
88     normalizeFuncOpMemRefs(funcOp, moduleOp);
89 }
90 
91 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
92 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
93 /// are satisfied will the value become a candidate for replacement.
94 /// TODO: Extend this for DimOps.
95 static bool isMemRefNormalizable(Value::user_range opUsers) {
96   return llvm::all_of(opUsers, [](Operation *op) {
97     return op->hasTrait<OpTrait::MemRefsNormalizable>();
98   });
99 }
100 
101 /// Set all the calling functions and the callees of the function as not
102 /// normalizable.
103 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
104     FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
105   if (!normalizableFuncs.contains(funcOp))
106     return;
107 
108   LLVM_DEBUG(
109       llvm::dbgs() << "@" << funcOp.getName()
110                    << " calls or is called by non-normalizable function\n");
111   normalizableFuncs.erase(funcOp);
112   // Caller of the function.
113   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
114   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
115     // TODO: Extend this for ops that are FunctionOpInterface. This would
116     // require creating an OpInterface for FunctionOpInterface ops.
117     FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
118     for (FuncOp &funcOp : normalizableFuncs) {
119       if (parentFuncOp == funcOp) {
120         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
121                                             normalizableFuncs);
122         break;
123       }
124     }
125   }
126 
127   // Functions called by this function.
128   funcOp.walk([&](CallOp callOp) {
129     StringAttr callee = callOp.getCalleeAttr().getAttr();
130     for (FuncOp &funcOp : normalizableFuncs) {
131       // We compare FuncOp and callee's name.
132       if (callee == funcOp.getNameAttr()) {
133         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
134                                             normalizableFuncs);
135         break;
136       }
137     }
138   });
139 }
140 
141 /// Check whether all the uses of AllocOps, CallOps and function arguments of a
142 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
143 /// or ReturnOp. Only if these constraints are satisfied will the function
144 /// become a candidate for normalization. We follow a conservative approach here
145 /// wherein even if the non-normalizable memref is not a part of the function's
146 /// argument or return type, we still label the entire function as
147 /// non-normalizable. We assume external functions to be normalizable.
148 bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
149   // We assume external functions to be normalizable.
150   if (funcOp.isExternal())
151     return true;
152 
153   if (funcOp
154           .walk([&](memref::AllocOp allocOp) -> WalkResult {
155             Value oldMemRef = allocOp.getResult();
156             if (!isMemRefNormalizable(oldMemRef.getUsers()))
157               return WalkResult::interrupt();
158             return WalkResult::advance();
159           })
160           .wasInterrupted())
161     return false;
162 
163   if (funcOp
164           .walk([&](CallOp callOp) -> WalkResult {
165             for (unsigned resIndex :
166                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
167               Value oldMemRef = callOp.getResult(resIndex);
168               if (oldMemRef.getType().isa<MemRefType>())
169                 if (!isMemRefNormalizable(oldMemRef.getUsers()))
170                   return WalkResult::interrupt();
171             }
172             return WalkResult::advance();
173           })
174           .wasInterrupted())
175     return false;
176 
177   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
178     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
179     if (oldMemRef.getType().isa<MemRefType>())
180       if (!isMemRefNormalizable(oldMemRef.getUsers()))
181         return false;
182   }
183 
184   return true;
185 }
186 
187 /// Fetch the updated argument list and result of the function and update the
188 /// function signature. This updates the function's return type at the caller
189 /// site and in case the return type is a normalized memref then it updates
190 /// the calling function's signature.
191 /// TODO: An update to the calling function signature is required only if the
192 /// returned value is in turn used in ReturnOp of the calling function.
193 void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
194                                                ModuleOp moduleOp) {
195   FunctionType functionType = funcOp.getType();
196   SmallVector<Type, 4> resultTypes;
197   FunctionType newFuncType;
198   resultTypes = llvm::to_vector<4>(functionType.getResults());
199 
200   // External function's signature was already updated in
201   // 'normalizeFuncOpMemRefs()'.
202   if (!funcOp.isExternal()) {
203     SmallVector<Type, 8> argTypes;
204     for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
205       argTypes.push_back(argEn.value().getType());
206 
207     // Traverse ReturnOps to check if an update to the return type in the
208     // function signature is required.
209     funcOp.walk([&](ReturnOp returnOp) {
210       for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
211         Type opType = operandEn.value().getType();
212         MemRefType memrefType = opType.dyn_cast<MemRefType>();
213         // If type is not memref or if the memref type is same as that in
214         // function's return signature then no update is required.
215         if (!memrefType || memrefType == resultTypes[operandEn.index()])
216           continue;
217         // Update function's return type signature.
218         // Return type gets normalized either as a result of function argument
219         // normalization, AllocOp normalization or an update made at CallOp.
220         // There can be many call flows inside a function and an update to a
221         // specific ReturnOp has not yet been made. So we check that the result
222         // memref type is normalized.
223         // TODO: When selective normalization is implemented, handle multiple
224         // results case where some are normalized, some aren't.
225         if (memrefType.getLayout().isIdentity())
226           resultTypes[operandEn.index()] = memrefType;
227       }
228     });
229 
230     // We create a new function type and modify the function signature with this
231     // new type.
232     newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
233                                     /*results=*/resultTypes);
234   }
235 
236   // Since we update the function signature, it might affect the result types at
237   // the caller site. Since this result might even be used by the caller
238   // function in ReturnOps, the caller function's signature will also change.
239   // Hence we record the caller function in 'funcOpsToUpdate' to update their
240   // signature as well.
241   llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
242   // We iterate over all symbolic uses of the function and update the return
243   // type at the caller site.
244   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
245   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
246     Operation *userOp = symbolUse.getUser();
247     OpBuilder builder(userOp);
248     // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
249     // that the non-CallOp has no memrefs to be replaced.
250     // TODO: Handle cases where a non-CallOp symbol use of a function deals with
251     // memrefs.
252     auto callOp = dyn_cast<CallOp>(userOp);
253     if (!callOp)
254       continue;
255     Operation *newCallOp =
256         builder.create<CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
257                                resultTypes, userOp->getOperands());
258     bool replacingMemRefUsesFailed = false;
259     bool returnTypeChanged = false;
260     for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
261       OpResult oldResult = userOp->getResult(resIndex);
262       OpResult newResult = newCallOp->getResult(resIndex);
263       // This condition ensures that if the result is not of type memref or if
264       // the resulting memref was already having a trivial map layout then we
265       // need not perform any use replacement here.
266       if (oldResult.getType() == newResult.getType())
267         continue;
268       AffineMap layoutMap =
269           oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
270       if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
271                                           /*extraIndices=*/{},
272                                           /*indexRemap=*/layoutMap,
273                                           /*extraOperands=*/{},
274                                           /*symbolOperands=*/{},
275                                           /*domOpFilter=*/nullptr,
276                                           /*postDomOpFilter=*/nullptr,
277                                           /*allowNonDereferencingOps=*/true,
278                                           /*replaceInDeallocOp=*/true))) {
279         // If it failed (due to escapes for example), bail out.
280         // It should never hit this part of the code because it is called by
281         // only those functions which are normalizable.
282         newCallOp->erase();
283         replacingMemRefUsesFailed = true;
284         break;
285       }
286       returnTypeChanged = true;
287     }
288     if (replacingMemRefUsesFailed)
289       continue;
290     // Replace all uses for other non-memref result types.
291     userOp->replaceAllUsesWith(newCallOp);
292     userOp->erase();
293     if (returnTypeChanged) {
294       // Since the return type changed it might lead to a change in function's
295       // signature.
296       // TODO: If funcOp doesn't return any memref type then no need to update
297       // signature.
298       // TODO: Further optimization - Check if the memref is indeed part of
299       // ReturnOp at the parentFuncOp and only then updation of signature is
300       // required.
301       // TODO: Extend this for ops that are FunctionOpInterface. This would
302       // require creating an OpInterface for FunctionOpInterface ops.
303       FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
304       funcOpsToUpdate.insert(parentFuncOp);
305     }
306   }
307   // Because external function's signature is already updated in
308   // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
309   if (!funcOp.isExternal())
310     funcOp.setType(newFuncType);
311 
312   // Updating the signature type of those functions which call the current
313   // function. Only if the return type of the current function has a normalized
314   // memref will the caller function become a candidate for signature update.
315   for (FuncOp parentFuncOp : funcOpsToUpdate)
316     updateFunctionSignature(parentFuncOp, moduleOp);
317 }
318 
319 /// Normalizes the memrefs within a function which includes those arising as a
320 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument
321 /// is used to help update function's signature after normalization.
322 void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
323                                               ModuleOp moduleOp) {
324   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
325   // alloc ops first and then process since normalizeMemRef replaces/erases ops
326   // during memref rewriting.
327   SmallVector<memref::AllocOp, 4> allocOps;
328   funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
329   for (memref::AllocOp allocOp : allocOps)
330     (void)normalizeMemRef(&allocOp);
331 
332   // We use this OpBuilder to create new memref layout later.
333   OpBuilder b(funcOp);
334 
335   FunctionType functionType = funcOp.getType();
336   SmallVector<Location> functionArgLocs(llvm::map_range(
337       funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
338   SmallVector<Type, 8> inputTypes;
339   // Walk over each argument of a function to perform memref normalization (if
340   for (unsigned argIndex :
341        llvm::seq<unsigned>(0, functionType.getNumInputs())) {
342     Type argType = functionType.getInput(argIndex);
343     MemRefType memrefType = argType.dyn_cast<MemRefType>();
344     // Check whether argument is of MemRef type. Any other argument type can
345     // simply be part of the final function signature.
346     if (!memrefType) {
347       inputTypes.push_back(argType);
348       continue;
349     }
350     // Fetch a new memref type after normalizing the old memref to have an
351     // identity map layout.
352     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
353                                                    /*numSymbolicOperands=*/0);
354     if (newMemRefType == memrefType || funcOp.isExternal()) {
355       // Either memrefType already had an identity map or the map couldn't be
356       // transformed to an identity map.
357       inputTypes.push_back(newMemRefType);
358       continue;
359     }
360 
361     // Insert a new temporary argument with the new memref type.
362     BlockArgument newMemRef = funcOp.front().insertArgument(
363         argIndex, newMemRefType, functionArgLocs[argIndex]);
364     BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
365     AffineMap layoutMap = memrefType.getLayout().getAffineMap();
366     // Replace all uses of the old memref.
367     if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
368                                         /*extraIndices=*/{},
369                                         /*indexRemap=*/layoutMap,
370                                         /*extraOperands=*/{},
371                                         /*symbolOperands=*/{},
372                                         /*domOpFilter=*/nullptr,
373                                         /*postDomOpFilter=*/nullptr,
374                                         /*allowNonDereferencingOps=*/true,
375                                         /*replaceInDeallocOp=*/true))) {
376       // If it failed (due to escapes for example), bail out. Removing the
377       // temporary argument inserted previously.
378       funcOp.front().eraseArgument(argIndex);
379       continue;
380     }
381 
382     // All uses for the argument with old memref type were replaced
383     // successfully. So we remove the old argument now.
384     funcOp.front().eraseArgument(argIndex + 1);
385   }
386 
387   // Walk over normalizable operations to normalize memrefs of the operation
388   // results. When `op` has memrefs with affine map in the operation results,
389   // new operation containin normalized memrefs is created. Then, the memrefs
390   // are replaced. `CallOp` is skipped here because it is handled in
391   // `updateFunctionSignature()`.
392   funcOp.walk([&](Operation *op) {
393     if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
394         op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) {
395       // Create newOp containing normalized memref in the operation result.
396       Operation *newOp = createOpResultsNormalized(funcOp, op);
397       // When all of the operation results have no memrefs or memrefs without
398       // affine map, `newOp` is the same with `op` and following process is
399       // skipped.
400       if (op != newOp) {
401         bool replacingMemRefUsesFailed = false;
402         for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
403           // Replace all uses of the old memrefs.
404           Value oldMemRef = op->getResult(resIndex);
405           Value newMemRef = newOp->getResult(resIndex);
406           MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
407           // Check whether the operation result is MemRef type.
408           if (!oldMemRefType)
409             continue;
410           MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
411           if (oldMemRefType == newMemRefType)
412             continue;
413           // TODO: Assume single layout map. Multiple maps not supported.
414           AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
415           if (failed(replaceAllMemRefUsesWith(oldMemRef,
416                                               /*newMemRef=*/newMemRef,
417                                               /*extraIndices=*/{},
418                                               /*indexRemap=*/layoutMap,
419                                               /*extraOperands=*/{},
420                                               /*symbolOperands=*/{},
421                                               /*domOpFilter=*/nullptr,
422                                               /*postDomOpFilter=*/nullptr,
423                                               /*allowNonDereferencingOps=*/true,
424                                               /*replaceInDeallocOp=*/true))) {
425             newOp->erase();
426             replacingMemRefUsesFailed = true;
427             continue;
428           }
429         }
430         if (!replacingMemRefUsesFailed) {
431           // Replace other ops with new op and delete the old op when the
432           // replacement succeeded.
433           op->replaceAllUsesWith(newOp);
434           op->erase();
435         }
436       }
437     }
438   });
439 
440   // In a normal function, memrefs in the return type signature gets normalized
441   // as a result of normalization of functions arguments, AllocOps or CallOps'
442   // result types. Since an external function doesn't have a body, memrefs in
443   // the return type signature can only get normalized by iterating over the
444   // individual return types.
445   if (funcOp.isExternal()) {
446     SmallVector<Type, 4> resultTypes;
447     for (unsigned resIndex :
448          llvm::seq<unsigned>(0, functionType.getNumResults())) {
449       Type resType = functionType.getResult(resIndex);
450       MemRefType memrefType = resType.dyn_cast<MemRefType>();
451       // Check whether result is of MemRef type. Any other argument type can
452       // simply be part of the final function signature.
453       if (!memrefType) {
454         resultTypes.push_back(resType);
455         continue;
456       }
457       // Computing a new memref type after normalizing the old memref to have an
458       // identity map layout.
459       MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
460                                                      /*numSymbolicOperands=*/0);
461       resultTypes.push_back(newMemRefType);
462     }
463 
464     FunctionType newFuncType =
465         FunctionType::get(&getContext(), /*inputs=*/inputTypes,
466                           /*results=*/resultTypes);
467     // Setting the new function signature for this external function.
468     funcOp.setType(newFuncType);
469   }
470   updateFunctionSignature(funcOp, moduleOp);
471 }
472 
473 /// Create an operation containing normalized memrefs in the operation results.
474 /// When the results of `oldOp` have memrefs with affine map, the memrefs are
475 /// normalized, and new operation containing them in the operation results is
476 /// returned. If all of the results of `oldOp` have no memrefs or memrefs
477 /// without affine map, `oldOp` is returned without modification.
478 Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
479                                                        Operation *oldOp) {
480   // Prepare OperationState to create newOp containing normalized memref in
481   // the operation results.
482   OperationState result(oldOp->getLoc(), oldOp->getName());
483   result.addOperands(oldOp->getOperands());
484   result.addAttributes(oldOp->getAttrs());
485   // Add normalized MemRefType to the OperationState.
486   SmallVector<Type, 4> resultTypes;
487   OpBuilder b(funcOp);
488   bool resultTypeNormalized = false;
489   for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
490     auto resultType = oldOp->getResult(resIndex).getType();
491     MemRefType memrefType = resultType.dyn_cast<MemRefType>();
492     // Check whether the operation result is MemRef type.
493     if (!memrefType) {
494       resultTypes.push_back(resultType);
495       continue;
496     }
497     // Fetch a new memref type after normalizing the old memref.
498     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
499                                                    /*numSymbolicOperands=*/0);
500     if (newMemRefType == memrefType) {
501       // Either memrefType already had an identity map or the map couldn't
502       // be transformed to an identity map.
503       resultTypes.push_back(memrefType);
504       continue;
505     }
506     resultTypes.push_back(newMemRefType);
507     resultTypeNormalized = true;
508   }
509   result.addTypes(resultTypes);
510   // When all of the results of `oldOp` have no memrefs or memrefs without
511   // affine map, `oldOp` is returned without modification.
512   if (resultTypeNormalized) {
513     OpBuilder bb(oldOp);
514     for (auto &oldRegion : oldOp->getRegions()) {
515       Region *newRegion = result.addRegion();
516       newRegion->takeBody(oldRegion);
517     }
518     return bb.createOperation(result);
519   }
520   return oldOp;
521 }
522