diff --git a/src/datastructures/min_max.rs b/src/datastructures/min_max.rs index e58890f861fadc3ec51906dfd81375aac02bf2d5..70336b6e9dc967a0ba107b41fc40bf7aafd5cb92 100644 --- a/src/datastructures/min_max.rs +++ b/src/datastructures/min_max.rs @@ -159,15 +159,15 @@ impl MinMax { } } - fn parent(index: u64) -> u64 { + fn parent(&self, index: usize) -> usize { (index - 1) / 2 } - fn left_child(index: u64) -> u64 { + fn left_child(&self, index: usize) -> usize { 2 * index + 1 } - fn right_child(index: u64) -> u64 { + fn right_child(&self, index: usize) -> usize { 2 * index + 2 } @@ -350,6 +350,48 @@ impl MinMax { pub fn enclose(&self, index: u64) -> Result<u64, NodeError> { self.bwd_search(index, 2) } + + pub fn rank_1(&self, index: u64) -> Result<u64, NodeError> { + if index >= self.bits.len() { + Err(NodeError::NotANodeError) + } else { + let block_no = (index / self.block_size); + let begin_of_block = block_no * self.block_size; + let mut rank = 0; + + // Count 1s in the last block + for k in begin_of_block..=index { + if self.bits[k] { + rank += 1; + } + } + + let mut current_node = ((self.heap.len() / 2) as u64 + block_no) as usize; + // multiplier * block_size: number of bits belonging to heap node + let mut multiplier = 1; + + while current_node > 0 { + let old_node = current_node; + current_node = self.parent(current_node); + if self.left_child(current_node) != old_node { + // (excess of node + number of bits for node)/2 = number of 1-bits for node + rank += (self.heap[self.left_child(current_node)].excess + + (multiplier * self.block_size) as i64) / 2; + } + multiplier *= 2; + } + + Ok(rank as u64) + } + } + + pub fn rank_0(&self, index: u64) -> Result<u64, NodeError> { + let result = (index - self.rank_1(index).unwrap()) as i64; + if result < 0 { + return Err(NodeError::NotANodeError); + } + Ok(index - self.rank_1(index).unwrap() + 1) + } } #[derive(Clone, Debug, Default, Serialize, Deserialize)] @@ -487,4 +529,28 @@ mod tests { assert_eq!(min_max.enclose(6).unwrap(), 1); } + #[test] + // #[ignore] + fn test_rank_1() { + let bits1 = bit_vec![ + true, true, true, false, true, false, true, true, false, false, false, true, false, + true, true, true, false, true, false, false, false, false + ]; + let min_max = MinMax::new(bits1, 4); + assert_eq!(min_max.rank_1(11).unwrap(), 7); + assert_eq!(min_max.rank_1(21).unwrap(), 11); + } + + #[test] + fn test_rank_0() { + let bits1 = bit_vec![ + true, true, true, false, true, false, true, true, false, false, false, true, false, + true, true, true, false, true, false, false, false, false + ]; + let min_max = MinMax::new(bits1, 4); + assert_eq!(min_max.rank_0(12).unwrap(), 6); + assert_eq!(min_max.rank_0(17).unwrap(), 7); + assert_eq!(min_max.rank_0(21).unwrap(), 11); + } + }