树状数组&线段树入门

预置知识

引入:

现给你一个长度为 $n$ 的数列 $A$,再输入 $q$ 个询问,每个询问都给出两个整数l,r。对于每个询问都要求给出对于数列 $A$ 在区间 $[l,r]$上的和(假设下标从0开始)

很容易想到能够使用前缀和 $pre[i]$ 来维护,每次询问直接输出 $pre[r]-pre[l-1]$ 即可,这个问题就愉快地解决了。

但是问题又来了,如果是 $q$ 个操作,这些操作中既有查询又有修改 (每次修改指定位置的值) 呢?这时候单纯用前缀和复杂度就会很高,因为每一次修改都会影响前缀和,对于每次修改都得重新计算一遍前缀和数组,那如何处理这个问题呢?

一、树状数组

树状数组是一种利用二进制特征进行检索的树状结构。

  • 初识树状数组

    就【引入】中的问题来初始树状数组,我们把它形式化一下

    长度为 $n$ 的数组 $a_i(i=1 \dots n)$ ,进行以下操作:

    1. 修改元素 $add(k,x):$ 把 $a_k$ 加上 $x$。
    2. 询问区间 $[l,r]$ 的和。我们可以用树状数组求和 $sum(x): ~sum=a_1+a_2+ \dots +a_x$。那么所求即是 $sum(r)-sum(l-1)$

      代码:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      int lowbit(int x)
      {
      return x&(-x);
      }

      void add(int x,int d)
      {
      while(x<=n){
      tree[x]+=d;
      x+=lowbit(x);
      }
      }

      int sum(int x)
      {
      int sum=0;
      while(x>0){
      sum+=tree[x];
      x-=lowbit(x);
      }
      return sum;
      }

      注: add() 和 sum() 的复杂度都是 O(logn)

      使用方法:(重点)

    3. 初始化。先清空 $tree$​ 数组,然后读取 $a_1,a_2, \dots ,a_n$​ ,用 $add(i,a_i)$ 逐一处理,得到一开始的 $tree$ 数组;实际上并不需要 $a$ 数组,因为它隐含在了 $tree$​ 数组中。

    4. 求区间和,输出 $sum(r)-sum(l-1)$
    5. 修改,执行 $add(i,d)$​。
  • 原理(简单介绍一下)

    问:为什么任何一个正整数(10进制)都可以唯一表示成用2的整数次幂相加?

    答:世界上只有 $10$ 种人,一种人懂二进制,一种人不懂二进制。

    树状数组采用的就是这个原理

    因为任意一个十进制数都可以转换成 $2$ 进制,那么二进制转换成 $10$ 进制的过程就是任何一个正整数表示成2的整数次幂和的过程。

    img

    lowbit(x) 原理

    $lowbit(x)$​ 返回 $x$​ 二进制形式的最后一个 $1$​ 的位置对应的幂级数。假设 $x=12=00001100_2$​​, $lowbit(x)=4=2^2$​

    $x=00001100, -x=x_补=11110100$​

    最后一个 $1$​ 前面进行 & 操作后,全是 $0$​ ,最后一个 $1$​ 后面进行 & 操作后也全是 $0$​ 只剩下那一个 $1$ 。

    主要原理

    假如要求 $sum(12):$​ $sum(12)=sum(8)+sum(4)$ ;就相当于 $sum(12)-lowbit(12)=8, sum(8)-lowbit(8)=0$

    假如要更新 $12$​​ 这个点 :$12+lowbit(12)=16, 16+lowbit(16)=32, 64, 128…$​​直到大于 $n$​​​,我的理解就是找到离他最近的那个 $2^k$​​,就是上图中的 ①,②,④,⑧ $\dots$​​​ 结点中,这样查询的过程中就可以把加到这个点的值累积到答案中了;不会重复,你自己模拟一下(查询过程中不会重复)。

    蓝书上说的更专业一点:$tree[x]$ 数组保存的是序列 $a$ 的区间 $[x-lowbit(x)+1,x]$ 这个区间中所有数的和,即 $\sum_{i=lowbit(x)+1}^{x}a[i]$

    该结构满足以下性质:(结合上图来看)

    1. 每个内部结点 $c[x]$ 保存以它为根的子树中所有叶子节点的和。
    2. 每个内部结点 $c[x]$ 的子节点个数等于 $lowbit(x)$ 的位数。
    3. 除树根外,每个内部节点 $c[x]$ 的父节点是 $c[x+lowbit(x)]$。

