1/* Based on musl's src/search/tsearch.c, by Szabolcs Nagy.
2 * See LICENSE file for copyright details. */
3#include <stdlib.h>
4#include <string.h>
5#include "tree.h"
6#include "util.h"
7
8#define MAXH (sizeof(void *) * 8 * 3 / 2)
9
10void
11deltree(struct treenode *n, void delkey(void *), void delval(void *))
12{
13 if (!n)
14 return;
15 if (delkey)
16 delkey(n->key);
17 if (delval)
18 delval(n->value);
19 deltree(n->child[0], delkey, delval);
20 deltree(n->child[1], delkey, delval);
21 free(n);
22}
23
24static inline int
25height(struct treenode *n)
26{
27 return n ? n->height : 0;
28}
29
30static int
31rot(struct treenode **p, struct treenode *x, int dir /* deeper side */)
32{
33 struct treenode *y = x->child[dir];
34 struct treenode *z = y->child[!dir];
35 int hx = x->height;
36 int hz = height(z);
37
38 if (hz > height(y->child[dir])) {
39 /*
40 * x
41 * / \ dir z
42 * A y / \
43 * / \ --> x y
44 * z D /| |\
45 * / \ A B C D
46 * B C
47 */
48 x->child[dir] = z->child[!dir];
49 y->child[!dir] = z->child[dir];
50 z->child[!dir] = x;
51 z->child[dir] = y;
52 x->height = hz;
53 y->height = hz;
54 z->height = hz + 1;
55 } else {
56 /*
57 * x y
58 * / \ / \
59 * A y --> x D
60 * / \ / \
61 * z D A z
62 */
63 x->child[dir] = z;
64 y->child[!dir] = x;
65 x->height = hz + 1;
66 y->height = hz + 2;
67 z = y;
68 }
69 *p = z;
70 return z->height - hx;
71}
72
73static int
74balance(struct treenode **p)
75{
76 struct treenode *n = *p;
77 int h0 = height(n->child[0]);
78 int h1 = height(n->child[1]);
79
80 if (h0 - h1 + 1u < 3u) {
81 int old = n->height;
82 n->height = h0 < h1 ? h1 + 1 : h0 + 1;
83 return n->height - old;
84 }
85 return rot(p, n, h0 < h1);
86}
87
88struct treenode *
89treefind(struct treenode *n, const char *key)
90{
91 int c;
92
93 while (n) {
94 c = strcmp(key, n->key);
95 if (c == 0)
96 return n;
97 n = n->child[c > 0];
98 }
99 return NULL;
100}
101
102void *
103treeinsert(struct treenode **rootp, char *key, void *value)
104{
105 struct treenode **a[MAXH], *n = *rootp, *r;
106 void *old;
107 int i = 0, c;
108
109 a[i++] = rootp;
110 while (n) {
111 c = strcmp(key, n->key);
112 if (c == 0) {
113 old = n->value;
114 n->value = value;
115 return old;
116 }
117 a[i++] = &n->child[c > 0];
118 n = n->child[c > 0];
119 }
120 r = xmalloc(sizeof(*r));
121 r->key = key;
122 r->value = value;
123 r->child[0] = r->child[1] = NULL;
124 r->height = 1;
125 /* insert new node, rebalance ancestors. */
126 *a[--i] = r;
127 while (i && balance(a[--i]))
128 ;
129 return NULL;
130}