2009-04-25 35 views
12

Bạn có biết thực hiện tốt một (nhị phân) segment tree trong Java không?Thực hiện phân đoạn cây java

+0

tương tự bài đăng: [IntervalTree Java Implementation] (http://stackoverflow.com/questions/1418150). –

+0

Đây là một triển khai của tôi, đó là mã nguồn mở có phạm vi truy vấn min/max/sum và cũng có các truy vấn stabbing khoảng thời gian. http://github.com/phishman3579/java-algorithms-implementation/blob/master/src/com/jwetherell/algorithms/data_structures/SegmentTree.java – Justin

Trả lời

3

này đã được thực hiện trong mã nguồn mở Layout Management SW Package project

Đây là một link to the sub package

Bạn có thể tìm thấy mã hữu ích. Tôi đã không xác minh nó cũng không chạy nó và tôi không thể tìm thấy giấy phép mã được cung cấp theo từ một tìm kiếm nhanh chóng của mã và trang web để Caveat Emptor.

Bạn có thể liên lạc với tác giả nhưng hoạt động cuối cùng dường như đã được tháng Tám năm 2008.

+0

Thông tin giấy phép tại đây: http://code.google.com/p/ layout-managment-sw-package/ –

8
public class SegmentTree { 
    public static class STNode { 
     int leftIndex; 
     int rightIndex; 
     int sum; 
     STNode leftNode; 
     STNode rightNode; 
    } 

    static STNode constructSegmentTree(int[] A, int l, int r) { 
     if (l == r) { 
      STNode node = new STNode(); 
      node.leftIndex = l; 
      node.rightIndex = r; 
      node.sum = A[l]; 
      return node; 
     } 
     int mid = (l + r)/2; 
     STNode leftNode = constructSegmentTree(A, l, mid); 
     STNode rightNode = constructSegmentTree(A, mid+1, r); 
     STNode root = new STNode(); 
     root.leftIndex = leftNode.leftIndex; 
     root.rightIndex = rightNode.rightIndex; 
     root.sum = leftNode.sum + rightNode.sum; 
     root.leftNode = leftNode; 
     root.rightNode = rightNode; 
     return root; 
    } 

    static int getSum(STNode root, int l, int r) { 
     if (root.leftIndex >= l && root.rightIndex <= r) { 
      return root.sum; 
     } 
     if (root.rightIndex < l || root.leftIndex > r) { 
      return 0; 
     } 
     return getSum(root.leftNode, l, r) + getSum(root.rightNode, l, r); 
    } 

    /** 
    * 
    * @param root 
    * @param index index of number to be updated in original array 
    * @param newValue 
    * @return difference between new and old values 
    */ 
    static int updateValueAtIndex(STNode root, int index, int newValue) { 
     int diff = 0; 
     if(root.leftIndex==root.rightIndex && index == root.leftIndex) { 
      // We actually reached to the leaf node to be updated 
      diff = newValue-root.sum; 
      root.sum=newValue; 
      return diff; 
     } 
     int mid = (root.leftIndex + root.rightIndex)/2; 
     if (index <= mid) { 
      diff= updateValueAtIndex(root.leftNode, index, newValue); 
     } else { 
      diff= updateValueAtIndex(root.rightNode, index, newValue); 
     } 
     root.sum+=diff; 
     return diff; 
    } 
} 
+0

'STNode rightNode = constructSegmentTree (A, giữa, r);' Chỉ nên sử dụng 'mid + 1' sau đó hoạt động. –

0

Algo and unit tests:

public class NumArrayTest { 

     @Test 
     public void testUpdateSumRange_WithEmpty() throws Exception { 
      NumArray numArray = new NumArray(new int[]{}); 
      assertEquals(0, numArray.sumRange(0, 0)); 
     } 

     @Test 
     public void testUpdateSumRange_WithSingleton() throws Exception { 
      NumArray numArray = new NumArray(new int[]{1}); 
      assertEquals(1, numArray.sumRange(0, 0)); 
      numArray.update(0, 2); 
      assertEquals(2, numArray.sumRange(0, 0)); 
     } 

     @Test 
     public void testUpdateSumRange_WithPairElements() throws Exception { 
      NumArray numArray = new NumArray(new int[]{1,2,3,4,5,6}); 
      assertEquals(12, numArray.sumRange(2, 4)); 
      numArray.update(3, 2); 
      assertEquals(10, numArray.sumRange(2, 4)); 
     } 

     @Test 
     public void testUpdateSumRange_WithInPairElements() throws Exception { 
      NumArray numArray = new NumArray(new int[]{1,2,3,4,5,6,7}); 
      assertEquals(12, numArray.sumRange(2, 4)); 
      numArray.update(3, 2); 
      assertEquals(10, numArray.sumRange(2, 4)); 
     } 
    } 



public class NumArray { 

    private final Node root; 

    private static class Node { 
     private final int begin; 
     private final int end; 
     private final Node left; 
     private final Node right; 
     private int sum; 

     public Node(int begin, int end, int sum, Node left, Node right) { 
      this.begin = begin; 
      this.end = end; 
      this.sum = sum; 
      this.left = left; 
      this.right = right; 
     } 

     public boolean isSingle() { 
      return begin == end; 
     } 

     public boolean contains(int i) { 
      return i >= begin && i <= end; 
     } 

     public boolean inside(int i, int j) { 
      return i <= begin && j >= end; 
     } 

     public boolean outside(int i, int j) { 
      return i > end || j < begin; 
     } 

     public void setSum(int sum) { 
      this.sum = sum; 
     } 
    } 

    public NumArray(int[] nums) { 
     if (nums.length == 0) { 
      root = null; 
     } else { 
      root = buildNode(nums, 0, nums.length - 1); 
     } 
    } 

    private Node buildNode(int[] nums, int begin, int end) { 
     if (begin == end) { 
      return new Node(begin, end, nums[begin], null, null); 
     } else { 
      int mid = (begin + end)/2 + 1; 
      Node left = buildNode(nums, begin, mid - 1); 
      Node right = buildNode(nums, mid, end); 
      return new Node(begin, end, left.sum + right.sum, left, right); 
     } 
    } 

    public void update(int i, int val) { 
     if (root == null) { 
      return; 
     } 
     if (!root.contains(i)) { 
      throw new IllegalArgumentException("i not in range"); 
     } 
     update(root, i, val); 
    } 

    private int update(Node node, int i, int val) { 
     if (node.isSingle()) { 
      node.setSum(val); 
     } else { 
      Node nodeToUpdate = node.left.contains(i) ? node.left : node.right; 
      int withoutNode = node.sum - nodeToUpdate.sum; 
      node.setSum(withoutNode + update(nodeToUpdate, i, val)); 
     } 
     return node.sum; 
    } 

    public int sumRange(int i, int j) { 
     if (root == null) { 
      return 0; 
     } 
     return sumRange(root, i, j); 
    } 

    private int sumRange(Node node, int i, int j) { 
     if (node.outside(i, j)) { 
      return 0; 
     } else if (node.inside(i, j)) { 
      return node.sum; 
     } else { 
      return sumRange(node.left, i, j) + sumRange(node.right, i, j); 
     } 
    } 

} 
0

Ở đây là:

import java.util.Scanner; 

public class MinimumSegmentTree { 

    static Scanner in = new Scanner(System.in); 

    public static void main(String[] args) { 
     final int n = in.nextInt(); 
     int[] a = new int[n]; 

     for (int i = 0; i < n; i++) { 
      a[i] = in.nextInt(); 
     } 

     int sizeOfSegmentTree = (int) Math.pow(2, Math.ceil(Math.log10(n)/Math.log10(2))); 
     sizeOfSegmentTree = 2*sizeOfSegmentTree-1; 

//  System.out.println(sizeOfSegmentTree); 

     int[] segmentTree = new int[sizeOfSegmentTree]; 
     formSegmentTree(a, segmentTree, 0, n-1, 0); 

//  for(int i=0; i<sizeOfSegmentTree; i++){ 
//   System.out.print(segmentTree[i]+" "); 
//  } 
//  System.out.println(); 

     final int q = in.nextInt(); 
     for (int i = 0; i < q; i++) { 
      int s, e; 
      s = in.nextInt(); 
      e = in.nextInt(); 

      int minOverRange = getMinimumOverRange(segmentTree, s, e, 0, n-1, 0); 
      System.out.println(minOverRange); 
     } 
    } 

    private static int getMinimumOverRange(int[] segmentTree, int qs, int qe, int s, int e, int pos) { 
     if (qs <= s && qe >= e) { 
      return segmentTree[pos]; 
     } 
     if (qs > e || s > qe) { 
      return 10000000; 
     } 

     int mid = (s + e)/2; 
     return Math.min(getMinimumOverRange(segmentTree, qs, qe, s, mid, 2 * pos + 1), 
       getMinimumOverRange(segmentTree, qs, qe, mid+1, e, 2 * pos + 2)); 
    } 

    private static void formSegmentTree(int[] a, int[] segmentTree, int s, int e, int pos) { 
     if (e - s == 0) { 
      segmentTree[pos] = a[s]; 
      return; 

     } 

     int mid = (s + e)/2; 

     formSegmentTree(a, segmentTree, s, mid, 2 * pos + 1); 
     formSegmentTree(a, segmentTree, mid+1, e, 2 * pos + 2); 

     segmentTree[pos] = Math.min(segmentTree[2 * pos + 1], segmentTree[2 * pos + 2]); 

    } 

} 
Các vấn đề liên quan