Untitled
unknown
plain_text
6 months ago
12 kB
39
Indexable
const std = @import("std"); const assert = std.debug.assert; pub fn Tree(comptime K: type, comptime V: type) type { return struct { root: ?*Node, allocator: std.mem.Allocator, const Self = @This(); pub fn init(allocator: std.mem.Allocator) Self { return Self{ .root = null, .allocator = allocator, }; } pub fn deinit(self: *Self) void { if (self.root) |root| { root.deinit(); } } pub fn get(self: *const Self, key: K) ?*const V { if (self.root) |root| { return root.get(key); } return null; } pub fn insert(self: *Self, key: K, value: V) ?V { if (self.root) |root| { const res = root.insert(key, value, self.allocator); self.root = res.root; return res.older_value; } else { self.root = Node.init(self.allocator, key, value); return null; } } pub fn remove(self: *Self, key: K) ?V { if (self.root) |root| { const res = root.remove(key); self.root = res.root; return res.older_value; } else { return null; } } const BalanceState = enum { LL, LR, RR, RL, Balance, }; const Node = struct { key: K, value: V, height: i32, left: ?*Node, right: ?*Node, allocator: std.mem.Allocator, fn init(allocator: std.mem.Allocator, key: K, value: V) *Node { const node = allocator.create(Node) catch unreachable; node.* = Node{ .key = key, .value = value, .height = 1, .left = null, .right = null, .allocator = allocator, }; return node; } fn deinit(node: *Node) void { if (node.left) |left| { left.deinit(); } if (node.right) |right| { right.deinit(); } node.allocator.destroy(node); } fn get(node: *Node, key: K) ?*const V { const order = std.mem.order(u8, node.key, key); return switch (order) { std.math.Order.eq => &node.value, std.math.Order.lt => if (node.right) |r| r.get(key) else null, std.math.Order.gt => if (node.left) |l| l.get(key) else null, }; } fn insert(node: *Node, key: K, value: V, allocator: std.mem.Allocator) struct { older_value: ?V, root: *Node } { const order = std.mem.order(u8, node.key, key); var older_value: ?V = null; switch (order) { std.math.Order.eq => older_value = node.value, std.math.Order.lt => { if (node.right) |n| { const res = n.insert(key, value, allocator); older_value = res.older_value; node.right = res.root; } else { node.right = Node.init(allocator, key, value); } }, std.math.Order.gt => { if (node.left) |n| { const res = n.insert(key, value, allocator); older_value = res.older_value; node.left = res.root; } else { node.left = Node.init(allocator, key, value); } }, } node.update_height(); return .{ .older_value = older_value, .root = rotate_if_not_balance(node) }; } fn remove(node: *Node, key: K) struct { older_value: ?V, root: ?*Node } { const res = node.remove_internal(key); if (res.root != null) { const root = res.root.?; //std.debug.print("{any}\n", .{root}); root.update_height(); return .{ .older_value = res.older_value, .root = rotate_if_not_balance(root) }; } else { return .{ .older_value = res.older_value, .root = null }; } } fn remove_internal(node: *Node, key: K) struct { older_value: ?V, root: ?*Node } { const order = std.mem.order(u8, node.key, key); switch (order) { std.math.Order.eq => { defer node.deinit(); if (node.left != null and node.right != null) { return .{ .older_value = node.value, .root = merge_subtree(node.left.?, node.right.?) }; } if (node.left != null and node.right == null) { return .{ .older_value = node.value, .root = node.left }; } if (node.left == null and node.right != null) { //std.debug.print("fuck3\n{any}\n", .{node.right}); return .{ .older_value = node.value, .root = node.right }; } return .{ .older_value = node.value, .root = null }; }, std.math.Order.gt => { if (node.left) |left| { const res = left.remove(key); node.left = res.root; return .{ .older_value = res.older_value, .root = node }; } else { return .{ .older_value = null, .root = node }; } }, std.math.Order.lt => { if (node.right) |right| { const res = right.remove(key); node.right = res.root; return .{ .older_value = res.older_value, .root = node }; } else { return .{ .older_value = null, .root = node }; } }, } } fn rotate_if_not_balance(node: *Node) *Node { const state = node.get_balance_state(); return switch (state) { BalanceState.LL => node.rotate_right(), BalanceState.LR => node.rotate_left_right(), BalanceState.RR => node.rotate_left(), BalanceState.RL => node.rotate_right_left(), BalanceState.Balance => node, }; } fn get_balance_state(node: *Node) BalanceState { const l_height = get_height(node.left); const r_height = get_height(node.right); if (l_height - r_height >= 2) { const l_node = node.left.?; const ll_height = get_height(l_node.left); const lr_height = get_height(l_node.right); if (ll_height > lr_height) { return BalanceState.LL; } else { return BalanceState.LR; } } else if (l_height - r_height <= -2) { const r_node = node.right.?; const rr_height = get_height(r_node.right); const rl_height = get_height(r_node.left); if (rr_height > rl_height) { return BalanceState.RR; } else { return BalanceState.RL; } } else { return BalanceState.Balance; } } fn rotate_left_right(node: *Node) *Node { const left = node.left.?; node.left = rotate_left(left); node.update_height(); return rotate_right(node); } fn rotate_right_left(node: *Node) *Node { const right = node.right.?; node.right = rotate_right(right); node.update_height(); return rotate_left(node); } fn rotate_left(node: *Node) *Node { const root = node.right.?; node.right = root.left; root.left = node; node.update_height(); root.update_height(); return root; } fn rotate_right(node: *Node) *Node { const root = node.left.?; node.left = root.right; root.right = node; node.update_height(); root.update_height(); return root; } fn update_height(node: *Node) void { const left = get_height(node.left); const right = get_height(node.right); if (left > right) { node.height = left + 1; } else { node.height = right + 1; } } fn get_height(node: ?*Node) i32 { if (node) |n| { //std.debug.print("{any}\n\n", .{n}); return n.height; } return 0; } fn merge_subtree(left: *Node, right: *Node) *Node { const res = right.split_minimal_node(); const minimal = res.minimal; minimal.left = left; minimal.right = res.root; minimal.update_height(); return res.minimal; } fn split_minimal_node(node: *Node) struct { minimal: *Node, root: ?*Node } { if (node.left) |left| { const res = left.split_minimal_node(); node.left = res.root; node.update_height(); return .{ .minimal = res.minimal, .root = node }; } else { return .{ .minimal = node, .root = null }; } } }; }; } test "insert" { var avltree = Tree([]const u8, i32).init(std.testing.allocator); defer avltree.deinit(); assert(avltree.insert("A", 1) == null); assert(avltree.insert("B", 1) == null); assert(avltree.insert("A", 2) == 1); assert(avltree.insert("B", 2) == 1); } test "get" { var avl = Tree([]const u8, i32).init(std.testing.allocator); defer avl.deinit(); assert(avl.get("A") == null); assert(avl.insert("A", 1) == null); assert(avl.insert("B", 2) == null); assert(avl.insert("C", 3) == null); assert(avl.insert("D", 4) == null); assert(avl.get("A").?.* == 1); assert(avl.get("B").?.* == 2); assert(avl.get("C").?.* == 3); assert(avl.get("D").?.* == 4); } test "remove" { var avl = Tree([]const u8, i32).init(std.testing.allocator); defer avl.deinit(); assert(avl.remove("A") == null); assert(avl.insert("A", 1) == null); assert(avl.insert("B", 2) == null); assert(avl.insert("C", 3) == null); assert(avl.insert("D", 4) == null); assert(avl.remove("A").? == 1); assert(avl.remove("B").? == 2); assert(avl.remove("C").? == 3); assert(avl.remove("D").? == 4); }
Editor is loading...
Leave a Comment