初学主席树

今天学长心血来潮,跟我讲了主席树。
身为资深笨比的我从早上琢磨到第二天凌晨才过了主席树的洛谷板子题。赶紧整理一下。

主席树又叫做可持久化线段树,他可以解决区间第k大问题。

问题:给定一系列数,问这些书从第l个数到第r个数这段闭区间内,第k大数是多少。

基于朴素思路思考

最直接的思路就是暴力,每次询问都对这个区间排个序,排好了查一查。显然时间复杂度是O(nmlogn)。

但这样不行。因为这个题数据范围很大,比如洛谷的板子题,n能到2e5,查询次数m也2e5,交个暴力上去那不是妥妥要拿TLE了。

基于朴素的改进

朴素的思路可以给我们一定启发。现在对朴素做法改进一下。

改进1:每个点改成桶排并离散化处理

朴素做法就只是单单的在每一个点放上了这个点对应的数字。那这样每次询问都要排序处理。显然这个做法瓶颈在于排序。那有没有办法把排序去掉呢?
假如我们在每个点不去写对应的数字了。我们把这个一维数组二维化来看,每个位置再对应一个一维数组,拿他做桶排处理,也就是每个点不维护数字了,维护一个桶排。那么会有什么变化?

原先,我输入数字,比如上一个位置是i-1,我只需要在i这个位置写上输入的数字就行了。
现在,我输入数字,我先把上一个位置i-1的桶排数组复制到i上,然后在i位置的桶数组上,以这个数字作为下标,在那个位置+1就行了。

比如现在有一系列数字是:4 5 3 7 2 9。那怎么维护?

原先就是开了一个一维数组存:

现在我们改成每个点都是个桶排:

这样我们显然省掉了排序的操作,时间复杂度理论上应该更优了。
但是这对数据范围有所限制,因为桶不能开到1e9个吧。可以考虑对数据进行离散化处理,这样桶就可以存的下那么大的数据了。

桶排节约了排序的时间。如果有了一个区间询问的话,比如询问[l,r]这个区间,我们可以拿位置r上的桶数组,每个元素都对应减去l-1的桶数组。诶?神奇的事情就在这里:
r桶数组减l-1桶数组,得到的新桶数组,他恰好就等效成我开了一个新桶数组,然后把[l,r]这个区间里的数字依次放进桶里的桶数组结果。不信你试试。

但我们的改进还远没有结束。r减l-1得到的桶想找第k小,还是线性的。
诶?在桶数组上,下标是递增的吧。递增,emmm,看来符合单调性,也许我们可以有一种方法,每次查询不再是线性的了,我们能不能用上二分的思路来遍历这个第k大位置呢?

显然是可以的!

改进2:基于桶排,把桶排改成前缀和

我们不桶排了,每个位置改成前缀和数组。这样每个位置都是递增的。

现在我们还是按照刚才的方法,r减l-1以后得到[l,r]的等效前缀和数组,然后这个数组找第k大怎么找?
可以用二分了。也就是转换成了,给你一个数组,找k第一次出现的位置呗。

C++正好有lower_bound能用。

显然,我们把每个点优化为前缀和,相比最开始的暴力来讲已经有了质的飞跃。
但是这个飞跃还是太少了,比如每个下标的1号位置,从来就没被用过,一直是0,这是不是有点浪费。比如我们做减法操作,本身是O(n)的,这是技术瓶颈,怎么办。

对于减法操作的优化,可以边二分边减法。这里先不展开,后面会再说到的。

等等?现在我们要做的是什么操作?区间查询是吧。
有一种数据结构是不是做区间查询很有一把刷子?对!我们可以用线段树来优化这个区间查询!

改进3:每个点改成线段树

我们刚才把每个店改成前缀和数组了,现在我们再进一步:每个点改成线段树。

首先我们离散化:4 5 3 7 2 9离散化为3 4 2 5 1 6。
然后开个线段树,每个结点维护大小在[l,r]这个区间的的个数。

我们在每个节点上放的这种线段树又叫权值线段树。每个点放权值线段树可以帮助我们查区间查的更快。

改进4:可持久化的空间优化

可持久化是啥意思?

