终于看了ThreadLocal源码

2023-12-28 13:08:37

前言

之前看ThreadLocal原理基本是博客,但对这个还是一知半解,趁着这几天有空看了一遍,印象深刻了很多。同时发现新大陆,原来ThreadLocal在进行set、get等操作,都会有槽位清理的逻辑,来防止内存泄漏,这也是之前一直没有关注的地方。
在看之前,希望大家先花亿分钟打开ThreadLocal的源码,跟着来一步一步的分析。

一、ThreadLocal介绍

1、ThreadLocal基本方法

ThreadLocal主要包含一个静态方法withInitial(),和三个基本实例方法set()get()remove();
大家可能对withInitial()方法会比较陌生,下面是这个方法的代码:

    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        return new SuppliedThreadLocal<>(supplier);
    }
    
    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

withInitial()使用了java 1.8增加的函数式接口SupplierSupplier接口通常用于延迟计算,即在需要值的时候才进行计算,它提供了一种延迟执行的方式。
withInitial()其实就相当于new Thread(),并且给set值,使用起来就类似于这样:

   ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(()->100);
   //相当于-------------------------------------------------
   ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
   threadLocal.set(100);

2、ThreadLocal的哈希值

每个Threadlocal哈希值是通过调用nextHashCode()方法生成的,最终调用的是AtomicInteger中的getAndAdd方法,保证自增的原子性。
每当创建ThreadLocal实例时这个值都会累加 0x61c88647,0x61c88647是散列算法中常用的一个魔法值,用于将哈希码能均匀分布在2的N次方的数组里,降低冲突几率。
虽然0x61c88647是一个比较大的值,但是即使AtomicInteger超出范围变为负数,也不影响计算索引位置,因为只用到了与运算。
最终得出的threadLocalHashCode作为创建ThreadLocalMap实例化对象和计算ThreadLocalMap索引位置使用。

    private final int threadLocalHashCode = nextHashCode();
    
    private static AtomicInteger nextHashCode = new AtomicInteger();

    /**
     * 哈希码累加值
     */
    private static final int HASH_INCREMENT = 0x61c88647;
    /**
     * 返回HashCode
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

二、ThreadLocal内部结构

1、ThreadLocalMap

ThreadLocal主要结构是ThreadLocalMap,是ThreadLocal的静态内部类。但它不是基于Hashmap实现的,而是一个Entry数组,每一个Entry则是一个key-value元素。
Entry的key为ThreadLocal,并且为弱引用,value则为Object。

    static class Entry extends WeakReference<ThreadLocal<?>> {
         
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
为什么key要设置为WeakReference弱引用呢?

这个在面试题上已经烂大街了。。。直接上传送门:https://juejin.cn/post/7126708538440679460。

2、初始化

初始化步骤也不难:先通过threadLocalHashCodeINITIAL_CAPACITY(初始容量为16),算出在table(Entry数组)的位置,再new一个Entry赋值。

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;  //当前的实际Entry数量
            setThreshold(INITIAL_CAPACITY);
        }

最后一行setThreshold方法,代码就简单一句:threshold = len * 2 / 3,其实是计算当前需要扩容的阈值,这里表示的是达到容量的三分之二就要扩容了。

3、set方法

这个set方法跟初始化差不多,也是先算出索引位置,再向tab[i]赋值。

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;  // 获取ThreadLocalMap的Entry数组
    int len = tab.length;  // 获取数组的长度
    int i = key.threadLocalHashCode & (len - 1);  // 计算初始索引

    // 遍历数组,查找匹配的ThreadLocal
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        if (e.refersTo(key)) {
            // 如果找到匹配的ThreadLocal,更新其值并返回
            e.value = value;
            return;
        }

        if (e.refersTo(null)) {
            // 如果遇到过期的ThreadLocal,替换它并返回
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // 如果没有匹配的ThreadLocal,创建新的Entry并插入
    tab[i] = new Entry(key, value);
    int sz = ++size;  // 当前的实际Entry数量增加1

    // 进行重新散列
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

其中有一段for循环,通过循环遍历表中的槽位,查找是否存在相同键的Entry:
(1)如果找到相同键的Entry,更新其值为新值。
(2)如果找到一个空槽位(键为null),则替换为空槽位。
(3)如果循环结束仍未找到相同键的Entry,则在表中的当前位置创建一个新的Entry对象。
末尾的if (!cleanSomeSlots(i, sz) && sz >= threshold)检查了是否有清理任何槽位,并且映射的大小大于或等于阈值threshold。如果这两个条件都为真,说明映射需要定期调整大小和重新散列,以保持其性能。

4、replaceStaleEntry替换槽位并清理

感觉看下来,最复杂就是replaceStaleEntry方法,刚开始看一脸懵。 replaceStaleEntry出现在set方法中,用来替换键为null的Entry,当e.refersTo(null)时,会进入replaceStaleEntry方法。
大致过程可以分为两步:
1、首先开始向前检索key为null的Entry,直到tab[i]为null停止,记录当前索引位置并赋值给slotToExpunge
2、开始向后检索元素,分为两种情况:
(1)找到匹配的key的Entry,更新值后,与tab[staleSlot]进行交换,并清理槽位后返回。注意的是tab[staleSlot]是一个key为null的Entry。
(2)检索后没有找到匹配,这时候就要在tab[staleSlot]新增一个Entry了,并清理slotToExpungelen范围内的槽位。

private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
    int len = tab.length; // 获取数组的长度

    Entry e;

    int slotToExpunge = staleSlot; // 初始化要清理的槽位为staleSlot
    // 从staleSlot往前查找过期Entry
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.refersTo(null))
            slotToExpunge = i; // 更新要清理的槽位为找到的过期Entry的位置

    // 从staleSlot往后查找Entry
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        // 如果找到匹配的ThreadLocal键的Entry
        if (e.refersTo(key)) {
            e.value = value; // 更新值

            //将找到的Entry与staleSlot位置的Entry交换
            //相当于匹配的元素往前移,将key为null元素往后移
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // 如果前一个过期槽位存在,从前一个过期槽位开始清理
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;

            // 清理一些槽位,然后返回
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 如果在向后扫描中未找到过期Entry,
        // slotToExpunge则取扫描键时看到的第一个key为null的Entry索引,为了后面的cleanSomeSlots清理
        if (e.refersTo(null) && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 如果未找到匹配的ThreadLocal键的Entry,将新Entry放入staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 如果运行中还有其他过期Entry,会清理它们
    if (slotToExpunge != staleSlot)
        //从slotToExpunge下标索引开始清理槽位
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

5、get方法

get方法很简单:
1、直接查找元素,找到直接返回
2、第一步没有找到,再利用线性探索进行查找Entry,直到匹配Entry。期间遇到key为null的Entry,顺便清理。

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1); // 计算初始索引
    Entry e = table[i]; // 获取索引处的Entry
    if (e != null && e.refersTo(key))
        return e; // 如果找到匹配的Entry,直接返回
    else
        return getEntryAfterMiss(key, i, e); // 否则调用getEntryAfterMiss方法进行进一步处理
}

// 未找到匹配的ThreadLocal键时的处理
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table; // 获取ThreadLocalMap的Entry数组
    int len = tab.length; // 获取数组的长度,即映射的容量

    while (e != null) {
        if (e.refersTo(key))
            return e; // 如果找到匹配的Entry,直接返回
        if (e.refersTo(null))
            expungeStaleEntry(i); // 如果找到引用为null的过期Entry,则清理
        else
            i = nextIndex(i, len); // 否则计算下一个索引
        e = tab[i]; // 获取新索引处的Entry
    }

    return null; // 如果循环结束仍未找到匹配的Entry,返回null表示未找到
}

6、remove方法

remove步骤如下:
1、获取entry索引位置i,如果tab[i]与key不相等,则继续进行线性探测,直到找到与key相等的元素Entry。
2、找到元素后,调用clear方法进行清理,并且调用expungeStaleEntry,从数组中删除任何过期的Entry和进行哈希调整。

      private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.refersTo(key)) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

7、cleanSomeSlots和expungeStaleEntry清理槽位

cleanSomeSlots会检查一部分数据进行清理,内部实际是调用expungeStaleEntry

 private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;  // 用于标记是否有Entry被移除
    Entry[] tab = table;  // 获取ThreadLocalMap的Entry数组
    int len = tab.length;  // 获取数组的长度
    do {
        i = nextIndex(i, len);  // 计算下一个索引
        Entry e = tab[i];  // 获取当前索引位置的Entry
        if (e != null && e.refersTo(null)) {
            n = len;      // 如果找到符合条件的Entry,则重新设置n为数组长度
            removed = true;  // 设置标记为true,表示有Entry被移除
            i = expungeStaleEntry(i);  // 清理过期的Entry,并返回下一个有效索引
        }
    } while ((n >>>= 1) != 0);  // 继续循环,直到n变为0,比如n为16,则进行4次循环

    return removed;  // 返回标记是否有Entry被移除
}

expungeStaleEntry方法有三个作用:
1、及时清理key为null的Entry
2、重新计算Entry的索引位置,调整后如果遇到哈希冲突,则调用nextIndex进行线性探测,直到获取空槽位索引
3、返回最后一个有效的索引(空槽位)
可以看到这个清理的过程只是覆盖了一段范围,并不是全部区间。

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;  // 获取ThreadLocalMap的Entry数组
    int len = tab.length;  // 获取数组的长度

    // 清理过期Entry
    tab[staleSlot].value = null;  // 将过期槽位的值设为null
    tab[staleSlot] = null;  // 将过期槽位设为null
    size--;  // 减小映射的大小

    // 重新散列直到遇到null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        if (k == null) {
            // 如果ThreadLocal为null,清理该槽位的Entry
            e.value = null;
            tab[i] = null;
            size--;  // 减小映射的大小
        } else {
            int h = k.threadLocalHashCode & (len - 1);

            if (h != i) {
                // 如果计算的哈希码与当前索引不同,说明需要重新散列
                tab[i] = null;

                // 重新哈希后,可能会遇到哈希冲突,使用线性探索法获取空槽位
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;  // 在新的位置存储Entry
            }
        }
    }
    return i;  // 返回最后一个有效的索引(空槽位)
}
8、rehash再哈希

1、先扫描全表,清理key为null的Entry
2、遍历旧数组的每个Entry,计算新的哈希码,并将新位置Entry储存到新数组。
3、再哈希完成后,最后将新Entry[]替换旧Entry[]数组。

 private void rehash() {
    expungeStaleEntries(); // 清理过期的Entry

    // 为了避免滞后,使用较小的阈值进行加倍
    if (size >= threshold - threshold / 4)
        resize(); // 调整映射的大小
}

private void resize() {
    Entry[] oldTab = table; // 获取旧的Entry数组
    int oldLen = oldTab.length; // 获取旧数组的长度
    int newLen = oldLen * 2; // 计算新数组的长度为旧数组的两倍
    Entry[] newTab = new Entry[newLen]; // 创建新的Entry数组
    int count = 0; // 计数非空的Entry

    // 遍历旧数组的每个Entry
    for (Entry e : oldTab) {
        if (e != null) {
            ThreadLocal<?> k = e.get(); // 获取Entry中的ThreadLocal键
            if (k == null) {
                e.value = null; // 如果键为null,帮助垃圾回收
            } else {
                int h = k.threadLocalHashCode & (newLen - 1); // 计算新的哈希码
                while (newTab[h] != null)
                    h = nextIndex(h, newLen); // 处理哈希冲突,找到新的位置
                newTab[h] = e; // 在新位置存储Entry
                count++; // 非空Entry计数加一
            }
        }
    }

    setThreshold(newLen); // 设置新的阈值
    size = count; // 更新映射的大小
    table = newTab; // 将映射的Entry数组指向新的数组
}

虚拟线程

因为用的是jdk21的最新版本,所以ThreadLocal出现虚拟线程的影子。
虚拟线程的目标是提供一种轻量级的线程模型,以便更好地支持大规模并发。与传统的本地线程(Native Thread)相比,虚拟线程更轻量,创建和销毁的成本更低,并且更容易扩展。其他的就不展开讲了,大家可以自行尝尝鲜。
下图代码主要判断当前是否虚拟线程:
在这里插入图片描述

结束

点个赞点个关注再走啦~

文章来源:https://blog.csdn.net/weixin_38225800/article/details/135260284
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。