15b03e692SNicolas Vasilache //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
25b03e692SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65b03e692SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85b03e692SNicolas Vasilache //
95b03e692SNicolas Vasilache // This file implements the linalg dialect Promotion pass.
105b03e692SNicolas Vasilache //
115b03e692SNicolas Vasilache //===----------------------------------------------------------------------===//
125b03e692SNicolas Vasilache
131834ad4aSRiver Riddle #include "PassDetail.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1542e5f422STres Popp #include "mlir/Dialect/Complex/IR/Complex.h"
16b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
175b03e692SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
18307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
195b03e692SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
215b03e692SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
225b03e692SNicolas Vasilache #include "mlir/IR/AffineExprVisitor.h"
235b03e692SNicolas Vasilache #include "mlir/IR/AffineMap.h"
24ef33c6e3SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255b03e692SNicolas Vasilache #include "mlir/Support/LLVM.h"
265b03e692SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h"
270ed2d4c7SMaheshRavishankar #include "llvm/ADT/MapVector.h"
2899069ab2SChristopher Bate #include "llvm/ADT/SmallBitVector.h"
294519ca3dSNicolas Vasilache #include "llvm/ADT/TypeSwitch.h"
305b03e692SNicolas Vasilache #include "llvm/Support/CommandLine.h"
31f89bb3c0SAlexander Belyaev #include "llvm/Support/Debug.h"
325b03e692SNicolas Vasilache
335b03e692SNicolas Vasilache using namespace mlir;
345b03e692SNicolas Vasilache using namespace mlir::linalg;
35c25b20c0SAlex Zinenko using namespace mlir::scf;
365b03e692SNicolas Vasilache
370ed2d4c7SMaheshRavishankar using llvm::MapVector;
385b03e692SNicolas Vasilache
395b03e692SNicolas Vasilache #define DEBUG_TYPE "linalg-promotion"
405b03e692SNicolas Vasilache
41ef33c6e3SNicolas Vasilache /// Alloc a new buffer of `size` * `width` i8; where `width` is given by the
42ef33c6e3SNicolas Vasilache /// data `layout` for `elementType`.
43ef33c6e3SNicolas Vasilache /// Use AllocOp or AllocaOp depending on `options`.
44ef33c6e3SNicolas Vasilache /// Take an optional alignment.
allocBuffer(ImplicitLocOpBuilder & b,const LinalgPromotionOptions & options,Type elementType,Value allocSize,DataLayout & layout,Optional<unsigned> alignment=None)45ef33c6e3SNicolas Vasilache static Value allocBuffer(ImplicitLocOpBuilder &b,
46ef33c6e3SNicolas Vasilache const LinalgPromotionOptions &options,
47ef33c6e3SNicolas Vasilache Type elementType, Value allocSize, DataLayout &layout,
488dbbb223SNicolas Vasilache Optional<unsigned> alignment = None) {
4942e5f422STres Popp auto width = layout.getTypeSize(elementType);
50ef33c6e3SNicolas Vasilache
51ef33c6e3SNicolas Vasilache IntegerAttr alignmentAttr;
52491d2701SKazu Hirata if (alignment.has_value())
53c27d8152SKazu Hirata alignmentAttr = b.getI64IntegerAttr(alignment.value());
54ef33c6e3SNicolas Vasilache
55ef33c6e3SNicolas Vasilache // Static buffer.
56a54f4eaeSMogball if (auto cst = allocSize.getDefiningOp<arith::ConstantIndexOp>()) {
57ef33c6e3SNicolas Vasilache auto staticBufferType =
58a54f4eaeSMogball MemRefType::get(width * cst.value(), b.getIntegerType(8));
59ef33c6e3SNicolas Vasilache if (options.useAlloca) {
60ef33c6e3SNicolas Vasilache return b.createOrFold<memref::AllocaOp>(staticBufferType, ValueRange{},
61ef33c6e3SNicolas Vasilache alignmentAttr);
62ef33c6e3SNicolas Vasilache }
63ef33c6e3SNicolas Vasilache return b.createOrFold<memref::AllocOp>(staticBufferType, ValueRange{},
64ef33c6e3SNicolas Vasilache alignmentAttr);
65ef33c6e3SNicolas Vasilache }
66ef33c6e3SNicolas Vasilache
67ef33c6e3SNicolas Vasilache // Fallback dynamic buffer.
68ef33c6e3SNicolas Vasilache auto dynamicBufferType = MemRefType::get(-1, b.getIntegerType(8));
69a54f4eaeSMogball Value mul = b.createOrFold<arith::MulIOp>(
70a54f4eaeSMogball b.create<arith::ConstantIndexOp>(width), allocSize);
71ef33c6e3SNicolas Vasilache if (options.useAlloca)
72ef33c6e3SNicolas Vasilache return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr);
73ef33c6e3SNicolas Vasilache return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr);
745b03e692SNicolas Vasilache }
755b03e692SNicolas Vasilache
760ed2d4c7SMaheshRavishankar /// Default allocation callback function. This allocates a promoted buffer when
770ed2d4c7SMaheshRavishankar /// no call back to do so is provided. The default is to allocate a
780ed2d4c7SMaheshRavishankar /// memref<..xi8> and return a view to get a memref type of shape
790ed2d4c7SMaheshRavishankar /// boundingSubViewSize.
804519ca3dSNicolas Vasilache static Optional<Value>
defaultAllocBufferCallBack(const LinalgPromotionOptions & options,OpBuilder & builder,memref::SubViewOp subView,ArrayRef<Value> boundingSubViewSize,Optional<unsigned> alignment,DataLayout & layout)814519ca3dSNicolas Vasilache defaultAllocBufferCallBack(const LinalgPromotionOptions &options,
824519ca3dSNicolas Vasilache OpBuilder &builder, memref::SubViewOp subView,
834519ca3dSNicolas Vasilache ArrayRef<Value> boundingSubViewSize,
844519ca3dSNicolas Vasilache Optional<unsigned> alignment, DataLayout &layout) {
850ed2d4c7SMaheshRavishankar ShapedType viewType = subView.getType();
86ef33c6e3SNicolas Vasilache ImplicitLocOpBuilder b(subView.getLoc(), builder);
87a54f4eaeSMogball auto zero = b.createOrFold<arith::ConstantIndexOp>(0);
88a54f4eaeSMogball auto one = b.createOrFold<arith::ConstantIndexOp>(1);
890ed2d4c7SMaheshRavishankar
900ed2d4c7SMaheshRavishankar Value allocSize = one;
91e4853be2SMehdi Amini for (const auto &size : llvm::enumerate(boundingSubViewSize))
92a54f4eaeSMogball allocSize = b.createOrFold<arith::MulIOp>(allocSize, size.value());
93ef33c6e3SNicolas Vasilache Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize,
94ef33c6e3SNicolas Vasilache layout, alignment);
950ed2d4c7SMaheshRavishankar SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
960ed2d4c7SMaheshRavishankar ShapedType::kDynamicSize);
97ef33c6e3SNicolas Vasilache Value view = b.createOrFold<memref::ViewOp>(
98ef33c6e3SNicolas Vasilache MemRefType::get(dynSizes, viewType.getElementType()), buffer, zero,
99ef33c6e3SNicolas Vasilache boundingSubViewSize);
1000ed2d4c7SMaheshRavishankar return view;
1010ed2d4c7SMaheshRavishankar }
1020ed2d4c7SMaheshRavishankar
1030ed2d4c7SMaheshRavishankar /// Default implementation of deallocation of the buffer use for promotion. It
1040ed2d4c7SMaheshRavishankar /// expects to get the same value that the default allocation method returned,
1050ed2d4c7SMaheshRavishankar /// i.e. result of a ViewOp.
1067d9518c8SNicolas Vasilache static LogicalResult
defaultDeallocBufferCallBack(const LinalgPromotionOptions & options,OpBuilder & b,Value fullLocalView)1077d9518c8SNicolas Vasilache defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
1087d9518c8SNicolas Vasilache OpBuilder &b, Value fullLocalView) {
1094519ca3dSNicolas Vasilache if (!options.useAlloca) {
1104519ca3dSNicolas Vasilache auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp());
111136d746eSJacques Pienaar b.create<memref::DeallocOp>(viewOp.getSource().getLoc(),
112136d746eSJacques Pienaar viewOp.getSource());
1134519ca3dSNicolas Vasilache }
1140ed2d4c7SMaheshRavishankar return success();
1150ed2d4c7SMaheshRavishankar }
1160ed2d4c7SMaheshRavishankar
1170ed2d4c7SMaheshRavishankar namespace {
1180ed2d4c7SMaheshRavishankar
1190ed2d4c7SMaheshRavishankar /// Helper struct that captures the information required to apply the
1200ed2d4c7SMaheshRavishankar /// transformation on each op. This bridges the abstraction gap with the
1210ed2d4c7SMaheshRavishankar /// user-facing API which exposes positional arguments to control which operands
1220ed2d4c7SMaheshRavishankar /// are promoted.
1230ed2d4c7SMaheshRavishankar struct LinalgOpInstancePromotionOptions {
1240ed2d4c7SMaheshRavishankar LinalgOpInstancePromotionOptions(LinalgOp op,
1250ed2d4c7SMaheshRavishankar const LinalgPromotionOptions &options);
1260ed2d4c7SMaheshRavishankar /// SubViews to promote.
1274519ca3dSNicolas Vasilache MapVector<int64_t, Value> subViews;
1280ed2d4c7SMaheshRavishankar /// True if the full view should be used for the promoted buffer.
1290ed2d4c7SMaheshRavishankar DenseMap<Value, bool> useFullTileBuffers;
1300ed2d4c7SMaheshRavishankar
1310ed2d4c7SMaheshRavishankar /// Callback functions for allocation and deallocation of promoted buffers, as
1320ed2d4c7SMaheshRavishankar /// well as to copy the data into and out of these buffers.
1330ed2d4c7SMaheshRavishankar AllocBufferCallbackFn allocationFn;
1340ed2d4c7SMaheshRavishankar DeallocBufferCallbackFn deallocationFn;
1350ed2d4c7SMaheshRavishankar CopyCallbackFn copyInFn;
1360ed2d4c7SMaheshRavishankar CopyCallbackFn copyOutFn;
1370ed2d4c7SMaheshRavishankar
1380ed2d4c7SMaheshRavishankar /// Alignment of promoted buffer.
1390ed2d4c7SMaheshRavishankar Optional<unsigned> alignment;
1400ed2d4c7SMaheshRavishankar };
1410ed2d4c7SMaheshRavishankar } // namespace
1420ed2d4c7SMaheshRavishankar
LinalgOpInstancePromotionOptions(LinalgOp linalgOp,const LinalgPromotionOptions & options)1430ed2d4c7SMaheshRavishankar LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
1440ed2d4c7SMaheshRavishankar LinalgOp linalgOp, const LinalgPromotionOptions &options)
145*5a001136SNicolas Vasilache : subViews(), alignment(options.alignment) {
146b7ae1d3dSnicolasvasilache assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
1470ed2d4c7SMaheshRavishankar auto vUseFullTileBuffers =
14830c67587SKazu Hirata options.useFullTileBuffers.value_or(llvm::SmallBitVector());
149e70d2c8eSTobias Gysi vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(),
150e70d2c8eSTobias Gysi options.useFullTileBuffersDefault);
1510ed2d4c7SMaheshRavishankar
152e70d2c8eSTobias Gysi for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
153e70d2c8eSTobias Gysi int64_t operandNumber = opOperand->getOperandNumber();
154e70d2c8eSTobias Gysi if (options.operandsToPromote &&
155e70d2c8eSTobias Gysi !options.operandsToPromote->count(operandNumber))
1560ed2d4c7SMaheshRavishankar continue;
157e70d2c8eSTobias Gysi Operation *op = opOperand->get().getDefiningOp();
158e2310704SJulian Gross if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
159e70d2c8eSTobias Gysi subViews[operandNumber] = sv;
160e70d2c8eSTobias Gysi useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
1610ed2d4c7SMaheshRavishankar }
1620ed2d4c7SMaheshRavishankar }
1630ed2d4c7SMaheshRavishankar
1644519ca3dSNicolas Vasilache if (options.allocationFn) {
1654519ca3dSNicolas Vasilache allocationFn = *options.allocationFn;
1664519ca3dSNicolas Vasilache } else {
1674519ca3dSNicolas Vasilache allocationFn = [&](OpBuilder &b, memref::SubViewOp subViewOp,
168ef33c6e3SNicolas Vasilache ArrayRef<Value> boundingSubViewSize,
169ef33c6e3SNicolas Vasilache DataLayout &layout) -> Optional<Value> {
1704519ca3dSNicolas Vasilache return defaultAllocBufferCallBack(options, b, subViewOp,
1714519ca3dSNicolas Vasilache boundingSubViewSize, alignment, layout);
1724519ca3dSNicolas Vasilache };
1734519ca3dSNicolas Vasilache }
1744519ca3dSNicolas Vasilache
1754519ca3dSNicolas Vasilache if (options.deallocationFn) {
1764519ca3dSNicolas Vasilache deallocationFn = *options.deallocationFn;
1774519ca3dSNicolas Vasilache } else {
1784519ca3dSNicolas Vasilache deallocationFn = [&](OpBuilder &b, Value buffer) {
1797d9518c8SNicolas Vasilache return defaultDeallocBufferCallBack(options, b, buffer);
1804519ca3dSNicolas Vasilache };
1814519ca3dSNicolas Vasilache }
1824519ca3dSNicolas Vasilache
1834519ca3dSNicolas Vasilache // Save the loc because `linalgOp` goes out of scope.
1844519ca3dSNicolas Vasilache Location loc = linalgOp.getLoc();
1854519ca3dSNicolas Vasilache auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
1860ed2d4c7SMaheshRavishankar Value dst) -> LogicalResult {
187ebc81537SAlexander Belyaev b.create<memref::CopyOp>(loc, src, dst);
1880ed2d4c7SMaheshRavishankar return success();
1890ed2d4c7SMaheshRavishankar };
1900ed2d4c7SMaheshRavishankar copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
1910ed2d4c7SMaheshRavishankar copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack);
1920ed2d4c7SMaheshRavishankar }
1930ed2d4c7SMaheshRavishankar
1945b03e692SNicolas Vasilache // Performs promotion of a `subView` into a local buffer of the size of the
1955b03e692SNicolas Vasilache // *ranges* of the `subView`. This produces a buffer whose size may be bigger
1965b03e692SNicolas Vasilache // than the actual size of the `subView` at the boundaries.
1975b03e692SNicolas Vasilache // This is related to the full/partial tile problem.
1985b03e692SNicolas Vasilache // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
1995b03e692SNicolas Vasilache // `partialLocalView` such that:
2005b03e692SNicolas Vasilache // * `buffer` is always the size of the full tile.
2015b03e692SNicolas Vasilache // * `fullLocalView` is a dense contiguous view into that buffer.
2025b03e692SNicolas Vasilache // * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
2035b03e692SNicolas Vasilache // that corresponds to the size of `subView` and accounting for boundary
2045b03e692SNicolas Vasilache // effects.
2055b03e692SNicolas Vasilache // The point of the full tile buffer is that constant static tile sizes are
2065b03e692SNicolas Vasilache // folded and result in a buffer type with statically known size and alignment
2075b03e692SNicolas Vasilache // properties.
2085b03e692SNicolas Vasilache // To account for general boundary effects, padding must be performed on the
2095b03e692SNicolas Vasilache // boundary tiles. For now this is done with an unconditional `fill` op followed
2105b03e692SNicolas Vasilache // by a partial `copy` op.
promoteSubviewAsNewBuffer(OpBuilder & b,Location loc,memref::SubViewOp subView,const AllocBufferCallbackFn & allocationFn,DataLayout & layout)211489fec27SNicolas Vasilache FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
212e2310704SJulian Gross OpBuilder &b, Location loc, memref::SubViewOp subView,
2131fc096afSMehdi Amini const AllocBufferCallbackFn &allocationFn, DataLayout &layout) {
2140bd6390bSNicolas Vasilache auto viewType = subView.getType();
2155b03e692SNicolas Vasilache auto rank = viewType.getRank();
21605d5125dSNicolas Vasilache SmallVector<Value, 4> fullSizes;
21705d5125dSNicolas Vasilache SmallVector<OpFoldResult> partialSizes;
2183cb1f35dSNicolas Vasilache fullSizes.reserve(rank);
2193cb1f35dSNicolas Vasilache partialSizes.reserve(rank);
22099069ab2SChristopher Bate llvm::SmallBitVector droppedDims = subView.getDroppedDims();
22199069ab2SChristopher Bate int64_t resultDimIdx = 0;
222e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
22399069ab2SChristopher Bate if (droppedDims[en.index()])
22499069ab2SChristopher Bate continue;
2255b03e692SNicolas Vasilache auto rangeValue = en.value();
2268dbbb223SNicolas Vasilache // Try to extract a tight constant.
2278dbbb223SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
22853da8600STobias Gysi FailureOr<int64_t> upperBound =
22953da8600STobias Gysi getConstantUpperBoundForIndex(rangeValue.size);
23053da8600STobias Gysi Value size =
23153da8600STobias Gysi failed(upperBound)
23253da8600STobias Gysi ? rangeValue.size
233c27d8152SKazu Hirata : b.create<arith::ConstantIndexOp>(loc, upperBound.value());
2348dbbb223SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
2353cb1f35dSNicolas Vasilache fullSizes.push_back(size);
2364519ca3dSNicolas Vasilache partialSizes.push_back(
23799069ab2SChristopher Bate b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
2385b03e692SNicolas Vasilache }
2393cb1f35dSNicolas Vasilache SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
2400ed2d4c7SMaheshRavishankar // If a callback is not specified, then use the default implementation for
2410ed2d4c7SMaheshRavishankar // allocating the promoted buffer.
242ef33c6e3SNicolas Vasilache Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, layout);
2430ed2d4c7SMaheshRavishankar if (!fullLocalView)
244489fec27SNicolas Vasilache return failure();
24505d5125dSNicolas Vasilache SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
24605d5125dSNicolas Vasilache SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1));
247ef33c6e3SNicolas Vasilache auto partialLocalView = b.createOrFold<memref::SubViewOp>(
248ef33c6e3SNicolas Vasilache loc, *fullLocalView, zeros, partialSizes, ones);
2490ed2d4c7SMaheshRavishankar return PromotionInfo{*fullLocalView, partialLocalView};
2505b03e692SNicolas Vasilache }
2515b03e692SNicolas Vasilache
252489fec27SNicolas Vasilache static FailureOr<MapVector<int64_t, PromotionInfo>>
promoteSubViews(ImplicitLocOpBuilder & b,LinalgOpInstancePromotionOptions options,DataLayout & layout)2534519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b,
254ef33c6e3SNicolas Vasilache LinalgOpInstancePromotionOptions options, DataLayout &layout) {
2558dbbb223SNicolas Vasilache if (options.subViews.empty())
256489fec27SNicolas Vasilache return failure();
2575b03e692SNicolas Vasilache
2584519ca3dSNicolas Vasilache MapVector<int64_t, PromotionInfo> promotionInfoMap;
2595b03e692SNicolas Vasilache
2608dbbb223SNicolas Vasilache for (auto v : options.subViews) {
261e2310704SJulian Gross memref::SubViewOp subView =
262e2310704SJulian Gross cast<memref::SubViewOp>(v.second.getDefiningOp());
263489fec27SNicolas Vasilache auto promotionInfo = promoteSubviewAsNewBuffer(
2644519ca3dSNicolas Vasilache b, b.getLoc(), subView, options.allocationFn, layout);
265489fec27SNicolas Vasilache if (failed(promotionInfo))
266489fec27SNicolas Vasilache return failure();
2670ed2d4c7SMaheshRavishankar promotionInfoMap[v.first] = *promotionInfo;
2680ed2d4c7SMaheshRavishankar
269d1866f89SPierre Oechsel // Only fill the buffer if the full local view is used
2700ed2d4c7SMaheshRavishankar if (!options.useFullTileBuffers[v.second])
271d1866f89SPierre Oechsel continue;
2724519ca3dSNicolas Vasilache Type subviewEltType = subView.getType().getElementType();
2734519ca3dSNicolas Vasilache Value fillVal =
2744519ca3dSNicolas Vasilache llvm::TypeSwitch<Type, Value>(subviewEltType)
2754519ca3dSNicolas Vasilache .Case([&](FloatType t) {
276a54f4eaeSMogball return b.create<arith::ConstantOp>(FloatAttr::get(t, 0.0));
2774519ca3dSNicolas Vasilache })
2784519ca3dSNicolas Vasilache .Case([&](IntegerType t) {
279a54f4eaeSMogball return b.create<arith::ConstantOp>(IntegerAttr::get(t, 0));
2804519ca3dSNicolas Vasilache })
2814519ca3dSNicolas Vasilache .Case([&](ComplexType t) {
2824519ca3dSNicolas Vasilache Value tmp;
28342e5f422STres Popp if (auto et = t.getElementType().dyn_cast<FloatType>())
284a54f4eaeSMogball tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0));
28542e5f422STres Popp else if (auto et = t.getElementType().cast<IntegerType>())
286a54f4eaeSMogball tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0));
2874519ca3dSNicolas Vasilache return b.create<complex::CreateOp>(t, tmp, tmp);
2884519ca3dSNicolas Vasilache })
2894519ca3dSNicolas Vasilache .Default([](auto) { return Value(); });
2904519ca3dSNicolas Vasilache if (!fillVal)
291489fec27SNicolas Vasilache return failure();
2927cef24eeSTobias Gysi b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
2935b03e692SNicolas Vasilache }
2945b03e692SNicolas Vasilache
2950ed2d4c7SMaheshRavishankar // Copy data into the promoted buffers. Use callback if provided.
2968dbbb223SNicolas Vasilache for (auto v : options.subViews) {
2970ed2d4c7SMaheshRavishankar auto info = promotionInfoMap.find(v.first);
2985b03e692SNicolas Vasilache if (info == promotionInfoMap.end())
2995b03e692SNicolas Vasilache continue;
300e2310704SJulian Gross if (failed(options.copyInFn(
301e2310704SJulian Gross b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
3020ed2d4c7SMaheshRavishankar info->second.partialLocalView)))
303489fec27SNicolas Vasilache return failure();
3045b03e692SNicolas Vasilache }
3050ed2d4c7SMaheshRavishankar return promotionInfoMap;
3065b03e692SNicolas Vasilache }
3075b03e692SNicolas Vasilache
308489fec27SNicolas Vasilache static FailureOr<LinalgOp>
promoteSubViews(ImplicitLocOpBuilder & b,LinalgOp op,LinalgOpInstancePromotionOptions options,DataLayout & layout)3094519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
310ef33c6e3SNicolas Vasilache LinalgOpInstancePromotionOptions options, DataLayout &layout) {
311f52d7173SNicolas Vasilache assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
312f52d7173SNicolas Vasilache
3135b03e692SNicolas Vasilache // 1. Promote the specified views and use them in the new op.
3144519ca3dSNicolas Vasilache auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
315489fec27SNicolas Vasilache if (failed(promotedBuffersAndViews) ||
3160ed2d4c7SMaheshRavishankar promotedBuffersAndViews->size() != options.subViews.size())
317489fec27SNicolas Vasilache return failure();
3185b03e692SNicolas Vasilache
3195b03e692SNicolas Vasilache // 2. Append all other operands as they appear, this enforces that such
3205b03e692SNicolas Vasilache // operands are not views. This is to support cases such as FillOp taking
3210ed2d4c7SMaheshRavishankar // extra scalars etc. Keep a reference to output buffers;
3220ed2d4c7SMaheshRavishankar SmallVector<Value, 8> opViews;
323e70d2c8eSTobias Gysi opViews.reserve(op.getNumInputsAndOutputs());
3240ed2d4c7SMaheshRavishankar SmallVector<std::pair<Value, Value>, 8> writebackViews;
3250ed2d4c7SMaheshRavishankar writebackViews.reserve(promotedBuffersAndViews->size());
326e70d2c8eSTobias Gysi for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
327e70d2c8eSTobias Gysi int64_t operandNumber = opOperand->getOperandNumber();
328e70d2c8eSTobias Gysi if (options.subViews.count(operandNumber) != 0) {
329e70d2c8eSTobias Gysi if (options.useFullTileBuffers[opOperand->get()])
3300ed2d4c7SMaheshRavishankar opViews.push_back(
331e70d2c8eSTobias Gysi (*promotedBuffersAndViews)[operandNumber].fullLocalView);
3320ed2d4c7SMaheshRavishankar else
3330ed2d4c7SMaheshRavishankar opViews.push_back(
334e70d2c8eSTobias Gysi (*promotedBuffersAndViews)[operandNumber].partialLocalView);
335e70d2c8eSTobias Gysi if (operandNumber >= op.getNumInputs())
3360ed2d4c7SMaheshRavishankar writebackViews.emplace_back(std::make_pair(
337e70d2c8eSTobias Gysi opOperand->get(),
338e70d2c8eSTobias Gysi (*promotedBuffersAndViews)[operandNumber].partialLocalView));
3390ed2d4c7SMaheshRavishankar } else {
340e70d2c8eSTobias Gysi opViews.push_back(opOperand->get());
3410ed2d4c7SMaheshRavishankar }
3420ed2d4c7SMaheshRavishankar }
343c4a04059SChristian Sigg op->setOperands(0, opViews.size(), opViews);
3445b03e692SNicolas Vasilache
3458dbbb223SNicolas Vasilache OpBuilder::InsertionGuard guard(b);
3468dbbb223SNicolas Vasilache b.setInsertionPointAfter(op);
3475b03e692SNicolas Vasilache // 3. Emit write-back for the promoted output views: copy the partial view.
3480ed2d4c7SMaheshRavishankar for (auto viewAndPartialLocalView : writebackViews) {
3490ed2d4c7SMaheshRavishankar if (failed(options.copyOutFn(b, viewAndPartialLocalView.second,
3500ed2d4c7SMaheshRavishankar viewAndPartialLocalView.first)))
351489fec27SNicolas Vasilache return failure();
3520ed2d4c7SMaheshRavishankar }
3535b03e692SNicolas Vasilache
3548dbbb223SNicolas Vasilache // 4. Dealloc all local buffers.
3557d9518c8SNicolas Vasilache for (const auto &pi : *promotedBuffersAndViews)
356e21adfa3SRiver Riddle (void)options.deallocationFn(b, pi.second.fullLocalView);
3570ed2d4c7SMaheshRavishankar return op;
3585b03e692SNicolas Vasilache }
3595b03e692SNicolas Vasilache
3608dbbb223SNicolas Vasilache LogicalResult
promoteSubviewsPrecondition(Operation * op,LinalgPromotionOptions options)3618dbbb223SNicolas Vasilache mlir::linalg::promoteSubviewsPrecondition(Operation *op,
3628dbbb223SNicolas Vasilache LinalgPromotionOptions options) {
363e70d2c8eSTobias Gysi LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
364307cfdf5SNicolas Vasilache // Transformation applies to buffers only.
365e70d2c8eSTobias Gysi if (!linalgOp || !linalgOp.hasBufferSemantics())
366307cfdf5SNicolas Vasilache return failure();
3678dbbb223SNicolas Vasilache // Check that at least one of the requested operands is indeed a subview.
368e70d2c8eSTobias Gysi for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
369e70d2c8eSTobias Gysi auto sv =
370e70d2c8eSTobias Gysi isa_and_nonnull<memref::SubViewOp>(opOperand->get().getDefiningOp());
3718dbbb223SNicolas Vasilache if (sv) {
372037f0995SKazu Hirata if (!options.operandsToPromote ||
373e70d2c8eSTobias Gysi options.operandsToPromote->count(opOperand->getOperandNumber()))
374307cfdf5SNicolas Vasilache return success();
375307cfdf5SNicolas Vasilache }
3768dbbb223SNicolas Vasilache }
3778dbbb223SNicolas Vasilache // TODO: Check all subviews requested are bound by a static constant.
3788dbbb223SNicolas Vasilache // TODO: Check that the total footprint fits within a given size.
379307cfdf5SNicolas Vasilache return failure();
380307cfdf5SNicolas Vasilache }
381307cfdf5SNicolas Vasilache
382489fec27SNicolas Vasilache FailureOr<LinalgOp>
promoteSubViews(OpBuilder & builder,LinalgOp linalgOp,const LinalgPromotionOptions & options)3834519ca3dSNicolas Vasilache mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp,
3841fc096afSMehdi Amini const LinalgPromotionOptions &options) {
3858dbbb223SNicolas Vasilache LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options);
38642e5f422STres Popp auto layout = DataLayout::closest(linalgOp);
3874519ca3dSNicolas Vasilache ImplicitLocOpBuilder b(linalgOp.getLoc(), builder);
388489fec27SNicolas Vasilache auto res = ::promoteSubViews(b, linalgOp, linalgOptions, layout);
389489fec27SNicolas Vasilache if (failed(res))
390489fec27SNicolas Vasilache return failure();
391489fec27SNicolas Vasilache return res;
3928dbbb223SNicolas Vasilache }
393