olap/spark-tungsten:codegen

2024-01-07 17:22:46

15721这一章没什么好说的,不再贴课程内容了。codegen和simd在工业界一般只会选一种实现。比如phothon之前用codegen,然后改成了向量化引擎。一般gen的都是weld IR/LLVM IR/当前语言,gen成C++的也要检查是不是有本地预编译版本,要不没法用。因为clickhouse没有codegen,这节课就拿我比较熟悉的spark的tungsten来当例子,tungsten会gen成scala,然后拿janino动态编译。
tungsten主要有两个特色:一个是codegen,另一个是in-heap memory的管理。本文顺便把它的内存管理也分析一下。在jvm堆内自由分配内存,不被free,不受gc影响,还是挺有意思的。

WASG

手写代码的生成过程分为两个步骤:

  1. 从父节点到子节点,递归调用 doProduce,生成框架
  2. 从子节点到父节点,递归调用 doConsume,向框架填充每一个操作符的运算逻辑

image.png
首先,在 Stage 顶端节点也就是 Project 之上,添加 WholeStageCodeGen 节点。WholeStageCodeGen 节点通过调用 doExecute 来触发整个代码生成过程的计算。doExecute 会递归调用子节点的 doProduce 函数,直到遇到 Shuffle Boundary 为止。这里,Shuffle Boundary 指的是 Shuffle 边界,要么是数据源,要么是上一个 Stage 的输出。在叶子节点(也就是 Scan)调用的 doProduce 函数会先把手写代码的框架生成出来。

  override def doExecute(): RDD[InternalRow] = {
	// 下面这一行将会调用子类的produce完成上述过程。
    val (ctx, cleanedSource) = doCodeGen()
    // try to compile and fallback if it failed
	// 调用janino完成动态编译过程
    val (_, compiledCodeStats) = try {
      CodeGenerator.compile(cleanedSource)
    } catch {
      case NonFatal(_) if !Utils.isTesting && conf.codegenFallback =>
        // We should already saw the error message
        logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
        return child.execute()
    }

    // Check if compiled code has a too large function
    if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) {
      logInfo(s"Found too long generated codes and JIT optimization might not work: " +
        s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " +
        s"${conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
        s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
        s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
      return child.execute()
    }

    val references = ctx.references.toArray

    val durationMs = longMetric("pipelineTime")

    // Even though rdds is an RDD[InternalRow] it may actually be an RDD[ColumnarBatch] with
    // type erasure hiding that. This allows for the input to a code gen stage to be columnar,
    // but the output must be rows.
    val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
    assert(rdds.size <= 2, "Up to two input RDDs can be supported")
    if (rdds.length == 1) {
      rdds.head.mapPartitionsWithIndex { (index, iter) =>
        val (clazz, _) = CodeGenerator.compile(cleanedSource)
        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
        buffer.init(index, Array(iter))
        new Iterator[InternalRow] {
          override def hasNext: Boolean = {
            val v = buffer.hasNext
            if (!v) durationMs += buffer.durationMs()
            v
          }
          override def next: InternalRow = buffer.next()
        }
      }
    } else {
      // Right now, we support up to two input RDDs.
      rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
        Iterator((leftIter, rightIter))
        // a small hack to obtain the correct partition index
      }.mapPartitionsWithIndex { (index, zippedIter) =>
        val (leftIter, rightIter) = zippedIter.next()
        val (clazz, _) = CodeGenerator.compile(cleanedSource)
        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
        buffer.init(index, Array(leftIter, rightIter))
        new Iterator[InternalRow] {
          override def hasNext: Boolean = {
            val v = buffer.hasNext
            if (!v) durationMs += buffer.durationMs()
            v
          }
          override def next: InternalRow = buffer.next()
        }
      }
    }
  }

  def doCodeGen(): (CodegenContext, CodeAndComment) = {
    val startTime = System.nanoTime()
    val ctx = new CodegenContext
    val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)

    // main next function.
    ctx.addNewFunction("processNext",
      s"""
        protected void processNext() throws java.io.IOException {
          ${code.trim}
        }
       """, inlineToOuterClass = true)

    val className = generatedClassName()

    val source = s"""
      public Object generate(Object[] references) {
        return new $className(references);
      }

      ${ctx.registerComment(
        s"""Codegened pipeline for stage (id=$codegenStageId)
           |${this.treeString.trim}""".stripMargin,
         "wsc_codegenPipeline")}
      ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)}
      final class $className extends ${classOf[BufferedRowIterator].getName} {

        private Object[] references;
        private scala.collection.Iterator[] inputs;
        ${ctx.declareMutableStates()}

        public $className(Object[] references) {
          this.references = references;
        }

        public void init(int index, scala.collection.Iterator[] inputs) {
          partitionIndex = index;
          this.inputs = inputs;
          ${ctx.initMutableStates()}
          ${ctx.initPartition()}
        }

        ${ctx.emitExtraCode()}

        ${ctx.declareAddedFunctions()}
      }
      """.trim

    // try to compile, helpful for debug
    val cleanedSource = CodeFormatter.stripOverlappingComments(
      new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments()))

    val duration = System.nanoTime() - startTime
    WholeStageCodegenExec.increaseCodeGenTime(duration)

    logDebug(s"\n${CodeFormatter.format(cleanedSource)}")
    (ctx, cleanedSource)
  }

