/// This is an example of a c++ rewrite pattern for the TransposeOp. It /// optimizes the following scenario: transpose(transpose(x)) -> x structSimplifyRedundantTranspose :public mlir::OpRewritePattern { /// We register this pattern to match every toy.transpose in the IR. /// The "benefit" is used by the framework to order the patterns and process /// them in order of profitability. SimplifyRedundantTranspose(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {}
/// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter)constoverride{ // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp();
// Input defined by another transpose? If not, no match. if (!transposeInputOp) return failure();
// Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); return success(); } };
/// Register our patterns as "canonicalization" patterns on the TransposeOp so /// that they can be picked up by the Canonicalization framework. voidTransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context){ results.add(context); }
MLIR还提供了一种表达式重写的方法,是基于DDR规则的方式来自动生成表达式匹配和重写函数,代码生成的部分仍然基于ODS框架实现。DRR(Declarative, Rule-based Pattern-match and Rewrite):声明性、基于规则的模式匹配和重写方法。它是一种基于 DAG 的声明性重写器,提供基于表格的模式匹配和重写规则的句法。
/// This class defines the interface for handling inlining with Toy operations. /// We simplify inherit from the base interface class and override /// the necessary methods. structToyInlinerInterface :public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface;
/// This hook checks to see if the given callable operation is legal to inline /// into the given call. For Toy this hook can simply return true, as the Toy /// Call operation is always inlinable. boolisLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned)constfinal{ returntrue; }
/// This hook checks to see if the given operation is legal to inline into the /// given region. For Toy this hook can simply return true, as all Toy /// operations are inlinable. boolisLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &)constfinal{ returntrue; }
/// This hook is called when a terminator operation has been inlined. The only /// terminator that we have in the Toy dialect is the return /// operation(toy.return). We handle the return by replacing the values /// previously returned by the call operation with the operands of the /// return. voidhandleTerminator(Operation *op, ArrayRef valuesToRepl)constfinal{ // Only "toy.return" needs to be handled here. auto returnOp = cast(op);
// Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (constauto
&it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } };
/// Dialect initialization, the instance will be owned by the context. This is /// the point of registration of types and operations for the dialect. voidToyDialect::initialize(){ addOperations< #define GET_OP_LIST #include"toy/Ops.cpp.inc" >(); addInterfaces(); }
def GenericCallOp : Toy_Op<"generic_call", [DeclareOpInterfaceMethods]> { let summary = "generic call operation"; let description = [{ Generic calls represent calls to a user defined function that needs to be specialized for the shape of its arguments. The callee name is attached as a symbol reference via an attribute. The arguments list must match the arguments expected by the callee. For example:
This is only valid if a function named "my_func" exists and takes two arguments. }];
// The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs);
// The generic call operation returns a single value of TensorType. let results = (outs F64Tensor);
// Specialize assembly printing and parsing using a declarative format. let assemblyFormat = [{
$callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) }];
// Add custom build methods for the generic call operation. let builders = [ OpBuilder"StringRef":$callee, "ArrayRef":$arguments)> ]; }
/// Return the callee of the generic call operation, this is required by the /// call interface. CallInterfaceCallable GenericCallOp::getCallableForCallee(){ return getAttrOfType("callee"); }
/// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands(){ return inputs(); }
def CastOp : Toy_Op<"cast", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, NoSideEffect, SameOperandsAndResultShape ]> { let summary = "shape cast operation"; let description = [{ The "cast" operation converts a tensor from one type to an equivalent type without changing any data elements. The source and destination types must both be tensor types with the same element type. If both are ranked, then shape is required to match. The operation is invalid if converting to a mismatching constant dimension. }];
let arguments = (ins F64Tensor:$input); let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; }
/// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify /// this operation and provide other additional utilities. boolCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs){ if (inputs.size() != 1 || outputs.size() != 1) returnfalse; // The inputs must be Tensors with the same element type. TensorType input = inputs.front().dyn_cast(); TensorType output = outputs.front().dyn_cast(); if (!input || !output || input.getElementType() != output.getElementType()) returnfalse; // The shape is required to match if both types are ranked. return !input.hasRank() || !output.hasRank() || input == output; }
structToyInlinerInterface :public DialectInlinerInterface { .... /// Attempts to materialize a conversion for a type mismatch between a call /// from this dialect, and a callable region. This method should generate an /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc)constfinal{ return builder.create(conversionLoc, resultType, input); } };
if (enableOpt) { mlir::PassManager pm(&context); // Apply any generic pass manager command line options and run the pipeline. applyPassManagerCLOptions(pm);
// Inline all functions into main and then delete them. pm.addPass(mlir::createInlinerPass()); ... }
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { let description = [{ Interface to access a registered method to infer the return types for an operation that can be used during type inference. }];
let methods = [ InterfaceMethod<"Infer and set the output shape for the current operation.", "void", "inferShapes"> ]; }
ShapeInferenceOpInterface接口继承了OpInterface,该继承接收要赋予生成的 C++ 接口类的名称"ShapeInference"作为模板参数。description字段提供了Operation的简要说明,而methods字段定义Operation将需要提供的接口方法。
def MulOp : Toy_Op<"mul", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "element-wise multiplication operation"; let description = [{ The "mul" operation performs element-wise multiplication between two tensors. The shapes of the tensor operands are expected to match. }];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); let results = (outs F64Tensor);
// Specify a parser and printer method. let parser = [{ return ::parseBinaryOp(parser, result); }]; let printer = [{ return ::printBinaryOp(p, *this); }];
// Allow building a MulOp with from the two input operands. let builders = [ OpBuilder"Value":$lhs, "Value":$rhs)> ]; }
/// Infer the output shape of the MulOp, this is required by the shape inference /// interface. voidMulOp::inferShapes(){ getResult().setType(getOperand(0).getType()); }
classShapeInferencePass :public mlir::PassWrapper { public: voidrunOnFunction()override{ auto f = getFunction();
// Populate the worklist with the operations that need shape inference: // these are operations that return a dynamic shape. llvm::SmallPtrSet<:operation style="color: #1c00cf;line-height: 26px;">16> opWorklist; f.walk([&](mlir::Operation *op) { if (returnsDynamicShape(op)) opWorklist.insert(op); });
// Iterate on the operations in the worklist until all operations have been // inferred or no change happened (fix point). while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). auto nextop = llvm::find_if(opWorklist, allOperandsInferred); if (nextop == opWorklist.end()) break;
Operation *op = *nextop; opWorklist.erase(op);
// Ask the operation to infer its output shapes. LLVM_DEBUG(llvm::dbgs() if (auto shapeOp = dyn_cast(op)) { shapeOp.inferShapes(); } else { op->emitError("unable to infer shape of operation without shape " "inference interface"); return signalPassFailure(); } }
// If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { f.emitError("Shape inference failed, ") signalPassFailure(); } }
/// A utility method that returns if the given operation has all of its /// operands inferred. staticboolallOperandsInferred(Operation *op){ return llvm::all_of(op->getOperandTypes(), [](Type operandType) { return operandType.isa(); }); }
/// A utility method that returns if the given operation has a dynamically /// shaped result. staticboolreturnsDynamicShape(Operation *op){ return llvm::any_of(op->getResultTypes(), [](Type resultType) { return !resultType.isa(); }); } }; } // end anonymous namespace