Lines Matching refs:tensorCast
198 LogicalResult matchAndRewrite(CastOp tensorCast, in matchAndRewrite()
200 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); in matchAndRewrite()
208 auto resultType = tensorCast.getType().cast<TensorType>(); in matchAndRewrite()
226 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, in matchAndRewrite()
247 LogicalResult matchAndRewrite(CastOp tensorCast, in matchAndRewrite()
250 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>(); in matchAndRewrite()
252 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || in matchAndRewrite()
253 tensorCast.getType().getShape() == tensorCast.getSource() in matchAndRewrite()
267 int64_t dim = tensorCast.getType().getShape()[dimIndex++]; in matchAndRewrite()
274 tensorCast, tensorCast.getType().cast<RankedTensorType>(), in matchAndRewrite()
717 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>(); in matchAndRewrite() local
718 if (!tensorCast) in matchAndRewrite()
722 extract, tensorCast.getSource(), extract.getIndices()); in matchAndRewrite()