然后,Scan 中的 doProduce 会反向递归调用每个父节点的 doConsume 函数。不同操作符在执行 doConsume 函数的过程中,会把关系表达式转化成 Java 代码,然后把这份代码像做“完形填空”一样,嵌入到刚刚的代码框架里。
doConsume代码不太好理解,我们以filter为例:

  
  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    val numOutput = metricTerm(ctx, "numOutputRows")
	
    val predicateCode = generatePredicateCode(
      ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)

    // Reset the isNull to false for the not-null columns, then the followed operators could
    // generate better code (remove dead branches).
    val resultVars = input.zipWithIndex.map { case (ev, i) =>
      if (notNullAttributes.contains(child.output(i).exprId)) {
        ev.isNull = FalseLiteral
      }
      ev
    }

    // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
    s"""
       |do {
       |  $predicateCode
       |  $numOutput.add(1);
       |  ${consume(ctx, resultVars)}
       |} while(false);
     """.stripMargin
  }

  
  protected def generatePredicateCode(
      ctx: CodegenContext,
      inputAttrs: Seq[Attribute],
      inputExprCode: Seq[ExprCode],
      outputAttrs: Seq[Attribute],
      notNullPreds: Seq[Expression],
      otherPreds: Seq[Expression],
      nonNullAttrExprIds: Seq[ExprId]): String = {
    /**
     * Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
     */
    def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
      val bound = BindReferences.bindReference(c, attrs)
      val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references)

      // Generate the code for the predicate.
      val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx)
      val nullCheck = if (bound.nullable) {
        s"${ev.isNull} || "
      } else {
        s""
      }

      s"""
         |$evaluated
         |${ev.code}
         |if (${nullCheck}!${ev.value}) continue;
       """.stripMargin
    }

    // To generate the predicates we will follow this algorithm.
    // For each predicate that is not IsNotNull, we will generate them one by one loading attributes
    // as necessary. For each of both attributes, if there is an IsNotNull predicate we will
    // generate that check *before* the predicate. After all of these predicates, we will generate
    // the remaining IsNotNull checks that were not part of other predicates.
    // This has the property of not doing redundant IsNotNull checks and taking better advantage of
    // short-circuiting, not loading attributes until they are needed.
    // This is very perf sensitive.
    // TODO: revisit this. We can consider reordering predicates as well.
    val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
    val extraIsNotNullAttrs = mutable.Set[Attribute]()
    val generated = otherPreds.map { c =>
      val nullChecks = c.references.map { r =>
        val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
        if (idx != -1 && !generatedIsNotNullChecks(idx)) {
          generatedIsNotNullChecks(idx) = true
          // Use the child's output. The nullability is what the child produced.
          genPredicate(notNullPreds(idx), inputExprCode, inputAttrs)
        } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
          extraIsNotNullAttrs += r
          genPredicate(IsNotNull(r), inputExprCode, inputAttrs)
        } else {
          ""
        }
      }.mkString("\n").trim

      // Here we use *this* operator's output with this output's nullability since we already
      // enforced them with the IsNotNull checks above.
      s"""
         |$nullChecks
         |${genPredicate(c, inputExprCode, outputAttrs)}
       """.stripMargin.trim
    }.mkString("\n")

    val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
      if (!generatedIsNotNullChecks(idx)) {
        genPredicate(c, inputExprCode, inputAttrs)
      } else {
        ""
      }
    }.mkString("\n")

    s"""
       |$generated
       |$nullChecks
     """.stripMargin
  }
}

这个地方先裁剪再判断,首先对涉及到谓词的is not null生成判断,之后进行裁剪,然后对裁剪后的列没有覆盖到is not null的再做一次is not null。这里的性能比较关键。
对于以下sql:

SELECT department, AVG(salary) AS avg_salary
FROM employee
GROUP BY department
HAVING AVG(salary) > 60000

生成效果如下:
generated:

boolean filter_value_2 = !hashAgg_isNull_11;
if (!filter_value_2) continue;


boolean filter_value_3 = false;
filter_value_3 = org.apache.spark.sql.catalyst.util.SQLOrderingUtil.compareDoubles(hashAgg_value_11, 60000.0D) > 0;
if (!filter_value_3) continue;