二、线段树

线段树是建立在区间基础上的树,树的每个结点代表一条线段 [l,r]

img

考查每个线段 $[l,r]$ ,$l$ 是左结点 ;$r$ 是右结点;

  1. $L=R$​ 说明这个结点代表的区间只有一个点,它就是一个子结点。
  2. $L<R$ 说明这个结点代表的不止一个点,他有两个儿子,左儿子代表的区间是 $[L,M]$,右儿子代表的区间是 $[M+1,R]$,其中 $M=(L+R)/2$。

线段树是二叉树,一个区间每次被折一半往下分,最多分 $log_2n$​​ 次到达最底层;当需要查找一个点或者区间的时候,顺着结点往下找,最多 $log_2n$​​​ 次就能找到。

  • 乘 $4$​​​ 原理 (下图不理解,周六给你们讲一下晖-辉定理🐕)

img

1
2
3
4
5
6
7
8
9
const int MAXM=1e5+10;
int sum[MAXN*4],L[MAXN*4],R[MAXN*4],MIN[MAXN*4],MAX[MAXN*4];//sum数组保存和,L区间左结点,R区间右结点,MIN区间最小值,MAX区间最大值。
//或者开一个结构体
struct node{
int sum,max,m
in,lr;
}tree[MAXN<<2];

int a[MAXM];

利用满二叉树建树

