1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 #include "llvm/Support/YAMLTraits.h"
30 
31 using namespace mlir;
32 
33 using llvm::yaml::Input;
34 using llvm::yaml::MappingTraits;
35 using llvm::yaml::ScalarEnumerationTraits;
36 using llvm::yaml::ScalarTraits;
37 
38 #define DEBUG_TYPE "linalg-ods-gen"
39 
40 //===----------------------------------------------------------------------===//
41 // Mapping structs (correspond to data types in the YAML description).
42 // TODO: Since this is a schema/part of the contract, it should be moved to
43 // a real header.
44 //===----------------------------------------------------------------------===//
45 
46 namespace {
47 
48 struct LinalgYAMLContext {
49   MLIRContext *mlirContext;
50 };
51 
52 struct LinalgOpMetadata {
53   std::string name;
54   std::string cppClassName;
55   Optional<std::string> doc;
56   SmallVector<std::string> implements;
57   SmallVector<std::string> defines;
58 };
59 
60 struct SerializedAffineMap {
61   AffineMapAttr affineMapAttr;
62 
affineMap__anon8e0fd1860111::SerializedAffineMap63   AffineMap affineMap() { return affineMapAttr.getValue(); }
64 };
65 
66 enum class LinalgOperandDefKind {
67   InputTensor,
68   Scalar,
69   OutputTensor,
70   IndexAttr,
71   UnaryFnAttr,
72   BinaryFnAttr,
73   TypeFnAttr
74 };
75 
76 struct LinalgOperandDef {
77   std::string name;
78   LinalgOperandDefKind kind;
79   Optional<std::string> typeVar;
80   Optional<SerializedAffineMap> shapeMap;
81   Optional<SerializedAffineMap> indexAttrMap;
82   Optional<SmallVector<int64_t>> defaultIndices;
83   Optional<std::string> defaultFn;
84 };
85 
86 enum class LinalgIteratorTypeDef {
87   parallel,
88   reduction,
89 };
90 
91 struct LinalgIndexingMapsConfig {
92   Optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
93 };
94 
95 struct ScalarExpression;
96 
97 enum class ScalarFnKind { Unary, Binary, Type };
98 
99 struct ScalarFn {
100   ScalarFnKind kind;
101   Optional<std::string> fnName;
102   Optional<std::string> attrName;
103   Optional<std::string> typeVar;
104   // NOTE: This must be of arity 1, but to break the self-referential cycle,
105   // we use a heap allocated vector.
106   std::vector<ScalarExpression> operands;
107 };
108 
109 struct ScalarExpression {
110   Optional<std::string> arg;
111   Optional<std::string> constant;
112   Optional<int64_t> index;
113   Optional<ScalarFn> scalarFn;
114 };
115 
116 struct ScalarAssign {
117   std::string arg;
118   ScalarExpression value;
119 };
120 
121 struct LinalgStructuredOpConfig {
122   SmallVector<LinalgOperandDef> args;
123   LinalgIndexingMapsConfig indexingMaps;
124   SmallVector<LinalgIteratorTypeDef> iteratorTypes;
125   std::vector<ScalarAssign> assignments;
126 };
127 
128 struct LinalgOpConfig {
129   Optional<LinalgOpMetadata> metadata;
130   Optional<LinalgStructuredOpConfig> structuredOp;
131 };
132 
133 } // namespace
134 
135 //===----------------------------------------------------------------------===//
136 // Mapping traits.
137 //===----------------------------------------------------------------------===//
138 
139 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
140 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
142 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
143 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression)
144 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig)
145 
146 namespace llvm {
147 namespace yaml {
148 
149 /// Top-level type containing op metadata and one of a concrete op type.
150 /// Currently, the only defined op type is `structured_op` (maps to
151 /// `LinalgStructuredOpConfig`).
152 template <>
153 struct MappingTraits<LinalgOpConfig> {
mappingllvm::yaml::MappingTraits154   static void mapping(IO &io, LinalgOpConfig &info) {
155     io.mapOptional("metadata", info.metadata);
156     io.mapOptional("structured_op", info.structuredOp);
157   }
158 };
159 
160 /// A structured op models (at most) a single contraction by modeling
161 ///   - A list of named arguments (`LinalgOperandDef`), which can be inputs,
162 ///     outputs, or index attributes.
163 ///   - List of indexing maps (see `LinalgIndexingMaps`).
164 ///   - Iterator types (see `LinalgIteratorTypeDef`).
165 ///   - List of scalar level assignment (see `ScalarAssign`).
166 template <>
167 struct MappingTraits<LinalgStructuredOpConfig> {
mappingllvm::yaml::MappingTraits168   static void mapping(IO &io, LinalgStructuredOpConfig &info) {
169     io.mapRequired("args", info.args);
170     io.mapRequired("indexing_maps", info.indexingMaps);
171     io.mapRequired("iterator_types", info.iteratorTypes);
172     io.mapRequired("assignments", info.assignments);
173   }
174 };
175 
176 /// Maps a named tensor, scalar or attribute argument to an operation,
177 /// consisting of:
178 ///   - `name`: Must be unique within the operation.
179 ///   - `usage`: How the argument is used (input, output, attribute, etc).
180 ///   - `type_var`: The symbolic type variable that binds to the element or self
181 ///     type of the tensor or scalar argument, respectively.
182 ///   - `shape_map`: An optional AffineMap from all op symbols to the shape of
183 ///     the argument. Only tensor arguments have a `shape_map`. Each shape must
184 ///     be normalized over the same list of symbols and have no dimension
185 ///     inputs.
186 ///   - `index_attr_map`: An optional AffineMap from all op symbols to the
187 ///     index attribute symbols. During op creation these symbols are replaced
188 ///     by the corresponding `name` index attribue values. Only index attribute
189 ///     arguments have an `index_attr_map`.
190 ///   - `default_indices`: An optional default initialization for index
191 ///     attribute arguments.
192 ///   - `default_fn`: An optional default initialization for function attribute
193 ///     arguments.
194 template <>
195 struct MappingTraits<LinalgOperandDef> {
mappingllvm::yaml::MappingTraits196   static void mapping(IO &io, LinalgOperandDef &info) {
197     io.mapRequired("name", info.name);
198     io.mapRequired("kind", info.kind);
199     io.mapOptional("type_var", info.typeVar);
200     io.mapOptional("shape_map", info.shapeMap);
201     io.mapOptional("index_attr_map", info.indexAttrMap);
202     io.mapOptional("default_indices", info.defaultIndices);
203     io.mapOptional("default_fn", info.defaultFn);
204   }
205 };
206 
207 /// Usage enum for a named argument.
208 template <>
209 struct ScalarEnumerationTraits<LinalgOperandDefKind> {
enumerationllvm::yaml::ScalarEnumerationTraits210   static void enumeration(IO &io, LinalgOperandDefKind &value) {
211     io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
212     io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
213     io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
214     io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
215     io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
216     io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
217     io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
218   }
219 };
220 
221 /// Iterator type enum.
222 template <>
223 struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
enumerationllvm::yaml::ScalarEnumerationTraits224   static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
225     io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
226     io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
227   }
228 };
229 
230 /// Metadata about the op (name, C++ name, and documentation).
231 template <>
232 struct MappingTraits<LinalgOpMetadata> {
mappingllvm::yaml::MappingTraits233   static void mapping(IO &io, LinalgOpMetadata &info) {
234     io.mapRequired("name", info.name);
235     io.mapRequired("cpp_class_name", info.cppClassName);
236     io.mapOptional("doc", info.doc);
237     io.mapOptional("implements", info.implements);
238     io.mapOptional("defines", info.defines);
239   }
240 };
241 
242 /// How the ops indexing maps are produced. Must be one of:
243 ///   - static_indexing_maps: A static list of AffineMaps, possibly with
244 ///     some symbols that bind to attributes of the op. Each indexing map must
245 ///     be normalized over the same list of dimensions, and its symbols must
246 ///     match the symbols for argument shapes.
247 template <>
248 struct MappingTraits<LinalgIndexingMapsConfig> {
mappingllvm::yaml::MappingTraits249   static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
250     io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
251   }
252 };
253 
254 /// Models an assignment to a named output.
255 ///   - The `arg` name must match a named output.
256 ///   - The `value` is a scalar expression for computing the value to
257 ///     assign (see `ScalarExpression`).
258 template <>
259 struct MappingTraits<ScalarAssign> {
mappingllvm::yaml::MappingTraits260   static void mapping(IO &io, ScalarAssign &info) {
261     io.mapRequired("arg", info.arg);
262     io.mapRequired("value", info.value);
263   }
264 };
265 
266 /// A scalar expression (RHS of an assignment). Must be one of:
267 ///   - `scalar_arg`: An operation argument.
268 ///   - `scalar_const`: A constant definition.
269 ///   - `scalar_index`: An iteration index.
270 ///   - `scalar_fn`: A named function (see `ScalarFn`).
271 template <>
272 struct MappingTraits<ScalarExpression> {
mappingllvm::yaml::MappingTraits273   static void mapping(IO &io, ScalarExpression &info) {
274     io.mapOptional("scalar_arg", info.arg);
275     io.mapOptional("scalar_const", info.constant);
276     io.mapOptional("scalar_index", info.index);
277     io.mapOptional("scalar_fn", info.scalarFn);
278   }
279 };
280 
281 /// Scalar function kind enum.
282 template <>
283 struct ScalarEnumerationTraits<ScalarFnKind> {
enumerationllvm::yaml::ScalarEnumerationTraits284   static void enumeration(IO &io, ScalarFnKind &value) {
285     io.enumCase(value, "unary", ScalarFnKind::Unary);
286     io.enumCase(value, "binary", ScalarFnKind::Binary);
287     io.enumCase(value, "type", ScalarFnKind::Type);
288   }
289 };
290 
291 /// A scalar expression that evaluates a named function.
292 /// Functions are generally "math" level and type polymorphic. Builtin
293 /// functions include:
294 ///   - `add(lhs, rhs)`
295 ///   - `mul(lhs, rhs)`
296 template <>
297 struct MappingTraits<ScalarFn> {
mappingllvm::yaml::MappingTraits298   static void mapping(IO &io, ScalarFn &info) {
299     io.mapRequired("kind", info.kind);
300     io.mapOptional("fn_name", info.fnName);
301     io.mapOptional("attr_name", info.attrName);
302     io.mapOptional("type_var", info.typeVar);
303     io.mapRequired("operands", info.operands);
304   }
305 };
306 
307 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
308 /// the same.
309 template <>
310 struct ScalarTraits<SerializedAffineMap> {
outputllvm::yaml::ScalarTraits311   static void output(const SerializedAffineMap &value, void *rawYamlContext,
312                      raw_ostream &out) {
313     assert(value.affineMapAttr);
314     value.affineMapAttr.print(out);
315   }
inputllvm::yaml::ScalarTraits316   static StringRef input(StringRef scalar, void *rawYamlContext,
317                          SerializedAffineMap &value) {
318     assert(rawYamlContext);
319     auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
320     if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
321                         .dyn_cast_or_null<AffineMapAttr>())
322       value.affineMapAttr = attr;
323     else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
324       return "could not parse as an affine map attribute";
325     return StringRef();
326   }
mustQuotellvm::yaml::ScalarTraits327   static QuotingType mustQuote(StringRef) { return QuotingType::None; }
328 };
329 
330 } // namespace yaml
331 } // namespace llvm
332 
333 namespace {
334 
335 //===----------------------------------------------------------------------===//
336 // Generation utilities
337 //===----------------------------------------------------------------------===//
338 
339 class GenerationContext {
340 public:
GenerationContext(MLIRContext * context,raw_ostream * odsOut,raw_ostream * defnOut)341   GenerationContext(MLIRContext *context, raw_ostream *odsOut,
342                     raw_ostream *defnOut)
343       : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
344         defnOut(defnOut) {}
345 
getContext()346   MLIRContext *getContext() { return context; }
347 
setLoc(Location loc)348   void setLoc(Location loc) { this->loc = loc; }
getLoc()349   Location getLoc() { return loc; }
350 
shouldGenerateOds()351   bool shouldGenerateOds() { return odsOut; }
shouldGenerateDefns()352   bool shouldGenerateDefns() { return defnOut; }
353 
odss()354   raw_ostream &odss() {
355     assert(odsOut && "ODS stream not defined");
356     return *odsOut;
357   }
358 
defns()359   raw_ostream &defns() {
360     assert(defnOut && "Definition stream not defined");
361     return *defnOut;
362   }
363 
364 private:
365   MLIRContext *context;
366   Location loc;
367   raw_ostream *odsOut;
368   raw_ostream *defnOut;
369 };
370 
371 } // namespace
372 
generateCppExpression(SerializedAffineMap self,StringRef contextName)373 static std::string generateCppExpression(SerializedAffineMap self,
374                                          StringRef contextName) {
375   std::string printedStr;
376   llvm::raw_string_ostream printedSs(printedStr);
377   self.affineMapAttr.print(printedSs);
378   printedSs.flush();
379 
380   static const char exprFormat[] =
381       R"FMT(mlir::parseAttribute("{0}", {1}).cast<AffineMapAttr>().getValue())FMT";
382   return llvm::formatv(exprFormat, printedStr, contextName);
383 }
384 
385 template <typename Container>
interleaveToString(Container & container,StringRef separator)386 static std::string interleaveToString(Container &container,
387                                       StringRef separator) {
388   std::string result;
389   llvm::raw_string_ostream ss(result);
390   llvm::interleave(container, ss, separator);
391   ss.flush();
392   return result;
393 }
394 
395 static Optional<int>
findTensorDefArgIndex(StringRef name,SmallVectorImpl<LinalgOperandDef> & args)396 findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
397   for (const auto &it : llvm::enumerate(args)) {
398     if (it.value().name == name)
399       return it.index();
400   }
401   return None;
402 }
403 
404 // Try to map the TypeVar to a predefined or an argument type.
405 static Optional<std::string>
findTypeValue(StringRef typeVar,SmallVectorImpl<LinalgOperandDef> & args)406 findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
407   // Handle all predefined types.
408   if (typeVar == "I32")
409     return std::string("helper.getIntegerType(32)");
410   if (typeVar == "I64")
411     return std::string("helper.getIntegerType(64)");
412   if (typeVar == "F32")
413     return std::string("helper.getFloat32Type()");
414   if (typeVar == "F64")
415     return std::string("helper.getFloat64Type()");
416 
417   // Search all argument types.
418   for (const auto &it : llvm::enumerate(args)) {
419     if (it.value().kind != LinalgOperandDefKind::InputTensor &&
420         it.value().kind != LinalgOperandDefKind::Scalar &&
421         it.value().kind != LinalgOperandDefKind::OutputTensor)
422       continue;
423     if (*it.value().typeVar == typeVar)
424       return llvm::formatv("block.getArgument({0}).getType()", it.index())
425           .str();
426   }
427 
428   return None;
429 }
430 
findAssignment(StringRef name,std::vector<ScalarAssign> & assignments)431 static ScalarAssign *findAssignment(StringRef name,
432                                     std::vector<ScalarAssign> &assignments) {
433   for (auto &assign : assignments) {
434     if (assign.arg == name)
435       return &assign;
436   }
437   return nullptr;
438 }
439 
440 // Return true if the operand is a function attribute.
isFunctionAttribute(LinalgOperandDefKind kind)441 static bool isFunctionAttribute(LinalgOperandDefKind kind) {
442   return kind == LinalgOperandDefKind::UnaryFnAttr ||
443          kind == LinalgOperandDefKind::BinaryFnAttr ||
444          kind == LinalgOperandDefKind::TypeFnAttr;
445 }
446 
447 // Return true if the operand is an attribute.
isAttribute(LinalgOperandDefKind kind)448 static bool isAttribute(LinalgOperandDefKind kind) {
449   return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
450 }
451 
452 // Get the enum name for the given operand kind.
convertOperandKindToEnumName(LinalgOperandDefKind kind)453 std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
454   switch (kind) {
455   case LinalgOperandDefKind::UnaryFnAttr:
456     return std::string("UnaryFn");
457   case LinalgOperandDefKind::BinaryFnAttr:
458     return std::string("BinaryFn");
459   case LinalgOperandDefKind::TypeFnAttr:
460     return std::string("TypeFn");
461   default:
462     break;
463   }
464   llvm_unreachable("unsupported function attribute kind");
465 }
466 
467 // Get the enum name for the given function kind.
convertFunctionKindToEnumName(ScalarFnKind kind)468 std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
469   switch (kind) {
470   case ScalarFnKind::Unary:
471     return std::string("UnaryFn");
472   case ScalarFnKind::Binary:
473     return std::string("BinaryFn");
474   case ScalarFnKind::Type:
475     return std::string("TypeFn");
476   }
477   llvm_unreachable("unsupported function kind");
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // Templates
482 //===----------------------------------------------------------------------===//
483 
484 // A single line banner format. Parameters:
485 // {0}: Single line comment
486 static const char bannerFormat[] = R"FMT(
487 //===----------------------------------------------------------------------===//
488 // {0}
489 //===----------------------------------------------------------------------===//
490 )FMT";
491 
492 //===----------------------------------------------------------------------===//
493 // Named generic op generation.
494 // These ops map at most a single contraction that complies with the limitations
495 // of a linalg.generic.
496 //===----------------------------------------------------------------------===//
497 
498 // Template for Linalg named ops' ODS definitions. Parameters:
499 // {0}: ODS/C++ op name
500 // {1}: assembly op mnemonic
501 // {2}: op interface list
502 // {3}: documentation (summary + description)
503 // {4}: op attribute list
504 // {5}: builder methods taking standalone attribute parameters
505 // {6}: additional method defintions
506 // {7}: additional methods for attributes used by indexing maps
507 static const char structuredOpOdsHeaderFormat[] = R"FMT(
508 //===----------------------------------------------------------------------===//
509 // Op definition for {0}
510 //===----------------------------------------------------------------------===//
511 
512 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
513   /*extraInterfaces=*/[{2}])> {
514     {3}
515     let arguments = (ins
516       Variadic<AnyType>:$inputs,
517       Variadic<AnyShaped>:$outputs{4}
518     );
519     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
520     let regions = (region AnyRegion:$region);
521 
522     let skipDefaultBuilders = 1;
523     let builders = [
524       OpBuilder<
525       (ins "ValueRange":$inputs, "ValueRange":$outputs,
526             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
527       [{{
528         buildStructuredOp($_builder, $_state, llvm::None, inputs, outputs,
529           attributes, {0}::getRegionBuilder());
530       }]>,
531       OpBuilder<
532       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
533             "ValueRange":$outputs,
534             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
535       [{{
536         buildStructuredOp($_builder, $_state, resultTensorTypes,
537           inputs, outputs, attributes, {0}::getRegionBuilder());
538       }]>,
539       OpBuilder<
540       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
541             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
542       [{{
543         $_state.addOperands(operands);
544         $_state.addAttributes(attributes);
545         $_state.addTypes(resultTensorTypes);
546         (void)$_state.addRegion();
547       }]>
548       {5}
549     ];
550     let hasCustomAssemblyFormat = 1;
551     let hasFolder = 1;
552     {6}
553 
554     let extraClassDeclaration = structuredOpsBaseDecls # [{{
555       // Auto-generated.
556       ArrayAttr iterator_types();
557       ArrayAttr getIndexingMaps();
558       static void regionBuilder(ImplicitLocOpBuilder &b,
559                                 Block &block, ArrayRef<NamedAttribute> attrs);
560       static std::function<void(ImplicitLocOpBuilder &,
561                                 Block &, ArrayRef<NamedAttribute>)>
562       getRegionBuilder() {{
563         return regionBuilder;
564       }
565 
566       // Generic methods.
567       static unsigned getNumRegionArgs();
568       std::string getLibraryCallName();
569       {7}
570     }];
571 }
572 )FMT";
573 
574 // Builder method taking attribute parameters. Parameters:
575 // {0}: Class name
576 // {1}: Comma interleaved attribute parameters
577 // {2}: Attribute initialization
578 static const char structuredOpBuilderFormat[] = R"FMT(
579   , OpBuilder<
580   (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
581        "ValueRange":$outputs, {1},
582        CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
583   [{{
584     {2}
585     buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
586       attributes, {0}::getRegionBuilder());
587   }]>
588 )FMT";
589 
590 // The iterator_types() method for structured ops. Parameters:
591 // {0}: Class name
592 // {1}: Comma interleaved iterator type names.
593 static const char structuredOpIteratorTypesFormat[] =
594     R"FMT(
595 ArrayAttr {0}::iterator_types() {{
596   return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
597 }
598 )FMT";
599 
600 // The iterator_types() method for rank polymorphic structured ops. Parameters:
601 // {0}: Class name
602 static const char rankPolyStructuredOpIteratorTypesFormat[] =
603     R"FMT(
604 ArrayAttr {0}::iterator_types() {{
605   int64_t rank = getRank(getOutputOperand(0));
606   return Builder(getContext()).getStrArrayAttr(
607     SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
608 }
609 )FMT";
610 
611 // The indexing_maps() method for structured ops. Parameters:
612 // {0}: Class name
613 // {1}: Comma-separated list of dimension variable names.
614 // {2}: Statements
615 static const char structuredOpIndexingMapsFormat[] = R"FMT(
616 ArrayAttr {0}::getIndexingMaps() {{
617   static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
618   ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
619   if (cached)
620     return cached;
621 
622   MLIRContext *context = getContext();
623   auto symbolBindings = getSymbolBindings(*this);
624   SmallVector<AffineMap> maps;
625   {2}
626   cached = Builder(context).getAffineMapArrayAttr(maps);
627   getOperation()->setAttr(memoizeAttr, cached);
628   return cached;
629 }
630 )FMT";
631 
632 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
633 // {0}: Class name
634 static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
635 ArrayAttr {0}::getIndexingMaps() {{
636   MLIRContext *context = getContext();
637   AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
638   AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
639     getNumParallelLoops(), context);
640   SmallVector<AffineMap> indexingMaps;
641   for (OpOperand *opOperand : getInputAndOutputOperands())
642     indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
643   return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
644 }
645 )FMT";
646 
647 // Implementations of fold and getEffects.
648 // Parameters:
649 // {0}: Class name
650 const char structuredOpFoldersFormat[] = R"FMT(
651 LogicalResult {0}::fold(ArrayRef<Attribute>,
652                         SmallVectorImpl<OpFoldResult> &) {{
653   return foldMemRefCast(*this);
654 }
655 void {0}::getEffects(SmallVectorImpl<
656     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
657       SmallVector<Value> inputBuffers = getInputBufferOperands();
658       SmallVector<Value> outputBuffers = getOutputBufferOperands();
659       getGenericEffectsImpl(effects,
660         getOperation()->getResults(), inputBuffers, outputBuffers);
661 }
662 )FMT";
663 
664 // Implementation of parse/print.
665 // Parameters:
666 // {0}: Class name
667 static const char structuredOpParserFormat[] = R"FMT(
668 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
669   return ::parseNamedStructuredOp(parser, result,
670     {0}::getNumRegionArgs(), {0}::getRegionBuilder());
671 }
672 void {0}::print(OpAsmPrinter &p) {{
673   ::printNamedStructuredOp(p, getOperation(), inputs(), outputs());
674 }
675 )FMT";
676 
generateNamedGenericOpOds(LinalgOpConfig & opConfig,GenerationContext & genContext)677 static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
678                                                GenerationContext &genContext) {
679   if (!genContext.shouldGenerateOds())
680     return success();
681 
682   raw_ostream &os = genContext.odss();
683 
684   std::string interfaceNameList;
685   std::string attrList;
686   std::string attrMethods;
687   std::string attrBuilder;
688 
689   std::string doc;
690   if (opConfig.metadata->doc) {
691     static const char structuredOpDocFmt[] = R"FMT(
692   let summary = [{ {0} }];
693   let description = [{
694     {1}
695   }];
696 )FMT";
697     StringRef summary, description;
698     std::tie(summary, description) =
699         StringRef(*opConfig.metadata->doc).trim().split('\n');
700     doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim());
701   }
702 
703   interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
704 
705   std::string definitionList;
706   for (const std::string &definition : opConfig.metadata->defines) {
707     static const char definitionFmt[] = "let {0} = 1;\n";
708     definitionList.append(llvm::formatv(definitionFmt, definition));
709   }
710 
711   if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
712         return isAttribute(arg.kind);
713       })) {
714     SmallVector<std::string> attrDefs;
715     SmallVector<std::string> attrParams;
716     SmallVector<std::string> attrStmts;
717     for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
718       static const char paramFmt[] = "\"Attribute\":${0}";
719       static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
720       // Add the type conversion attributes to the op definition and builders.
721       if (isFunctionAttribute(arg.kind)) {
722         assert(arg.defaultFn);
723         std::string enumName = convertOperandKindToEnumName(arg.kind);
724         static const char typeFmt[] = "{0}::{1}";
725         static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
726         attrDefs.push_back(llvm::formatv(
727             defFmt, llvm::formatv("{0}Attr", enumName),
728             llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
729         attrParams.push_back(llvm::formatv(paramFmt, arg.name));
730         attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
731       }
732       // Add the index attributes to the op definition and builders.
733       if (arg.kind == LinalgOperandDefKind::IndexAttr) {
734         assert(arg.indexAttrMap.has_value());
735         assert(arg.defaultIndices.has_value());
736         size_t size = arg.indexAttrMap->affineMap().getNumResults();
737         assert(arg.defaultIndices.value().size() == size);
738         static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
739         static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}";
740         std::string defaultVals;
741         llvm::raw_string_ostream ss(defaultVals);
742         llvm::interleave(
743             arg.defaultIndices.value(), ss,
744             [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
745             ", ");
746         attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
747                                          ss.str(), arg.name));
748         attrParams.push_back(llvm::formatv(paramFmt, arg.name));
749         attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
750       }
751     }
752     if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
753           return arg.kind == LinalgOperandDefKind::IndexAttr;
754         })) {
755       attrMethods = R"(
756         bool hasDynamicIndexingMaps();
757         LogicalResult verifyIndexingMapRequiredAttributes();
758       )";
759     }
760     attrList = ",\n" + llvm::join(attrDefs, ",\n");
761     attrBuilder = llvm::formatv(
762         structuredOpBuilderFormat, opConfig.metadata->cppClassName,
763         llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
764   }
765 
766   os << llvm::formatv(structuredOpOdsHeaderFormat,
767                       opConfig.metadata->cppClassName, opConfig.metadata->name,
768                       interfaceNameList, doc, attrList, attrBuilder,
769                       definitionList, attrMethods);
770 
771   return success();
772 }
773 
774 static LogicalResult
generateNamedGenericOpDefns(LinalgOpConfig & opConfig,GenerationContext & genContext)775 generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
776                             GenerationContext &genContext) {
777   if (!genContext.shouldGenerateDefns())
778     return success();
779 
780   raw_ostream &os = genContext.defns();
781   StringRef className = opConfig.metadata->cppClassName;
782 
783   // Implementation banner.
784   std::string bannerComment = llvm::formatv("Implementation of {0}", className);
785   os << llvm::formatv(bannerFormat, bannerComment);
786 
787   // Compute the number of scalar and tensor arguments.
788   int64_t numOfArgs =
789       llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
790         return arg.kind == LinalgOperandDefKind::InputTensor ||
791                arg.kind == LinalgOperandDefKind::Scalar ||
792                arg.kind == LinalgOperandDefKind::OutputTensor;
793       });
794 
795   // An operation that accesses only scalars and scalar/rank zero tensors is
796   // rank polymorhpic. We implement rank polymorphism by generating different
797   // indexing maps and iterators that match the rank of the first output tensor.
798   // An operation is rank polymorphic if the iteration domain has rank zero.
799   bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
800 
801   // Generate the iterator_types() method.
802   if (!isRankPolymorphic) {
803     std::string iteratorsStr;
804     llvm::raw_string_ostream ss(iteratorsStr);
805     llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
806                           [&](LinalgIteratorTypeDef it) {
807                             switch (it) {
808                             case LinalgIteratorTypeDef::parallel:
809                               ss << "getParallelIteratorTypeName()";
810                               break;
811                             case LinalgIteratorTypeDef::reduction:
812                               ss << "getReductionIteratorTypeName()";
813                               break;
814                             }
815                           });
816     ss.flush();
817     os << llvm::formatv(structuredOpIteratorTypesFormat, className,
818                         iteratorsStr);
819   } else {
820     os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className);
821   }
822 
823   // Generating the getIndexingMaps() method.
824   if (auto &staticMaps =
825           opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
826     if (staticMaps->empty())
827       return emitError(genContext.getLoc()) << "op has no indexing maps";
828     if (!isRankPolymorphic) {
829       AffineMap firstMap = staticMaps->front().affineMap();
830 
831       // Symbol bindings.
832       {
833         // For each symbol, generate a declaration for it, either with an
834         // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
835         // an attribute).
836         // TODO: Possibly lift into a top-level method.
837         static const char structuredOpSymbolBindingsFormat[] = R"FMT(
838 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
839   MLIRContext *context = self.getContext();
840   SmallVector<AffineExpr> exprs;
841 {1}
842   return exprs;
843 }
844 )FMT";
845 
846         unsigned symbolCount = firstMap.getNumSymbols();
847         SmallVector<std::string> symbolBindings;
848         for (unsigned i = 0; i < symbolCount; ++i) {
849           symbolBindings.push_back(llvm::formatv(
850               "  exprs.push_back(getAffineSymbolExpr({0}, context));", i));
851         }
852 
853         // Access an index attribute. Parameters:
854         // {0}: Attribute name
855         // {1}: Symbol position
856         // {2}: Attribute index
857         static const char structuredOpAccessAttrFormat[] = R"FMT(
858 int64_t cst{1} = self.{0}().getValues<int64_t>()[{2}];
859 exprs.push_back(getAffineConstantExpr(cst{1}, context));
860 )FMT";
861         // Update all symbol bindings mapped to an attribute.
862         for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
863           if (arg.kind != LinalgOperandDefKind::IndexAttr)
864             continue;
865           assert(arg.indexAttrMap);
866           for (auto &en :
867                llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
868             if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
869               symbolBindings[symbol.getPosition()] =
870                   llvm::formatv(structuredOpAccessAttrFormat, arg.name,
871                                 symbol.getPosition(), en.index());
872             }
873           }
874         }
875 
876         std::string symbolBindingsStr;
877         llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
878         llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
879         symbolBindingsSs.flush();
880 
881         os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
882                             symbolBindingsStr);
883       }
884 
885       // Indexing maps.
886       {
887         unsigned dimCount = firstMap.getNumDims();
888 
889         // Generate a comma-separated list of dim identifiers to be passed to
890         // bindDims, ensuring tht AffineExpr identifiers are bound in the right
891         // order to the proper AffineDimExpr.
892         // This results in vars in scope like: d0, d1, d2...
893         SmallVector<unsigned> dimIndices;
894         for (unsigned i = 0; i < dimCount; ++i)
895           dimIndices.push_back(i);
896         std::string dimIdentsStr;
897         llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
898         llvm::interleaveComma(dimIndices, dimIdentsSs,
899                               [&](unsigned i) { dimIdentsSs << "d" << i; });
900         dimIdentsSs.flush();
901 
902         // Statements to add and simplify each affine map.
903         SmallVector<std::string> stmts;
904         for (auto &indexingMap : *staticMaps) {
905           // TODO: Assert that dim and symbol count match the first.
906           stmts.push_back(
907               llvm::formatv("maps.push_back({0});",
908                             generateCppExpression(indexingMap, "context")));
909           stmts.push_back(llvm::formatv(
910               "maps.back() = "
911               "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
912               "symbolBindings, {0}, 0));",
913               dimCount));
914         }
915 
916         // TODO: This needs to be memoized and/or converted to non-parser based
917         // C++ codegen prior to real use.
918         os << llvm::formatv(structuredOpIndexingMapsFormat, className,
919                             dimIdentsStr, interleaveToString(stmts, "\n  "));
920       }
921     } else {
922       os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className);
923     }
924   } else {
925     return emitError(genContext.getLoc())
926            << "generating code for non static indexing maps not currently "
927               "supported";
928   }
929 
930   // getNumRegionArgs()
931   {
932     // Generates a getNumRegionArgs() method. Parameters:
933     // {0}: Class name
934     // {1}: Number of region args
935     static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
936 unsigned {0}::getNumRegionArgs() {{ return {1}; }
937 )FMT";
938     os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
939                         numOfArgs);
940   }
941 
942   // getLibraryCallName()
943   {
944     // Generates a getLibraryCallName method. Parameters:
945     // {0}: Class name
946     static const char structuredOpGetLibraryCallFormat[] = R"FMT(
947 std::string {0}::getLibraryCallName() {{
948   return generateLibraryCallName(getOperation());
949 }
950 )FMT";
951     os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
952   }
953 
954   // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
955   if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
956         return arg.kind == LinalgOperandDefKind::IndexAttr;
957       })) {
958     std::vector<std::string> attrVerifications;
959     for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
960       if (arg.kind != LinalgOperandDefKind::IndexAttr)
961         continue;
962       assert(arg.indexAttrMap);
963       // Verify index attribute. Paramters:
964       // {0}: Attribute name
965       // {1}: Attribute size
966       static const char attrFmt[] = R"FMT(
967 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
968   if (!attr.getType().getElementType().isInteger(64))
969     return op->emitError("incorrect element type for index attribute '{0}'");
970   if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
971     return op->emitError("incorrect shape for index attribute '{0}'");
972 }
973 )FMT";
974       attrVerifications.push_back(llvm::formatv(
975           attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
976     }
977 
978     // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
979     // {0}: Class name
980     // {1}: Attribute verification
981     static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
982 bool {0}::hasDynamicIndexingMaps() {{ return true; }
983 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
984   Operation *op = getOperation();
985   {1}
986   return success();
987 }
988 )FMT";
989     os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes,
990                         className, llvm::join(attrVerifications, "\n"));
991   }
992 
993   // regionBuilder()
994   {
995     // Generates a regionBuilder method. Parameters.
996     // {0}: Class name
997     // {1}: Number of args
998     // {2}: Attributes
999     // {3}: Statements
1000     static const char structuredOpRegionBuilderFormat[] = R"FMT(
1001 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1002                         Block &block, ArrayRef<NamedAttribute> attrs) {{
1003   assert({1} > 0 && block.getNumArguments() == {1} &&
1004          "{0} regionBuilder expects {1} (>=0) args");
1005   RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
1006   SmallVector<Value> yields;
1007   {2}
1008   {3}
1009   helper.yieldOutputs(yields);
1010 }
1011 )FMT";
1012     auto &args = opConfig.structuredOp->args;
1013     auto &assignments = opConfig.structuredOp->assignments;
1014     size_t generatedAssignmentCount = 0;
1015     int localCounter = 0;
1016     SmallVector<std::string> attrs;
1017     SmallVector<std::string> stmts;
1018     for (LinalgOperandDef &arg : args) {
1019       if (!isFunctionAttribute(arg.kind))
1020         continue;
1021       // Obtain the type function attribute values. Parameters.
1022       // {0}: enum name
1023       // {1}: attribute name
1024       // {2}: default type function name
1025       static const char attrDef[] = R"FMT(
1026 {0} {1}Val = {0}::{2};
1027 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1028                               return attr.getName() == "{1}"; });
1029 if ({1}Iter != attrs.end()) {{
1030   if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
1031     {1}Val = attr.getValue();
1032 }
1033 )FMT";
1034       std::string enumName = convertOperandKindToEnumName(arg.kind);
1035       attrs.push_back(
1036           llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
1037     }
1038     for (LinalgOperandDef &arg : args) {
1039       if (arg.kind != LinalgOperandDefKind::OutputTensor)
1040         continue;
1041 
1042       // Find the assignment that correlates with the argument.
1043       ScalarAssign *assignment = findAssignment(arg.name, assignments);
1044       if (!assignment)
1045         return emitError(genContext.getLoc())
1046                << "no assignment found for output argument " << arg.name;
1047       ++generatedAssignmentCount;
1048 
1049       // Recursively generate the expression.
1050       std::function<Optional<std::string>(ScalarExpression &)>
1051           generateExpression =
1052               [&](ScalarExpression &expression) -> Optional<std::string> {
1053         if (expression.arg) {
1054           // Argument reference.
1055           Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args);
1056           if (!argIndex) {
1057             emitError(genContext.getLoc())
1058                 << "scalar argument not defined on the op: " << *expression.arg;
1059             return None;
1060           }
1061           return std::string(
1062               llvm::formatv("block.getArgument({0})", *argIndex));
1063         }
1064         if (expression.constant) {
1065           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1066           stmts.push_back(
1067               llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT",
1068                             cppIdent, expression.constant));
1069           return cppIdent;
1070         }
1071         if (expression.index) {
1072           // Access an iteration index.
1073           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1074           stmts.push_back(llvm::formatv("Value {0} = helper.index({1});",
1075                                         cppIdent, *expression.index));
1076           return cppIdent;
1077         }
1078         if (expression.scalarFn) {
1079           std::string enumName =
1080               convertFunctionKindToEnumName(expression.scalarFn->kind);
1081 
1082           // Get the function or attribute name.
1083           assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
1084           std::string funcType;
1085           if (expression.scalarFn->fnName) {
1086             funcType = llvm::formatv("{0}::{1}", enumName,
1087                                      *expression.scalarFn->fnName);
1088           }
1089           if (expression.scalarFn->attrName) {
1090             if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
1091                   return isFunctionAttribute(arg.kind) &&
1092                          arg.name == expression.scalarFn->attrName.value();
1093                 })) {
1094               emitError(genContext.getLoc())
1095                   << "missing function attribute "
1096                   << expression.scalarFn->attrName.value();
1097             }
1098             funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
1099           }
1100           assert(!funcType.empty());
1101 
1102           // Add the optional type parameter to the operands.
1103           SmallVector<std::string> operandCppValues;
1104           if (expression.scalarFn->kind == ScalarFnKind::Type) {
1105             assert(expression.scalarFn->typeVar.has_value());
1106             Optional<std::string> typeCppValue =
1107                 findTypeValue(expression.scalarFn->typeVar.value(), args);
1108             if (!typeCppValue) {
1109               emitError(genContext.getLoc())
1110                   << "type variable " << expression.scalarFn->typeVar.value()
1111                   << ", used in a type conversion, must map to a predefined or "
1112                   << "an argument type but it does not";
1113               return None;
1114             }
1115             operandCppValues.push_back(typeCppValue.value());
1116           }
1117 
1118           // Collect the scalar operands.
1119           for (ScalarExpression &operand : expression.scalarFn->operands) {
1120             auto operandCppValue = generateExpression(operand);
1121             if (!operandCppValue)
1122               return None;
1123             operandCppValues.push_back(*operandCppValue);
1124           }
1125 
1126           // Call the function builder.
1127           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1128           stmts.push_back(llvm::formatv(
1129               "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
1130               funcType, interleaveToString(operandCppValues, ", ")));
1131           return cppIdent;
1132         }
1133         emitError(genContext.getLoc()) << "unknown ScalarExpression type";
1134         return None;
1135       };
1136       Optional<std::string> cppValue = generateExpression(assignment->value);
1137       if (!cppValue)
1138         return failure();
1139       stmts.push_back(llvm::formatv("yields.push_back({0});", cppValue));
1140     }
1141 
1142     if (generatedAssignmentCount != assignments.size())
1143       return emitError(genContext.getLoc())
1144              << "mismatched number of assignments vs output arguments";
1145 
1146     os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
1147                         interleaveToString(attrs, "\n  "),
1148                         interleaveToString(stmts, "\n  "));
1149   }
1150 
1151   // Parser and printer.
1152   os << llvm::formatv(structuredOpParserFormat, className);
1153 
1154   // Canonicalizers and folders.
1155   os << llvm::formatv(structuredOpFoldersFormat, className);
1156 
1157   return success();
1158 }
1159 
generateOp(LinalgOpConfig & opConfig,GenerationContext & genContext)1160 static LogicalResult generateOp(LinalgOpConfig &opConfig,
1161                                 GenerationContext &genContext) {
1162   // Switch on op type being generated.
1163   if (opConfig.structuredOp) {
1164     return success(
1165         succeeded(generateNamedGenericOpOds(opConfig, genContext)) &&
1166         succeeded(generateNamedGenericOpDefns(opConfig, genContext)));
1167   }
1168   return emitError(genContext.getLoc()) << "unsupported operation type";
1169 }
1170 
1171 //===----------------------------------------------------------------------===//
1172 // Command line options and main
1173 //===----------------------------------------------------------------------===//
1174 
1175 static llvm::cl::opt<std::string>
1176     inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
1177                   llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1178 
1179 static llvm::cl::opt<std::string>
1180     outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1181                           llvm::cl::value_desc("filename"), llvm::cl::init(""));
1182 
1183 static llvm::cl::opt<std::string>
1184     outputCppImplFilename("o-impl",
1185                           llvm::cl::desc("C++ implementation file name"),
1186                           llvm::cl::value_desc("filename"), llvm::cl::init(""));
1187 
main(int argc,char ** argv)1188 int main(int argc, char **argv) {
1189   llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML");
1190 
1191   // Set up the input file.
1192   std::string errorMessage;
1193   std::unique_ptr<llvm::MemoryBuffer> file =
1194       mlir::openInputFile(inputFilename, &errorMessage);
1195   if (!file) {
1196     llvm::errs() << errorMessage << "\n";
1197     return 1;
1198   }
1199 
1200   MLIRContext mlirContext;
1201   LinalgYAMLContext yamlContext{&mlirContext};
1202 
1203   std::vector<LinalgOpConfig> opConfigs;
1204 
1205   // Parse input.
1206   Input yin(file->getBuffer(), &yamlContext);
1207   yin >> opConfigs;
1208 
1209   if (yin.error())
1210     return 1;
1211 
1212   // Open output files.
1213   std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
1214   if (!outputOdsDeclFilename.empty()) {
1215     outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage);
1216     if (!outputOdsDecl) {
1217       llvm::errs() << errorMessage << "\n";
1218       return 1;
1219     }
1220   }
1221 
1222   std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
1223   if (!outputCppImplFilename.empty()) {
1224     outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage);
1225     if (!outputCppImpl) {
1226       llvm::errs() << errorMessage << "\n";
1227       return 1;
1228     }
1229   }
1230 
1231   if (!outputOdsDecl && !outputCppImpl) {
1232     llvm::errs() << "error: No output files specified\n";
1233     return 1;
1234   }
1235 
1236   // Generate.
1237   GenerationContext genContext(&mlirContext,
1238                                outputOdsDecl ? &outputOdsDecl->os() : nullptr,
1239                                outputCppImpl ? &outputCppImpl->os() : nullptr);
1240 
1241   for (auto &opConfig : opConfigs) {
1242     if (!opConfig.metadata) {
1243       emitError(genContext.getLoc())
1244           << "missing operation metadata on subsequent op";
1245       return 1;
1246     }
1247 
1248     genContext.setLoc(NameLoc::get(
1249         StringAttr::get(&mlirContext, opConfig.metadata->cppClassName)));
1250     if (failed(generateOp(opConfig, genContext))) {
1251       return 1;
1252     }
1253   }
1254 
1255   if (outputOdsDecl)
1256     outputOdsDecl->keep();
1257   if (outputCppImpl)
1258     outputCppImpl->keep();
1259 
1260   return 0;
1261 }
1262