如果加上一句where salary IS NOT NULL,那么在hashAgg之前,还会插入一段null的判断:

boolean rdd_isNull_3 = rdd_row_0.isNullAt(3);
double rdd_value_3 = rdd_isNull_3 ?
  -1.0 : (rdd_row_0.getDouble(3));

boolean filter_value_2 = !rdd_isNull_3;
if (!filter_value_2) continue;

内存管理

tungsten memory management

image.png
这里的idea很简单,重构对象模型但是不改变gc逻辑,于是tungsten抽象出了page table,来存放大量java native object,page table地址还是由jvm进行管理,拿到地址后在jvm堆内查找。

spark-core

在看spark-unsafe中的tungsten分配器之前, 我们先看下spark-core中的内存管理模块,
我们可以看到MemoryManager中的分配器已经默认换成了tungsten

  /**
   * Allocates memory for use by Unsafe/Tungsten code.
   */
  private[memory] final val tungstenMemoryAllocator: MemoryAllocator = {
    tungstenMemoryMode match {
      case MemoryMode.ON_HEAP => MemoryAllocator.HEAP
      case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE
    }
  }

MemoryManager就是用来管理Execution和Storage之间内存分配的类。
image.png
Execution和Storage都有堆内和堆外内存,使用内存池的方式由MemoryManager进行管理。

  @GuardedBy("this")
  protected val onHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.ON_HEAP)
  @GuardedBy("this")
  protected val offHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.OFF_HEAP)
  @GuardedBy("this")
  protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP)
  @GuardedBy("this")
  protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP)

对于tungsten的实际调用在TaskMMemoryManager中:

// 调用ExecutorMemoryManager进行内存分配,分配得到一个内存页,将其添加到
// page table中,用于内存地址映射
/**
   * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
   * intended for allocating large blocks of memory that will be shared between operators.
   */
  public MemoryBlock allocatePage(long size) {
    if (size > MAXIMUM_PAGE_SIZE_BYTES) {
      throw new IllegalArgumentException(
        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
    }

    final int pageNumber;
    synchronized (this) {
      // allocatedPages是一个bitmap
      // PAGE_TABLE_SIZE是两个内存页 8KB
      pageNumber = allocatedPages.nextClearBit(0);
      if (pageNumber >= PAGE_TABLE_SIZE) {
        throw new IllegalStateException(
          "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
      }
      allocatedPages.set(pageNumber);
    }
    try {
      page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
    } catch (OutOfMemoryError e) {
      // 继续清理直到满足需要
      logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
      // there is no enough memory actually, it means the actual free memory is smaller than
      // MemoryManager thought, we should keep the acquired memory.
      synchronized (this) {
        acquiredButNotUsed += acquired;
        allocatedPages.clear(pageNumber);
      }
      // this could trigger spilling to free some pages.
      return allocatePage(size, consumer);
    }
    page.pageNumber = pageNumber;
    pageTable[pageNumber] = page;
    if (logger.isTraceEnabled()) {
      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
    }
    return page;
  }
 
 给定分配到的内存页和页内的偏移,生成一个64bits的逻辑地址
 /**
   * Given a memory page and offset within that page, encode this address into a 64-bit long.
   * This address will remain valid as long as the corresponding page has not been freed.
   *
   * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
   * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
   *                     this should be the value that you would pass as the base offset into an
   *                     UNSAFE call (e.g. page.baseOffset() + something).
   * @return an encoded page address.
   */
  public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
    if (!inHeap) {
      // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
      // encode. Due to our page size limitation, though, we can convert this into an offset that's
      // relative to the page's base offset; this relative offset will fit in 51 bits.
      offsetInPage -= page.getBaseOffset();
    }
    return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
  }13bits是page number,低位为页内偏移
  @VisibleForTesting
  public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
    return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
  }

  给定逻辑地址,获取page number
  @VisibleForTesting
  public static int decodePageNumber(long pagePlusOffsetAddress) {
    return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
  }

  给定逻辑地址,获取页内偏移
  private static long decodeOffset(long pagePlusOffsetAddress) {
    return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
  }

  给定地址,获取内存页
  /**
   * Get the page associated with an address encoded by
   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
   */
  public Object getPage(long pagePlusOffsetAddress) {
    if (inHeap) {
      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
      final MemoryBlock page = pageTable[pageNumber];
      assert (page != null);
      assert (page.getBaseObject() != null);
      return page.getBaseObject();
    } else {
      return null;
    }
  }

给定地址获取页内偏移
  /**
   * Get the offset associated with an address encoded by
   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
   */
  public long getOffsetInPage(long pagePlusOffsetAddress) {
    final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
    if (inHeap) {
      return offsetInPage;
    } else {
      // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
      // converted the absolute address into a relative address. Here, we invert that operation:
      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
      final MemoryBlock page = pageTable[pageNumber];
      assert (page != null);
      return page.getBaseOffset() + offsetInPage;
    }
  }

