Yuanchieh's Blog

Yuanchieh's Blog

生命是長期而持續的累積

23 May 2022

【算法】Segment Tree 與 Binary Indexed Tree 解題整理

當需要在某一陣列中,求某一段區間的數值和或是最小值,如果是靜態資料,也就是陣列內容不會再改變,我們可以用 prefix sum 在 constant time 取得結果

但如果陣列的值會改變,就需要每次都重新計算 prefix sum,此時的時間複雜度會是 O(N),有沒有更快的方法呢?

這邊有兩個相似的樹狀結構 Segment Tree / Binary Indexed Tree (又稱 Fenwick Tree) 可以用 O(logN) 解決動態區間和的問題,其中 Segment Tree 可以更廣泛解決區間極值的問題

相關題目

  1. 307. Range Sum Query - Mutable
  2. 308. Range Sum Query 2D - Mutable
  3. 315. Count of Smaller Numbers After Self
  4. 327. Count of Range Sum

Segment Tree

教學影片:Segment Tree Data Structure - Min Max Queries

Segment Tree 與 BIT 的概念雷同,原本我用 prefix sum 遇到更新時要用 O(n) 整個重建,但如果我把區間切小,每次更新只要影響到部分區間,對應的讀取要篩選符合的區間讀取,妥協後 讀取與更新都控制在 log(N),但區間該怎麼切以及如何實作呢? 這就是 Segment Tree 與 BIT 不同之處

Segment Tree 用陣列儲存區間值,需要兩倍額外記憶體空間,原陣列放在新陣列的最後,接著往前跟新區間 (parent = idx/2),如下圖 (從影片截圖而來)
所以區間是 2 -> 4 -> 8 這樣往上疊加

初始化程式碼為

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class SegmentTree {
public:
    SegmentTree(const std::vector<int>& nums) {
        offset_ = nums.size();
        nodes_.resize(offset_ * 2, 0);
        // 先把原陣列放在最後
        for (int i = 0; i < offset_; i++) {
            nodes_[i + offset_] = nums[i];
        }
        // parent = 左 child + 右 child
        for (int i = offset_ - 1; i > 0; i--) {
            nodes_[i] = nodes_[i * 2] + nodes_[i * 2 + 1];
        }
    }

    void update(int index, int val) {
        // 因為有移動,所以要加上 offset_
        int nodeIdx = index + offset_;
        int diff = val - nodes_[nodeIdx];
        // 更新時要更新全部的 parent
        while (nodeIdx > 0) {
            nodes_[nodeIdx] += diff;
            nodeIdx /= 2;
        }
    }
private:
    int offset_;
    std::vector<int> nodes_;
};

透過 O(N) 即可完成初始化,我們將原陣列放在最後,並往上疊加出多個區間,Update 需要 O(logN),因為要往前把相關的區間都要更新一次

接著重點是 range query,傳入 left / right (閉區間) 如何在 O(logN) 解決

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
int sumRange(int left, int right) {
    int nodeLeftIndex = left + offset_;
    int nodeRightIndex = right + offset_;
    int count = 0;
    while (nodeLeftIndex <= nodeRightIndex) {
        // 如果是左指針,且指到區間右側,當下取值
        if ((nodeLeftIndex & 1) == 1) {
            count += nodes_[nodeLeftIndex];
            nodeLeftIndex++;
        }
        // 如果是右指針,且指到區間左側,當下取值
        if ((nodeRightIndex & 1) == 0) {
            count += nodes_[nodeRightIndex];
            nodeRightIndex--;
        }
        
        nodeLeftIndex /= 2;
        nodeRightIndex /= 2;
    } 
    
    return count;
}

這一段程式碼有點神奇,用圖示大致如下

左右分別原陣列長度為 2 , 3 時,如果想要查詢整段區間的指標移動,我們可以觀察出一個重點

偶數 index 都在區間的左側 / 奇數 index 是在區間的右側