可以换个审视问题的方法:
我认为我其实是对这颗线段树做了6次插入操作。我现在希望这六次操作后得出来的新线段树都能保存下来,也就是要保存六个线段树。
想想看,这六颗线段树意义其实不一般,他就相当于是一种历史版本备份,每次操作后线段树长啥样都能保存备份下来。
确实,我们这里的需求其实就是想保存每次插入后的历史版本线段树。

嗯?我们当前造出来的数据结构,每个点都有颗线段树,下标是第i次操作,值是操作后得到的线段树。
其实我们现在这个数据结构,已经是一个可持久化的了!

其实现在我们离成功已经很接近了。主席树正式名称就叫做可持久化线段树,我们现在这个数据结构不就是一种可持久化的线段树嘛?

但我们的方法有个弊端,我们创建第二棵树的时候,没有更新的点我们也创建了一份。我们浪费了很多空间。

主席树用到了一种优化方式:没有变化的点,我们再连回去。那么这棵树现在就变成了这个样子。

优化完毕后,OK!其实目前我们存储的这个数据结构,就是大名鼎鼎的主席树了!

主席树

主席树就是一种这样的数据结构,他其实是每个点放了一个权值线段树,然后做了一定的空间优化得出来的数据结构。有了主席树,我们显然可以轻松解决[l,r]区间上第k大点是多少的问题了。

技术实现的细节上,主席树需要:

  1. 首先要离散化,方便存储,压缩值域。
  2. 建空树。
  3. 依次插入离散化后的题目输入数据。
  4. 做查询操作,反离散化输出结果。

主席树的代码实现细节:

struct {
    int l, r, sum;
} hjt[N<<6]; //主席树结点数组
//cnt表示当前有几个结点,rot表示上面图示中各个下标对应的主席树根节点编号
int cnt = 0, rot[N];

//建初始树
int build(int l, int r) {
    int rt = ++cnt;
    if (l == r) return rt;
    int mid = (l + r) >> 1;
    hjt[rt].l = build(l, mid);
    hjt[rt].r = build(mid + 1, r);
    return rt;
}

//加入一个数字k
//初始调用 insert(k, 1, n, 上一状态root);
int insert(int k, int l, int r, int root) {
    int idx = ++cnt;
    hjt[idx].l = hjt[root].l, hjt[idx].r = hjt[root].r, hjt[idx].sum = hjt[root].sum + 1;
    if (l == r) return idx;
    int mid = (l + r) >> 1;
    if (k <= mid) hjt[idx].l = insert(k, l, mid, hjt[idx].l);
    else hjt[idx].r = insert(k, mid + 1, r, hjt[idx].r);
    return idx;
}

//查询
//参数:u和v代表从哪棵树到哪棵树(根节点),l和r是当前搜索范围
//初始调用:查询[l,r]区间第k大: query(rot[l-1], rot[r], 1, len, k) 
int query(int u, int v, int l, int r, int k) {
    if (l >= r) return l;
    int mid = (l + r) >> 1;
    int x = hjt[hjt[v].l].sum - hjt[hjt[u].l].sum;
    if (k <= x) return query(hjt[u].l, hjt[v].l, l, mid, k);
    else return query(hjt[u].r, hjt[v].r, mid + 1, r, k - x);
}

当然,主席树用之前需要对数据离散化处理,这里需要离散化后第一个数据从1开始。

一个简单的离散化实现方案:

int a[N]; //源数据存这里,数据量为n,从1开始,方法调用完后存储为离散化后的结果
int p[N], len; //p数组是反离散化数组
//离散化有去重操作,len表示去重后还剩多少数据

void lisanhua() {
    sort(p + 1, p + n + 1);
    len = unique(p + 1, p + n + 1) - p - 1;
    for (int i = 1; i <= n; i++) {
        a[i] = lower_bound(p + 1, p + len + 1, a[i]) - p;
    }
}

以洛谷模板题P3834为例,上述代码这样用:

void solve() {
    //输入数据
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        p[i] = a[i];
    }

    //离散化
    lsh();

    //初始化主席树
    rot[0] = build(1, len);

    //插入数据
    for (int i = 1; i <= n; i++) {
        rot[i] = insert(a[i], 1, len, rot[i - 1]);
    }

    //查询操作
    while (m--) {
        int l, r, k;
        cin >> l >> r >> k;
        cout << p[query(rot[l - 1], rot[r], 1, len, k)] << '\n';
    }

}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    solve();
    return 0;
}
上一篇
下一篇