spark-storage

spark-storage中类的关系比较复杂,不在这里展开,列一下几个重要类:

  1. BlockId:

表示 Spark 中数据块的唯一标识符。
依赖关系:通常作为其他存储相关类的参数或属性,例如 BlockManager。

  1. BlockInfo:

包含有关数据块的元数据信息。
依赖关系:依赖于 BlockId,并且可以与 BlockManager 一起使用。

  1. BlockManager:

负责管理分布式数据块的存储和检索。
依赖关系:依赖于 BlockId、BlockInfo 等类,与 DiskStore、MemoryStore 等一起协同工作。

  1. BlockManagerMaster:

管理集群中所有 BlockManager 的主节点。
依赖关系:依赖于 BlockManager,与 BlockManagerId 等协同工作。

  1. BlockManagerId:

表示 BlockManager 的唯一标识符。
依赖关系:通常作为 BlockManagerMaster 的参数,用于标识不同的 BlockManager。

  1. BlockManagerMasterEndpoint:

BlockManagerMaster 与其他节点通信的端点。
依赖关系:依赖于 BlockManagerMaster,与 RpcEndpoint 等一起使用。

  1. DiskBlockManager:

BlockManager 的一个实现,负责将数据块持久化到磁盘。
依赖关系:依赖于 BlockManager 和 DiskStore,与 DiskStore 等一起工作。

  1. MemoryStore:

BlockManager 中负责将数据块存储在内存中的组件。
依赖关系:依赖于 BlockManager 和 MemoryManager,与 MemoryManager 等协同工作。

  1. DiskStore:

BlockManager 中负责将数据块持久化到磁盘的组件。
依赖关系:依赖于 BlockManager 和 DiskBlockManager。

  1. MemoryManager:

负责管理内存的组件,与 MemoryStore 等协同工作。
依赖关系:通常与 MemoryStore 和 BlockManager 一起使用。

  1. ShuffleBlockId:

用于表示与Shuffle相关的数据块的标识符。
依赖关系:依赖于 BlockId。

spark-unsafe

HeapMemoryAllocator实现了堆内存的实际分配


  @GuardedBy("this")
  private final Map<Long, LinkedList<WeakReference<long[]>>> bufferPoolsBySize = new HashMap<>();
  private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;

  /**
   * Returns true if allocations of the given size should go through the pooling mechanism and
   * false otherwise.
   */
  private boolean shouldPool(long size) {
    // Very small allocations are less likely to benefit from pooling.
    return size >= POOLING_THRESHOLD_BYTES;
  }

这里使用一个弱引用的Long数组对于1M以上的回收内存进行资源池化,弱引用为了避免长时间未使用的数组一直保留在缓冲池中,消耗内存资源。
这也是spark内存使用不稳定的原因之一:弱引用对象的回收仍然是jvm控制的,没办法做到立即回收。

  @Override
  public MemoryBlock allocate(long size) throws OutOfMemoryError {
    int numWords = (int) ((size + 7) / 8);
    long alignedSize = numWords * 8L;
    assert (alignedSize >= size);
    if (shouldPool(alignedSize)) {
      synchronized (this) {
        final LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(alignedSize);
        if (pool != null) {
          while (!pool.isEmpty()) {
            final WeakReference<long[]> arrayReference = pool.pop();
            final long[] array = arrayReference.get();
            if (array != null) {
              assert (array.length * 8L >= size);
              MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
              if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
                memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
              }
              return memory;
            }
          }
          bufferPoolsBySize.remove(alignedSize);
        }
      }
    }
    long[] array = new long[numWords];
    MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
    if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
      memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
    }
    return memory;
  }

free的时候如果大于1M,则池化,否则清空引用


  @Override
  public void free(MemoryBlock memory) {
    assert (memory.obj != null) :
      "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?";
    assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
      "page has already been freed";
    assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
            || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
      "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " +
        "free()";

    final long size = memory.size();
    if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
      memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
    }

    // Mark the page as freed (so we can detect double-frees).
    memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;

    // As an additional layer of defense against use-after-free bugs, we mutate the
    // MemoryBlock to null out its reference to the long[] array.
    long[] array = (long[]) memory.obj;
    memory.setObjAndOffset(null, 0);

    long alignedSize = ((size + 7) / 8) * 8;
    if (shouldPool(alignedSize)) {
      synchronized (this) {
        LinkedList<WeakReference<long[]>> pool =
          bufferPoolsBySize.computeIfAbsent(alignedSize, k -> new LinkedList<>());
        pool.add(new WeakReference<>(array));
      }
    }
  }

文章来源:https://blog.csdn.net/treblez/article/details/135438658
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。