1
2
3
4
5
6
7
8
9
10
11
12
13
void build(int l,int r,int rt)
{
L[rt]=l,R[rt]=r;
if(l==r){
sum[rt]=a[l];MIN[rt]=a[l];MAX[rt]=a[l];return;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
MIN[rt]=min(MIN[rt<<1],MIN[rt<<1|1]);
MAX[rt]=max(MAX[rt<<1],MAX[rt<<1|1]);
}

单点修改

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void change(int x,int v,int rt)
{
if(L[rt]==R[rt]){
sum[rt]=v;MIN[rt]=v;MAX[rt]=v;
return;
}
int mid=(L[rt]+R[rt])>>1;
if(x<=mid) change(x,v,rt<<1);//x属于左半区间
if(r>mid) change(x,v,rt<<1|1);//x属于右半区间
//先递归儿子,再更新自己的东西。
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
MIN[rt]=min(MIN[rt<<1],MIN[rt<<1|1]);
MAX[rt]=max(MAX[rt<<1],MAX[rt<<1|1]);
}

区间查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
int ans=0;
//【l,r】是要查询的区间。
int ask(int l,int r,int rt)
{
//rt这个结点区间全部包含在要查询的区间中。
if(l<=L[rt]&&r>=R[rt]){
return sum[rt];
}
int mid=(L[rt]+R[rt])>>1;
if(l<=mid) ans+=ask(l,r,rt<<1);
if(r>mid) ans+=ask(l,r,rt<<1|1);
return ans;
}

int ask(int l,int r,int rt)
{
if(l<=L[rt]&&r>=R[rt]){
return MIN[rt];
}
int minn=0x3f3f3f3f;
int mid=(L[rt]+R[rt])>>1;
if(l<=mid) minn=min(minn,ask(l,r,rt<<1));
if(r>mid) minn=min(minn,ask(l,r,rt<<1|1));
return minn;
}

int ask(int l,int r,int rt)
{
if(l<=L[rt]&&r>=R[rt]){
return MAX[rt];
}
int maxx=-0x3f3f3f3f;
int mid=(L[rt]+R[rt])>>1;
if(l<=mid) maxx=max(maxx,ask(l,r,rt<<1));
if(r>mid) maxx=max(maxx,ask(l,r,rt<<1|1));
return maxx;
}

到这里其实已经实现和树状数组完全一样的功能了,虽然常数比树状数组大,而且代码复杂度也 还可以
那线段树存在的意义在那里呢?没错,它可以区间修改,还可以各种神奇操作

区间修改

如果只是采用单点修改的话,时间复杂度非但没有减小,然而还会增加 $log_2n$ 的复杂度,变成 $nlog_2n$ 。

线段树牛逼就牛逼到这个地方——它采用一种 $lazy-tag$ 方法。

lazy-tag

这个方法叫做 “ 懒标记 ” 。当修改的是一个整块区间的时候,只对这个线段区间进行整体上的修改,其内部每个元素的内容先不做修改,只有当这部分线段的一致性被破坏的时候才把变化值传递给子区间。那么,每次区间修改的复杂度是 $log_2n$​ 。做 $lazy$​ 操作的子区间需要记录状态。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
int add[MAXN<<2];

void push_down(int rt,int m)
{
if(add[rt]){
add[rt<<1]+=add[rt];//把懒惰标记传给左儿子
add[rt<<1|1]+=add[rt];//传给右儿子
sum[rt<<1]+=(m-(m>>1))*add[rt];//更新左儿子和右儿子的东西
sum[rt<<1|1]+=(m>>1)*add[rt];
add[rt]=0;
}
}

void update(int a,int b,int c,int l,int r,int rt)//给 [a,b] 区间加上 c ,[l,r]是rt结点代表的区间
{
if(a<=l&&b>=r){
sum[rt]+=(r-l+1)*c;
add[rt]+=c;
return;
}
push_down(rt,r-l+1);//下传懒惰标记
int mid=(l+r)>>1;
//把[a,b]区间分开计算。
if(a<=mid) update(a,b,c,l,mid,rt<<1);
if(b>mid) update(a,b,c,mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}

int query(int a,int b,int l,int r,int rt)
{
if(a<=l&&b>=r) return sum[rt];
push_down(rt,r-l+1);//这里需要往下传懒惰标记
int mid=(l+r)>>1;
long long ans=0;
if(a<=mid) ans+=query(a,b,l,mid,rt<<1);
if(b>mid) ans+=query(a,b,mid+1,r,rt<<1|1);
return ans;
}

好了,我们来看一道例题吧

POJ 3468 “A Simple Problem with Integers”

给出 $N$ 个数,进行 $Q$ 个操作,$1 \leq N,Q \leq1e5 $.

有如下两种操作:

“C a b c”,对区间 $[a,b]$ 的每个数字加 c。

”Q a b”,查询区间 $[a,b]$ 数字和。

对每个查询操作,输出结果。

蒟蒻的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<cstdio>
#include<iostream>
#include<algorithm>

using namespace std;

const int MAXN=1e5+10;
typedef long long ll;
ll sum[MAXN<<2],add[MAXN<<2];

void push_down(int rt,int m)
{
if(add[rt]){
add[rt<<1]+=add[rt];
add[rt<<1|1]+=add[rt];
sum[rt<<1]+=(m-(m>>1))*add[rt];
sum[rt<<1|1]+=(m>>1)*add[rt];
add[rt]=0;
}
}

void build(int l,int r,int rt)
{
add[rt]=0;
if(l==r){
scanf("%lld",&sum[rt]);
return;
}
int mid=(l+r)>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}

void update(int a,int b,ll c,int l,int r,int rt)
{
if(a<=l&&b>=r){
sum[rt]+=(r-l+1)*c;
add[rt]+=c;
return;
}
push_down(rt,r-l+1);
int mid=(l+r)>>1;
if(a<=mid) update(a,b,c,l,mid,rt<<1);
if(b>mid) update(a,b,c,mid+1,r,rt<<1|1);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}

ll query(int a,int b,int l,int r,int rt)
{
if(a<=l&&b>=r) return sum[rt];
push_down(rt,r-l+1);
int mid=(l+r)>>1;
long long ans=0;
if(a<=mid) ans+=query(a,b,l,mid,rt<<1);
if(b>mid) ans+=query(a,b,mid+1,r,rt<<1|1);
return ans;
}


int main()
{
int n,m;
scanf("%d%d",&n,&m);
build(1,n,1);
while(m--){
char str[2];
int a,b;long long c;
scanf("%s",str);
if(str[0]=='C'){
scanf("%d%d%lld",&a,&b,&c);
update(a,b,c,1,n,1);
}
else{
scanf("%d%d",&a,&b);
printf("%lld\n",query(a,b,1,n,1));
}
}
}

到这里,你已经入门了,快去刷题和继续学习各种和线段树相关的高级数据结构吧