因為我們是用 left / right 去找區間和,所以我們要找left 往右找與 right 往左找的共同區間,總共分成 4 種情況考慮

  1. left 指向偶數:
    代表我們會拿整個區間,因為區間的左側是偶數,當目前 left 就指向偶數,代表要取出整個區間
  2. left 指向奇數:
    代表我們不可以拿區間當作代表,因為奇數是區間的右側,再往下就到另一個區間,所以我們要直接取值
  3. right 指向偶數:
    right 跟 left 邏輯剛好相反,我們只能拿 right 往左的區間,因為偶數是 parent 左側,再往下移就到下個區間,所以要拿當前值
  4. right 指向奇數:
    因為奇數是區間的右側,代表我們可以拿整個區間為當前值

這部分非常 tricky,簡而言之就是當指針指向的位置可以包含整段區間,則往上指向區間,直到無法包含整段區間則取值,break down 成 2, 3 的案例就比較好理解

影片參考資料是左閉右開的計算方式,但這樣我覺得再取出區間和比較不好做,所以參考 leetcode 解答調整成目前的閉區間算法

307. Range Sum Query - Mutable 完整解法

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SegmentTree {
public:
    SegmentTree(const std::vector<int>& nums) {
        offset_ = nums.size();
        nodes_.resize(offset_ * 2, 0);
        for (int i = 0; i < offset_; i++) {
            nodes_[i + offset_] = nums[i];
        }
        for (int i = offset_ - 1; i > 0; i--) {
            nodes_[i] = nodes_[i * 2] + nodes_[i * 2 + 1];
        }
    }
    
    void update(int index, int val) {
        int nodeIdx = index + offset_;
        int diff = val - nodes_[nodeIdx];
        while (nodeIdx > 0) {
            nodes_[nodeIdx] += diff;
            nodeIdx /= 2;
        }
    }
    
    int sumRange(int left, int right) {
        int nodeLeftIndex = left + offset_;
        int nodeRightIndex = right + offset_;
        int count = 0;
        while (nodeLeftIndex <= nodeRightIndex) {
            if ((nodeLeftIndex & 1) == 1) {
                count += nodes_[nodeLeftIndex];
                nodeLeftIndex++;
            }
            if ((nodeRightIndex & 1) == 0) {
                count += nodes_[nodeRightIndex];
                nodeRightIndex--;
            }
            
            nodeLeftIndex /= 2;
            nodeRightIndex /= 2;
        } 
        
        return count;
    }
private:
    int offset_;
    std::vector<int> nodes_;
};

class NumArray {
public:
    NumArray(vector<int>& nums) {
        tree_ = new SegmentTree(nums);
    }
    
    void update(int index, int val) {
        tree_->update(index, val);
    }
    
