1 //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/IR/AffineMap.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 13 using namespace mlir; 14 15 bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) { 16 if (indexingMaps.size() != 3) 17 return false; 18 19 auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue(); 20 auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue(); 21 auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue(); 22 23 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || 24 map2.getNumResults() != 2 || map0.getNumInputs() != 3 || 25 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { 26 return false; 27 } 28 29 // Extract dimensions for MxK * KxN -> MxN 30 AffineExpr m = map2.getResult(0); 31 AffineExpr n = map2.getResult(1); 32 AffineExpr k = map0.getResult(1); 33 auto *context = indexingMaps.getContext(); 34 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); 35 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); 36 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); 37 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); 38 return indexingMaps == maps; 39 } 40 41 bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) { 42 if (indexingMaps.size() != 3) 43 return false; 44 45 auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue(); 46 auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue(); 47 auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue(); 48 49 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || 50 map2.getNumResults() != 2 || map0.getNumInputs() != 3 || 51 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { 52 return false; 53 } 54 55 // Extract dimensions for KxM * NxK -> NxM 56 AffineExpr n = map2.getResult(0); 57 AffineExpr m = map2.getResult(1); 58 AffineExpr k = map0.getResult(0); 59 auto *context = indexingMaps.getContext(); 60 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context)); 61 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context)); 62 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); 63 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); 64 return indexingMaps == maps; 65 } 66 67 bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) { 68 if (indexingMaps.size() != 3) 69 return false; 70 71 auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue(); 72 auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue(); 73 auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue(); 74 75 if (map0.getNumResults() != 3 || map1.getNumResults() != 3 || 76 map2.getNumResults() != 3 || map0.getNumInputs() != 4 || 77 map1.getNumInputs() != 4 || map2.getNumInputs() != 4) { 78 return false; 79 } 80 81 // Extract dimensions for BxMxK * BxKxN -> BxMxN 82 AffineExpr b = map2.getResult(0); 83 AffineExpr m = map2.getResult(1); 84 AffineExpr n = map2.getResult(2); 85 AffineExpr k = map0.getResult(2); 86 auto *context = indexingMaps.getContext(); 87 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context)); 88 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context)); 89 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context)); 90 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); 91 return indexingMaps == maps; 92 } 93