15b03e692SNicolas Vasilache //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
25b03e692SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65b03e692SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85b03e692SNicolas Vasilache //
95b03e692SNicolas Vasilache // This file implements the linalg dialect Promotion pass.
105b03e692SNicolas Vasilache //
115b03e692SNicolas Vasilache //===----------------------------------------------------------------------===//
125b03e692SNicolas Vasilache 
131834ad4aSRiver Riddle #include "PassDetail.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1542e5f422STres Popp #include "mlir/Dialect/Complex/IR/Complex.h"
16b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
175b03e692SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
18307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
195b03e692SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
215b03e692SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
225b03e692SNicolas Vasilache #include "mlir/IR/AffineExprVisitor.h"
235b03e692SNicolas Vasilache #include "mlir/IR/AffineMap.h"
24ef33c6e3SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255b03e692SNicolas Vasilache #include "mlir/Support/LLVM.h"
265b03e692SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h"
270ed2d4c7SMaheshRavishankar #include "llvm/ADT/MapVector.h"
2899069ab2SChristopher Bate #include "llvm/ADT/SmallBitVector.h"
294519ca3dSNicolas Vasilache #include "llvm/ADT/TypeSwitch.h"
305b03e692SNicolas Vasilache #include "llvm/Support/CommandLine.h"
31f89bb3c0SAlexander Belyaev #include "llvm/Support/Debug.h"
325b03e692SNicolas Vasilache 
335b03e692SNicolas Vasilache using namespace mlir;
345b03e692SNicolas Vasilache using namespace mlir::linalg;
35c25b20c0SAlex Zinenko using namespace mlir::scf;
365b03e692SNicolas Vasilache 
370ed2d4c7SMaheshRavishankar using llvm::MapVector;
385b03e692SNicolas Vasilache 
395b03e692SNicolas Vasilache #define DEBUG_TYPE "linalg-promotion"
405b03e692SNicolas Vasilache 
41ef33c6e3SNicolas Vasilache /// Alloc a new buffer of `size` * `width` i8; where `width` is given by the
42ef33c6e3SNicolas Vasilache /// data `layout` for `elementType`.
43ef33c6e3SNicolas Vasilache /// Use AllocOp or AllocaOp depending on `options`.
44ef33c6e3SNicolas Vasilache /// Take an optional alignment.
allocBuffer(ImplicitLocOpBuilder & b,const LinalgPromotionOptions & options,Type elementType,Value allocSize,DataLayout & layout,Optional<unsigned> alignment=None)45ef33c6e3SNicolas Vasilache static Value allocBuffer(ImplicitLocOpBuilder &b,
46ef33c6e3SNicolas Vasilache                          const LinalgPromotionOptions &options,
47ef33c6e3SNicolas Vasilache                          Type elementType, Value allocSize, DataLayout &layout,
488dbbb223SNicolas Vasilache                          Optional<unsigned> alignment = None) {
4942e5f422STres Popp   auto width = layout.getTypeSize(elementType);
50ef33c6e3SNicolas Vasilache 
51ef33c6e3SNicolas Vasilache   IntegerAttr alignmentAttr;
52491d2701SKazu Hirata   if (alignment.has_value())
53c27d8152SKazu Hirata     alignmentAttr = b.getI64IntegerAttr(alignment.value());
54ef33c6e3SNicolas Vasilache 
55ef33c6e3SNicolas Vasilache   // Static buffer.
56a54f4eaeSMogball   if (auto cst = allocSize.getDefiningOp<arith::ConstantIndexOp>()) {
57ef33c6e3SNicolas Vasilache     auto staticBufferType =
58a54f4eaeSMogball         MemRefType::get(width * cst.value(), b.getIntegerType(8));
59ef33c6e3SNicolas Vasilache     if (options.useAlloca) {
60ef33c6e3SNicolas Vasilache       return b.createOrFold<memref::AllocaOp>(staticBufferType, ValueRange{},
61ef33c6e3SNicolas Vasilache                                               alignmentAttr);
62ef33c6e3SNicolas Vasilache     }
63ef33c6e3SNicolas Vasilache     return b.createOrFold<memref::AllocOp>(staticBufferType, ValueRange{},
64ef33c6e3SNicolas Vasilache                                            alignmentAttr);
65ef33c6e3SNicolas Vasilache   }
66ef33c6e3SNicolas Vasilache 
67ef33c6e3SNicolas Vasilache   // Fallback dynamic buffer.
68ef33c6e3SNicolas Vasilache   auto dynamicBufferType = MemRefType::get(-1, b.getIntegerType(8));
69a54f4eaeSMogball   Value mul = b.createOrFold<arith::MulIOp>(
70a54f4eaeSMogball       b.create<arith::ConstantIndexOp>(width), allocSize);
71ef33c6e3SNicolas Vasilache   if (options.useAlloca)
72ef33c6e3SNicolas Vasilache     return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr);
73ef33c6e3SNicolas Vasilache   return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr);
745b03e692SNicolas Vasilache }
755b03e692SNicolas Vasilache 
760ed2d4c7SMaheshRavishankar /// Default allocation callback function. This allocates a promoted buffer when
770ed2d4c7SMaheshRavishankar /// no call back to do so is provided. The default is to allocate a
780ed2d4c7SMaheshRavishankar /// memref<..xi8> and return a view to get a memref type of shape
790ed2d4c7SMaheshRavishankar /// boundingSubViewSize.
804519ca3dSNicolas Vasilache static Optional<Value>
defaultAllocBufferCallBack(const LinalgPromotionOptions & options,OpBuilder & builder,memref::SubViewOp subView,ArrayRef<Value> boundingSubViewSize,Optional<unsigned> alignment,DataLayout & layout)814519ca3dSNicolas Vasilache defaultAllocBufferCallBack(const LinalgPromotionOptions &options,
824519ca3dSNicolas Vasilache                            OpBuilder &builder, memref::SubViewOp subView,
834519ca3dSNicolas Vasilache                            ArrayRef<Value> boundingSubViewSize,
844519ca3dSNicolas Vasilache                            Optional<unsigned> alignment, DataLayout &layout) {
850ed2d4c7SMaheshRavishankar   ShapedType viewType = subView.getType();
86ef33c6e3SNicolas Vasilache   ImplicitLocOpBuilder b(subView.getLoc(), builder);
87a54f4eaeSMogball   auto zero = b.createOrFold<arith::ConstantIndexOp>(0);
88a54f4eaeSMogball   auto one = b.createOrFold<arith::ConstantIndexOp>(1);
890ed2d4c7SMaheshRavishankar 
900ed2d4c7SMaheshRavishankar   Value allocSize = one;
91e4853be2SMehdi Amini   for (const auto &size : llvm::enumerate(boundingSubViewSize))
92a54f4eaeSMogball     allocSize = b.createOrFold<arith::MulIOp>(allocSize, size.value());
93ef33c6e3SNicolas Vasilache   Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize,
94ef33c6e3SNicolas Vasilache                              layout, alignment);
950ed2d4c7SMaheshRavishankar   SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
960ed2d4c7SMaheshRavishankar                                    ShapedType::kDynamicSize);
97ef33c6e3SNicolas Vasilache   Value view = b.createOrFold<memref::ViewOp>(
98ef33c6e3SNicolas Vasilache       MemRefType::get(dynSizes, viewType.getElementType()), buffer, zero,
99ef33c6e3SNicolas Vasilache       boundingSubViewSize);
1000ed2d4c7SMaheshRavishankar   return view;
1010ed2d4c7SMaheshRavishankar }
1020ed2d4c7SMaheshRavishankar 
1030ed2d4c7SMaheshRavishankar /// Default implementation of deallocation of the buffer use for promotion. It
1040ed2d4c7SMaheshRavishankar /// expects to get the same value that the default allocation method returned,
1050ed2d4c7SMaheshRavishankar /// i.e. result of a ViewOp.
1067d9518c8SNicolas Vasilache static LogicalResult
defaultDeallocBufferCallBack(const LinalgPromotionOptions & options,OpBuilder & b,Value fullLocalView)1077d9518c8SNicolas Vasilache defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
1087d9518c8SNicolas Vasilache                              OpBuilder &b, Value fullLocalView) {
1094519ca3dSNicolas Vasilache   if (!options.useAlloca) {
1104519ca3dSNicolas Vasilache     auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp());
111136d746eSJacques Pienaar     b.create<memref::DeallocOp>(viewOp.getSource().getLoc(),
112136d746eSJacques Pienaar                                 viewOp.getSource());
1134519ca3dSNicolas Vasilache   }
1140ed2d4c7SMaheshRavishankar   return success();
1150ed2d4c7SMaheshRavishankar }
1160ed2d4c7SMaheshRavishankar 
1170ed2d4c7SMaheshRavishankar namespace {
1180ed2d4c7SMaheshRavishankar 
1190ed2d4c7SMaheshRavishankar /// Helper struct that captures the information required to apply the
1200ed2d4c7SMaheshRavishankar /// transformation on each op. This bridges the abstraction gap with the
1210ed2d4c7SMaheshRavishankar /// user-facing API which exposes positional arguments to control which operands
1220ed2d4c7SMaheshRavishankar /// are promoted.
1230ed2d4c7SMaheshRavishankar struct LinalgOpInstancePromotionOptions {
1240ed2d4c7SMaheshRavishankar   LinalgOpInstancePromotionOptions(LinalgOp op,
1250ed2d4c7SMaheshRavishankar                                    const LinalgPromotionOptions &options);
1260ed2d4c7SMaheshRavishankar   /// SubViews to promote.
1274519ca3dSNicolas Vasilache   MapVector<int64_t, Value> subViews;
1280ed2d4c7SMaheshRavishankar   /// True if the full view should be used for the promoted buffer.
1290ed2d4c7SMaheshRavishankar   DenseMap<Value, bool> useFullTileBuffers;
1300ed2d4c7SMaheshRavishankar 
1310ed2d4c7SMaheshRavishankar   /// Callback functions for allocation and deallocation of promoted buffers, as
1320ed2d4c7SMaheshRavishankar   /// well as to copy the data into and out of these buffers.
1330ed2d4c7SMaheshRavishankar   AllocBufferCallbackFn allocationFn;
1340ed2d4c7SMaheshRavishankar   DeallocBufferCallbackFn deallocationFn;
1350ed2d4c7SMaheshRavishankar   CopyCallbackFn copyInFn;
1360ed2d4c7SMaheshRavishankar   CopyCallbackFn copyOutFn;
1370ed2d4c7SMaheshRavishankar 
1380ed2d4c7SMaheshRavishankar   /// Alignment of promoted buffer.
1390ed2d4c7SMaheshRavishankar   Optional<unsigned> alignment;
1400ed2d4c7SMaheshRavishankar };
1410ed2d4c7SMaheshRavishankar } // namespace
1420ed2d4c7SMaheshRavishankar 
LinalgOpInstancePromotionOptions(LinalgOp linalgOp,const LinalgPromotionOptions & options)1430ed2d4c7SMaheshRavishankar LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
1440ed2d4c7SMaheshRavishankar     LinalgOp linalgOp, const LinalgPromotionOptions &options)
145*5a001136SNicolas Vasilache     : subViews(), alignment(options.alignment) {
146b7ae1d3dSnicolasvasilache   assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
1470ed2d4c7SMaheshRavishankar   auto vUseFullTileBuffers =
14830c67587SKazu Hirata       options.useFullTileBuffers.value_or(llvm::SmallBitVector());
149e70d2c8eSTobias Gysi   vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(),
150e70d2c8eSTobias Gysi                              options.useFullTileBuffersDefault);
1510ed2d4c7SMaheshRavishankar 
152e70d2c8eSTobias Gysi   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
153e70d2c8eSTobias Gysi     int64_t operandNumber = opOperand->getOperandNumber();
154e70d2c8eSTobias Gysi     if (options.operandsToPromote &&
155e70d2c8eSTobias Gysi         !options.operandsToPromote->count(operandNumber))
1560ed2d4c7SMaheshRavishankar       continue;
157e70d2c8eSTobias Gysi     Operation *op = opOperand->get().getDefiningOp();
158e2310704SJulian Gross     if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
159e70d2c8eSTobias Gysi       subViews[operandNumber] = sv;
160e70d2c8eSTobias Gysi       useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
1610ed2d4c7SMaheshRavishankar     }
1620ed2d4c7SMaheshRavishankar   }
1630ed2d4c7SMaheshRavishankar 
1644519ca3dSNicolas Vasilache   if (options.allocationFn) {
1654519ca3dSNicolas Vasilache     allocationFn = *options.allocationFn;
1664519ca3dSNicolas Vasilache   } else {
1674519ca3dSNicolas Vasilache     allocationFn = [&](OpBuilder &b, memref::SubViewOp subViewOp,
168ef33c6e3SNicolas Vasilache                        ArrayRef<Value> boundingSubViewSize,
169ef33c6e3SNicolas Vasilache                        DataLayout &layout) -> Optional<Value> {
1704519ca3dSNicolas Vasilache       return defaultAllocBufferCallBack(options, b, subViewOp,
1714519ca3dSNicolas Vasilache                                         boundingSubViewSize, alignment, layout);
1724519ca3dSNicolas Vasilache     };
1734519ca3dSNicolas Vasilache   }
1744519ca3dSNicolas Vasilache 
1754519ca3dSNicolas Vasilache   if (options.deallocationFn) {
1764519ca3dSNicolas Vasilache     deallocationFn = *options.deallocationFn;
1774519ca3dSNicolas Vasilache   } else {
1784519ca3dSNicolas Vasilache     deallocationFn = [&](OpBuilder &b, Value buffer) {
1797d9518c8SNicolas Vasilache       return defaultDeallocBufferCallBack(options, b, buffer);
1804519ca3dSNicolas Vasilache     };
1814519ca3dSNicolas Vasilache   }
1824519ca3dSNicolas Vasilache 
1834519ca3dSNicolas Vasilache   // Save the loc because `linalgOp` goes out of scope.
1844519ca3dSNicolas Vasilache   Location loc = linalgOp.getLoc();
1854519ca3dSNicolas Vasilache   auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
1860ed2d4c7SMaheshRavishankar                                    Value dst) -> LogicalResult {
187ebc81537SAlexander Belyaev     b.create<memref::CopyOp>(loc, src, dst);
1880ed2d4c7SMaheshRavishankar     return success();
1890ed2d4c7SMaheshRavishankar   };
1900ed2d4c7SMaheshRavishankar   copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
1910ed2d4c7SMaheshRavishankar   copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack);
1920ed2d4c7SMaheshRavishankar }
1930ed2d4c7SMaheshRavishankar 
1945b03e692SNicolas Vasilache // Performs promotion of a `subView` into a local buffer of the size of the
1955b03e692SNicolas Vasilache // *ranges* of the `subView`. This produces a buffer whose size may be bigger
1965b03e692SNicolas Vasilache // than the actual size of the `subView` at the boundaries.
1975b03e692SNicolas Vasilache // This is related to the full/partial tile problem.
1985b03e692SNicolas Vasilache // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
1995b03e692SNicolas Vasilache // `partialLocalView` such that:
2005b03e692SNicolas Vasilache //   * `buffer` is always the size of the full tile.
2015b03e692SNicolas Vasilache //   * `fullLocalView` is a dense contiguous view into that buffer.
2025b03e692SNicolas Vasilache //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
2035b03e692SNicolas Vasilache //     that corresponds to the size of `subView` and accounting for boundary
2045b03e692SNicolas Vasilache //     effects.
2055b03e692SNicolas Vasilache // The point of the full tile buffer is that constant static tile sizes are
2065b03e692SNicolas Vasilache // folded and result in a buffer type with statically known size and alignment
2075b03e692SNicolas Vasilache // properties.
2085b03e692SNicolas Vasilache // To account for general boundary effects, padding must be performed on the
2095b03e692SNicolas Vasilache // boundary tiles. For now this is done with an unconditional `fill` op followed
2105b03e692SNicolas Vasilache // by a partial `copy` op.
promoteSubviewAsNewBuffer(OpBuilder & b,Location loc,memref::SubViewOp subView,const AllocBufferCallbackFn & allocationFn,DataLayout & layout)211489fec27SNicolas Vasilache FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
212e2310704SJulian Gross     OpBuilder &b, Location loc, memref::SubViewOp subView,
2131fc096afSMehdi Amini     const AllocBufferCallbackFn &allocationFn, DataLayout &layout) {
2140bd6390bSNicolas Vasilache   auto viewType = subView.getType();
2155b03e692SNicolas Vasilache   auto rank = viewType.getRank();
21605d5125dSNicolas Vasilache   SmallVector<Value, 4> fullSizes;
21705d5125dSNicolas Vasilache   SmallVector<OpFoldResult> partialSizes;
2183cb1f35dSNicolas Vasilache   fullSizes.reserve(rank);
2193cb1f35dSNicolas Vasilache   partialSizes.reserve(rank);
22099069ab2SChristopher Bate   llvm::SmallBitVector droppedDims = subView.getDroppedDims();
22199069ab2SChristopher Bate   int64_t resultDimIdx = 0;
222e4853be2SMehdi Amini   for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
22399069ab2SChristopher Bate     if (droppedDims[en.index()])
22499069ab2SChristopher Bate       continue;
2255b03e692SNicolas Vasilache     auto rangeValue = en.value();
2268dbbb223SNicolas Vasilache     // Try to extract a tight constant.
2278dbbb223SNicolas Vasilache     LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
22853da8600STobias Gysi     FailureOr<int64_t> upperBound =
22953da8600STobias Gysi         getConstantUpperBoundForIndex(rangeValue.size);
23053da8600STobias Gysi     Value size =
23153da8600STobias Gysi         failed(upperBound)
23253da8600STobias Gysi             ? rangeValue.size
233c27d8152SKazu Hirata             : b.create<arith::ConstantIndexOp>(loc, upperBound.value());
2348dbbb223SNicolas Vasilache     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
2353cb1f35dSNicolas Vasilache     fullSizes.push_back(size);
2364519ca3dSNicolas Vasilache     partialSizes.push_back(
23799069ab2SChristopher Bate         b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
2385b03e692SNicolas Vasilache   }
2393cb1f35dSNicolas Vasilache   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
2400ed2d4c7SMaheshRavishankar   // If a callback is not specified, then use the default implementation for
2410ed2d4c7SMaheshRavishankar   // allocating the promoted buffer.
242ef33c6e3SNicolas Vasilache   Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, layout);
2430ed2d4c7SMaheshRavishankar   if (!fullLocalView)
244489fec27SNicolas Vasilache     return failure();
24505d5125dSNicolas Vasilache   SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
24605d5125dSNicolas Vasilache   SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1));
247ef33c6e3SNicolas Vasilache   auto partialLocalView = b.createOrFold<memref::SubViewOp>(
248ef33c6e3SNicolas Vasilache       loc, *fullLocalView, zeros, partialSizes, ones);
2490ed2d4c7SMaheshRavishankar   return PromotionInfo{*fullLocalView, partialLocalView};
2505b03e692SNicolas Vasilache }
2515b03e692SNicolas Vasilache 
252489fec27SNicolas Vasilache static FailureOr<MapVector<int64_t, PromotionInfo>>
promoteSubViews(ImplicitLocOpBuilder & b,LinalgOpInstancePromotionOptions options,DataLayout & layout)2534519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b,
254ef33c6e3SNicolas Vasilache                 LinalgOpInstancePromotionOptions options, DataLayout &layout) {
2558dbbb223SNicolas Vasilache   if (options.subViews.empty())
256489fec27SNicolas Vasilache     return failure();
2575b03e692SNicolas Vasilache 
2584519ca3dSNicolas Vasilache   MapVector<int64_t, PromotionInfo> promotionInfoMap;
2595b03e692SNicolas Vasilache 
2608dbbb223SNicolas Vasilache   for (auto v : options.subViews) {
261e2310704SJulian Gross     memref::SubViewOp subView =
262e2310704SJulian Gross         cast<memref::SubViewOp>(v.second.getDefiningOp());
263489fec27SNicolas Vasilache     auto promotionInfo = promoteSubviewAsNewBuffer(
2644519ca3dSNicolas Vasilache         b, b.getLoc(), subView, options.allocationFn, layout);
265489fec27SNicolas Vasilache     if (failed(promotionInfo))
266489fec27SNicolas Vasilache       return failure();
2670ed2d4c7SMaheshRavishankar     promotionInfoMap[v.first] = *promotionInfo;
2680ed2d4c7SMaheshRavishankar 
269d1866f89SPierre Oechsel     // Only fill the buffer if the full local view is used
2700ed2d4c7SMaheshRavishankar     if (!options.useFullTileBuffers[v.second])
271d1866f89SPierre Oechsel       continue;
2724519ca3dSNicolas Vasilache     Type subviewEltType = subView.getType().getElementType();
2734519ca3dSNicolas Vasilache     Value fillVal =
2744519ca3dSNicolas Vasilache         llvm::TypeSwitch<Type, Value>(subviewEltType)
2754519ca3dSNicolas Vasilache             .Case([&](FloatType t) {
276a54f4eaeSMogball               return b.create<arith::ConstantOp>(FloatAttr::get(t, 0.0));
2774519ca3dSNicolas Vasilache             })
2784519ca3dSNicolas Vasilache             .Case([&](IntegerType t) {
279a54f4eaeSMogball               return b.create<arith::ConstantOp>(IntegerAttr::get(t, 0));
2804519ca3dSNicolas Vasilache             })
2814519ca3dSNicolas Vasilache             .Case([&](ComplexType t) {
2824519ca3dSNicolas Vasilache               Value tmp;
28342e5f422STres Popp               if (auto et = t.getElementType().dyn_cast<FloatType>())
284a54f4eaeSMogball                 tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0));
28542e5f422STres Popp               else if (auto et = t.getElementType().cast<IntegerType>())
286a54f4eaeSMogball                 tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0));
2874519ca3dSNicolas Vasilache               return b.create<complex::CreateOp>(t, tmp, tmp);
2884519ca3dSNicolas Vasilache             })
2894519ca3dSNicolas Vasilache             .Default([](auto) { return Value(); });
2904519ca3dSNicolas Vasilache     if (!fillVal)
291489fec27SNicolas Vasilache       return failure();
2927cef24eeSTobias Gysi     b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
2935b03e692SNicolas Vasilache   }
2945b03e692SNicolas Vasilache 
2950ed2d4c7SMaheshRavishankar   // Copy data into the promoted buffers. Use callback if provided.
2968dbbb223SNicolas Vasilache   for (auto v : options.subViews) {
2970ed2d4c7SMaheshRavishankar     auto info = promotionInfoMap.find(v.first);
2985b03e692SNicolas Vasilache     if (info == promotionInfoMap.end())
2995b03e692SNicolas Vasilache       continue;
300e2310704SJulian Gross     if (failed(options.copyInFn(
301e2310704SJulian Gross             b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
3020ed2d4c7SMaheshRavishankar             info->second.partialLocalView)))
303489fec27SNicolas Vasilache       return failure();
3045b03e692SNicolas Vasilache   }
3050ed2d4c7SMaheshRavishankar   return promotionInfoMap;
3065b03e692SNicolas Vasilache }
3075b03e692SNicolas Vasilache 
308489fec27SNicolas Vasilache static FailureOr<LinalgOp>
promoteSubViews(ImplicitLocOpBuilder & b,LinalgOp op,LinalgOpInstancePromotionOptions options,DataLayout & layout)3094519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
310ef33c6e3SNicolas Vasilache                 LinalgOpInstancePromotionOptions options, DataLayout &layout) {
311f52d7173SNicolas Vasilache   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
312f52d7173SNicolas Vasilache 
3135b03e692SNicolas Vasilache   // 1. Promote the specified views and use them in the new op.
3144519ca3dSNicolas Vasilache   auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
315489fec27SNicolas Vasilache   if (failed(promotedBuffersAndViews) ||
3160ed2d4c7SMaheshRavishankar       promotedBuffersAndViews->size() != options.subViews.size())
317489fec27SNicolas Vasilache     return failure();
3185b03e692SNicolas Vasilache 
3195b03e692SNicolas Vasilache   // 2. Append all other operands as they appear, this enforces that such
3205b03e692SNicolas Vasilache   // operands are not views. This is to support cases such as FillOp taking
3210ed2d4c7SMaheshRavishankar   // extra scalars etc.  Keep a reference to output buffers;
3220ed2d4c7SMaheshRavishankar   SmallVector<Value, 8> opViews;
323e70d2c8eSTobias Gysi   opViews.reserve(op.getNumInputsAndOutputs());
3240ed2d4c7SMaheshRavishankar   SmallVector<std::pair<Value, Value>, 8> writebackViews;
3250ed2d4c7SMaheshRavishankar   writebackViews.reserve(promotedBuffersAndViews->size());
326e70d2c8eSTobias Gysi   for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
327e70d2c8eSTobias Gysi     int64_t operandNumber = opOperand->getOperandNumber();
328e70d2c8eSTobias Gysi     if (options.subViews.count(operandNumber) != 0) {
329e70d2c8eSTobias Gysi       if (options.useFullTileBuffers[opOperand->get()])
3300ed2d4c7SMaheshRavishankar         opViews.push_back(
331e70d2c8eSTobias Gysi             (*promotedBuffersAndViews)[operandNumber].fullLocalView);
3320ed2d4c7SMaheshRavishankar       else
3330ed2d4c7SMaheshRavishankar         opViews.push_back(
334e70d2c8eSTobias Gysi             (*promotedBuffersAndViews)[operandNumber].partialLocalView);
335e70d2c8eSTobias Gysi       if (operandNumber >= op.getNumInputs())
3360ed2d4c7SMaheshRavishankar         writebackViews.emplace_back(std::make_pair(
337e70d2c8eSTobias Gysi             opOperand->get(),
338e70d2c8eSTobias Gysi             (*promotedBuffersAndViews)[operandNumber].partialLocalView));
3390ed2d4c7SMaheshRavishankar     } else {
340e70d2c8eSTobias Gysi       opViews.push_back(opOperand->get());
3410ed2d4c7SMaheshRavishankar     }
3420ed2d4c7SMaheshRavishankar   }
343c4a04059SChristian Sigg   op->setOperands(0, opViews.size(), opViews);
3445b03e692SNicolas Vasilache 
3458dbbb223SNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
3468dbbb223SNicolas Vasilache   b.setInsertionPointAfter(op);
3475b03e692SNicolas Vasilache   // 3. Emit write-back for the promoted output views: copy the partial view.
3480ed2d4c7SMaheshRavishankar   for (auto viewAndPartialLocalView : writebackViews) {
3490ed2d4c7SMaheshRavishankar     if (failed(options.copyOutFn(b, viewAndPartialLocalView.second,
3500ed2d4c7SMaheshRavishankar                                  viewAndPartialLocalView.first)))
351489fec27SNicolas Vasilache       return failure();
3520ed2d4c7SMaheshRavishankar   }
3535b03e692SNicolas Vasilache 
3548dbbb223SNicolas Vasilache   // 4. Dealloc all local buffers.
3557d9518c8SNicolas Vasilache   for (const auto &pi : *promotedBuffersAndViews)
356e21adfa3SRiver Riddle     (void)options.deallocationFn(b, pi.second.fullLocalView);
3570ed2d4c7SMaheshRavishankar   return op;
3585b03e692SNicolas Vasilache }
3595b03e692SNicolas Vasilache 
3608dbbb223SNicolas Vasilache LogicalResult
promoteSubviewsPrecondition(Operation * op,LinalgPromotionOptions options)3618dbbb223SNicolas Vasilache mlir::linalg::promoteSubviewsPrecondition(Operation *op,
3628dbbb223SNicolas Vasilache                                           LinalgPromotionOptions options) {
363e70d2c8eSTobias Gysi   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
364307cfdf5SNicolas Vasilache   // Transformation applies to buffers only.
365e70d2c8eSTobias Gysi   if (!linalgOp || !linalgOp.hasBufferSemantics())
366307cfdf5SNicolas Vasilache     return failure();
3678dbbb223SNicolas Vasilache   // Check that at least one of the requested operands is indeed a subview.
368e70d2c8eSTobias Gysi   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
369e70d2c8eSTobias Gysi     auto sv =
370e70d2c8eSTobias Gysi         isa_and_nonnull<memref::SubViewOp>(opOperand->get().getDefiningOp());
3718dbbb223SNicolas Vasilache     if (sv) {
372037f0995SKazu Hirata       if (!options.operandsToPromote ||
373e70d2c8eSTobias Gysi           options.operandsToPromote->count(opOperand->getOperandNumber()))
374307cfdf5SNicolas Vasilache         return success();
375307cfdf5SNicolas Vasilache     }
3768dbbb223SNicolas Vasilache   }
3778dbbb223SNicolas Vasilache   // TODO: Check all subviews requested are bound by a static constant.
3788dbbb223SNicolas Vasilache   // TODO: Check that the total footprint fits within a given size.
379307cfdf5SNicolas Vasilache   return failure();
380307cfdf5SNicolas Vasilache }
381307cfdf5SNicolas Vasilache 
382489fec27SNicolas Vasilache FailureOr<LinalgOp>
promoteSubViews(OpBuilder & builder,LinalgOp linalgOp,const LinalgPromotionOptions & options)3834519ca3dSNicolas Vasilache mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp,
3841fc096afSMehdi Amini                               const LinalgPromotionOptions &options) {
3858dbbb223SNicolas Vasilache   LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options);
38642e5f422STres Popp   auto layout = DataLayout::closest(linalgOp);
3874519ca3dSNicolas Vasilache   ImplicitLocOpBuilder b(linalgOp.getLoc(), builder);
388489fec27SNicolas Vasilache   auto res = ::promoteSubViews(b, linalgOp, linalgOptions, layout);
389489fec27SNicolas Vasilache   if (failed(res))
390489fec27SNicolas Vasilache     return failure();
391489fec27SNicolas Vasilache   return res;
3928dbbb223SNicolas Vasilache }
393