Untitled
unknown
plain_text
a year ago
12 kB
55
Indexable
const std = @import("std");
const assert = std.debug.assert;
pub fn Tree(comptime K: type, comptime V: type) type {
return struct {
root: ?*Node,
allocator: std.mem.Allocator,
const Self = @This();
pub fn init(allocator: std.mem.Allocator) Self {
return Self{
.root = null,
.allocator = allocator,
};
}
pub fn deinit(self: *Self) void {
if (self.root) |root| {
root.deinit();
}
}
pub fn get(self: *const Self, key: K) ?*const V {
if (self.root) |root| {
return root.get(key);
}
return null;
}
pub fn insert(self: *Self, key: K, value: V) ?V {
if (self.root) |root| {
const res = root.insert(key, value, self.allocator);
self.root = res.root;
return res.older_value;
} else {
self.root = Node.init(self.allocator, key, value);
return null;
}
}
pub fn remove(self: *Self, key: K) ?V {
if (self.root) |root| {
const res = root.remove(key);
self.root = res.root;
return res.older_value;
} else {
return null;
}
}
const BalanceState = enum {
LL,
LR,
RR,
RL,
Balance,
};
const Node = struct {
key: K,
value: V,
height: i32,
left: ?*Node,
right: ?*Node,
allocator: std.mem.Allocator,
fn init(allocator: std.mem.Allocator, key: K, value: V) *Node {
const node = allocator.create(Node) catch unreachable;
node.* = Node{
.key = key,
.value = value,
.height = 1,
.left = null,
.right = null,
.allocator = allocator,
};
return node;
}
fn deinit(node: *Node) void {
if (node.left) |left| {
left.deinit();
}
if (node.right) |right| {
right.deinit();
}
node.allocator.destroy(node);
}
fn get(node: *Node, key: K) ?*const V {
const order = std.mem.order(u8, node.key, key);
return switch (order) {
std.math.Order.eq => &node.value,
std.math.Order.lt => if (node.right) |r| r.get(key) else null,
std.math.Order.gt => if (node.left) |l| l.get(key) else null,
};
}
fn insert(node: *Node, key: K, value: V, allocator: std.mem.Allocator) struct { older_value: ?V, root: *Node } {
const order = std.mem.order(u8, node.key, key);
var older_value: ?V = null;
switch (order) {
std.math.Order.eq => older_value = node.value,
std.math.Order.lt => {
if (node.right) |n| {
const res = n.insert(key, value, allocator);
older_value = res.older_value;
node.right = res.root;
} else {
node.right = Node.init(allocator, key, value);
}
},
std.math.Order.gt => {
if (node.left) |n| {
const res = n.insert(key, value, allocator);
older_value = res.older_value;
node.left = res.root;
} else {
node.left = Node.init(allocator, key, value);
}
},
}
node.update_height();
return .{ .older_value = older_value, .root = rotate_if_not_balance(node) };
}
fn remove(node: *Node, key: K) struct { older_value: ?V, root: ?*Node } {
const res = node.remove_internal(key);
if (res.root != null) {
const root = res.root.?;
//std.debug.print("{any}\n", .{root});
root.update_height();
return .{ .older_value = res.older_value, .root = rotate_if_not_balance(root) };
} else {
return .{ .older_value = res.older_value, .root = null };
}
}
fn remove_internal(node: *Node, key: K) struct { older_value: ?V, root: ?*Node } {
const order = std.mem.order(u8, node.key, key);
switch (order) {
std.math.Order.eq => {
defer node.deinit();
if (node.left != null and node.right != null) {
return .{ .older_value = node.value, .root = merge_subtree(node.left.?, node.right.?) };
}
if (node.left != null and node.right == null) {
return .{ .older_value = node.value, .root = node.left };
}
if (node.left == null and node.right != null) {
//std.debug.print("fuck3\n{any}\n", .{node.right});
return .{ .older_value = node.value, .root = node.right };
}
return .{ .older_value = node.value, .root = null };
},
std.math.Order.gt => {
if (node.left) |left| {
const res = left.remove(key);
node.left = res.root;
return .{ .older_value = res.older_value, .root = node };
} else {
return .{ .older_value = null, .root = node };
}
},
std.math.Order.lt => {
if (node.right) |right| {
const res = right.remove(key);
node.right = res.root;
return .{ .older_value = res.older_value, .root = node };
} else {
return .{ .older_value = null, .root = node };
}
},
}
}
fn rotate_if_not_balance(node: *Node) *Node {
const state = node.get_balance_state();
return switch (state) {
BalanceState.LL => node.rotate_right(),
BalanceState.LR => node.rotate_left_right(),
BalanceState.RR => node.rotate_left(),
BalanceState.RL => node.rotate_right_left(),
BalanceState.Balance => node,
};
}
fn get_balance_state(node: *Node) BalanceState {
const l_height = get_height(node.left);
const r_height = get_height(node.right);
if (l_height - r_height >= 2) {
const l_node = node.left.?;
const ll_height = get_height(l_node.left);
const lr_height = get_height(l_node.right);
if (ll_height > lr_height) {
return BalanceState.LL;
} else {
return BalanceState.LR;
}
} else if (l_height - r_height <= -2) {
const r_node = node.right.?;
const rr_height = get_height(r_node.right);
const rl_height = get_height(r_node.left);
if (rr_height > rl_height) {
return BalanceState.RR;
} else {
return BalanceState.RL;
}
} else {
return BalanceState.Balance;
}
}
fn rotate_left_right(node: *Node) *Node {
const left = node.left.?;
node.left = rotate_left(left);
node.update_height();
return rotate_right(node);
}
fn rotate_right_left(node: *Node) *Node {
const right = node.right.?;
node.right = rotate_right(right);
node.update_height();
return rotate_left(node);
}
fn rotate_left(node: *Node) *Node {
const root = node.right.?;
node.right = root.left;
root.left = node;
node.update_height();
root.update_height();
return root;
}
fn rotate_right(node: *Node) *Node {
const root = node.left.?;
node.left = root.right;
root.right = node;
node.update_height();
root.update_height();
return root;
}
fn update_height(node: *Node) void {
const left = get_height(node.left);
const right = get_height(node.right);
if (left > right) {
node.height = left + 1;
} else {
node.height = right + 1;
}
}
fn get_height(node: ?*Node) i32 {
if (node) |n| {
//std.debug.print("{any}\n\n", .{n});
return n.height;
}
return 0;
}
fn merge_subtree(left: *Node, right: *Node) *Node {
const res = right.split_minimal_node();
const minimal = res.minimal;
minimal.left = left;
minimal.right = res.root;
minimal.update_height();
return res.minimal;
}
fn split_minimal_node(node: *Node) struct { minimal: *Node, root: ?*Node } {
if (node.left) |left| {
const res = left.split_minimal_node();
node.left = res.root;
node.update_height();
return .{ .minimal = res.minimal, .root = node };
} else {
return .{ .minimal = node, .root = null };
}
}
};
};
}
test "insert" {
var avltree = Tree([]const u8, i32).init(std.testing.allocator);
defer avltree.deinit();
assert(avltree.insert("A", 1) == null);
assert(avltree.insert("B", 1) == null);
assert(avltree.insert("A", 2) == 1);
assert(avltree.insert("B", 2) == 1);
}
test "get" {
var avl = Tree([]const u8, i32).init(std.testing.allocator);
defer avl.deinit();
assert(avl.get("A") == null);
assert(avl.insert("A", 1) == null);
assert(avl.insert("B", 2) == null);
assert(avl.insert("C", 3) == null);
assert(avl.insert("D", 4) == null);
assert(avl.get("A").?.* == 1);
assert(avl.get("B").?.* == 2);
assert(avl.get("C").?.* == 3);
assert(avl.get("D").?.* == 4);
}
test "remove" {
var avl = Tree([]const u8, i32).init(std.testing.allocator);
defer avl.deinit();
assert(avl.remove("A") == null);
assert(avl.insert("A", 1) == null);
assert(avl.insert("B", 2) == null);
assert(avl.insert("C", 3) == null);
assert(avl.insert("D", 4) == null);
assert(avl.remove("A").? == 1);
assert(avl.remove("B").? == 2);
assert(avl.remove("C").? == 3);
assert(avl.remove("D").? == 4);
}
Editor is loading...
Leave a Comment