    int sumRange(int left, int right) {
        return tree_->sumRange(left, right);
    }
private:
    SegmentTree* tree_;
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray* obj = new NumArray(nums);
 * obj->update(index,val);
 * int param_2 = obj->sumRange(left,right);
 */

Binary Index Tree

1994 年的論文 A New Data Structure for Cumulative Frequency Tables / 我覺得講得很好的影片 Fenwick Tree (Binary Index Tree) - Quick Tutorial and Source Code Explanation

這張圖是從論文截圖而來,實作技巧非常巧妙,他利用 Last Significant Bit (LSB) 來決定區間的範圍,如果 LSB 是 xxx1,則只儲存當前一個數,如果 LSB 是 xx10,則儲存當前兩個數,以此類推,所以可以看到 1, 3 等只會儲存當前 1 個數、8 會儲存往前 8 個數

所以更新時會需要更新所有相關區間,上面這張圖是代表當你更新 index i 時,需要往上調整的 bit index,例如更新 idx 1 時,因為 bit[2]、bit[4]、bit[8] 都有包含 idx 1,所以都要一併更新

實作方面非常簡單,透過 2 補數 i & -i 即可取得 LSB

1
2
3
4
5
6
7
8
9
// 1: 0001
//-1: 1111
// 1 & -1 => 0001
// 6: 0110
//-6: 1010
// 6 & -6 => 0010
int getParent(int i) {
    return i + (i & -i);
}

讓我們看查詢會變得如何: 圖片表達如果你要某個 prefix sum,你必須往前輪詢的 index,例如要找 idx 1~9 的 prefix sum,則需要 bit[9] + bit[8],搭配上一張圖 bit[9] 只有儲存 idx 9 這個元素,而 bit[8] 儲存了 idx 1-8 個元素

實作方面同樣透過 2 補數,只是變成往下減

1
2
3
int getNextInterval(int i) {
    return i - (i & -i);
}

整體實作

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class BIT {
public:
    BIT(const std::vector<int>& nums) {
        int size = nums.size();
        // index 從 1 開始比較好計算
        bit_.resize(size + 1, 0);
        arr_.resize(size, 0);
        
        // 初始化只要往下一個 parent 加
        // 後面 iterate 會疊上去
        for (int i = 0; i < size; i++) {
            int bitIdx = i + 1;
            bit_[bitIdx] += nums[i];
            arr_[i] = nums[i];
            int parent = getParent(bitIdx);
            if (parent < bit_.size()) {
                bit_[parent] += bit_[bitIdx];
            }
        }
    }
    
    void update(int index, int val) {
        int diff = val - arr_[index];
        arr_[index] = val;

        index++;
        // 更新記得要全部包含的區間都更新
        while (index < bit_.size()) {
            bit_[index] += diff;
            index = getParent(index);
        }
    }
    
    int prefixSum(int index) {
        int count = 0;
        index++;
        // 取值要往前推
        while (index > 0) {
            count += bit_[index];
            index = getNextInterval(index);
        }
        return count;
    }
private:
    std::vector<int> bit_;
    std::vector<int> arr_;
    
    int getParent(int i) {
        return i + (i & -i);
    }
    
    int getNextInterval(int i) {
        return i - (i & -i);
    }
};

class NumArray {
public:
    NumArray(vector<int>& nums) {
        tree_ = new BIT(nums);
    }
    
    void update(int index, int val) {
        tree_->update(index, val);
    }
    
    int sumRange(int left, int right) {
        return tree_->prefixSum(right) - tree_->prefixSum(left - 1);
    }
private:
    BIT* tree_;
};

小結:BIT vs Segment Tree

整理一下兩者

比較
操作 Segment Tree Binary Index Tree
記憶體空間 2 * n n
初始化 O(n) O(n)
查詢 O(logN) O(logN)
更新 O(logN) O(logN)

此外,使用上兩者有共同侷限 新增 / 移除元素需要重新初始化

兩者差異

兩者都可以解決 #307 這道題目,但看似相同但還是有差異之處,簡單來說 Segment Tree 用途更廣,BIT 只能解決 Prefix Sum 計算

例如 Segment Tree 還可以解決區間最小值/區間最大值,而 BIT 是做不到的,為什麼?

因為 BIT 並不是儲存每一個值,而是在初始化就以區間的形式保存,如果是加法這種 invertible 算法,意即我區間儲存 (a+b), 我可以透過 (a+b) - a 還原 b 的值;
但求極值是 non invertable,如果要用 BIT 求極值,那麼在區間計算時就用儲存極值,這樣更新時就會出錯

1
2
3
4
5
原陣列:[3, 2]
BIT: [, 3, 3(保存區間極值)]
update (0, 1)

BIT: [, 1, ?] => 無法計算

所以 BIT 用途比較侷限,但優點是記憶體空間小,而且 bitwise 的計算速度會快更多

有一篇論文寫可以用兩個 BIT 實作區間極值的查詢,結果還是比 Segment Tree 快上許多,可以參考看看 Efficient Range Minimum Queries - using Binary Indexed Trees

Categories