1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
| #include <iostream> #include <algorithm>
int idx = 0, root = 0; int n;
struct node { int s[2]; int v; int cnt; int size; int p;
void init(int v1, int p1) { v = v1; cnt = 1; p = p1; size = 1; } } tr[100010];
void push_up(int k) { tr[k].size = tr[tr[k].s[0]].size + tr[tr[k].s[1]].size + tr[k].cnt; }
void rotate(int x) { int y = tr[x].p, z = tr[y].p; int k = tr[y].s[1] == x; tr[y].s[k] = tr[x].s[k ^ 1]; tr[tr[x].s[k ^ 1]].p = y; tr[x].s[k ^ 1] = y; tr[y].p = x; tr[z].s[tr[z].s[1] == y] = x; tr[x].p = z; push_up(y), push_up(x); }
void splay(int x, int k) { while (tr[x].p != k) { int y = tr[x].p, z = tr[y].p; if (z != k) (tr[y].s[0] == x) ^ (tr[x].s[0] == y) ? rotate(x) : rotate(y); rotate(x); } if (k == 0)root = x; }
void insert(int v) { int x = root, p = 0; while (x && tr[x].v != v) p = x, x = tr[x].s[v > tr[x].v]; if (x)tr[x].cnt++; else { x = ++idx; tr[p].s[v > tr[p].v] = x; tr[x].init(v, p); } splay(x, 0); }
int get_val(int k)//查询排名k的节点 { int x = root; while (1) { int y = tr[x].s[0]; if (tr[y].size + tr[x].cnt < k) { k -= tr[y].size + tr[x].cnt; x = tr[x].s[1]; } else { if (tr[y].size >= k)x = y; else break; } } splay(x, 0); return tr[x].v; }
void find(int v) {//查找v所在的节点 并把节点转到根 int x = root; while (tr[x].s[v > tr[x].v] && v != tr[x].v) x = tr[x].s[v > tr[x].v]; splay(x, 0); }
int get_rank(int v) { find(v); return tr[tr[root].s[0]].size; }
int get_pre(int v) {//获取v的前驱 find(v); int x = root; if (tr[x].v < v)return x; x = tr[x].s[0]; while (tr[x].s[1])x = tr[x].s[1]; return x; }
int get_suc(int v) {//获取x的后继 find(v); int x = root; if (tr[x].v > v) return x; x = tr[x].s[1]; while (tr[x].s[0])x = tr[x].s[0]; return x; }
void del(int v) {//删除数字v若有多个相同只删除一个 int pre = get_pre(v); int suc = get_suc(v); splay(pre, 0), splay(suc, pre); int del = tr[suc].s[0]; if (tr[del].cnt > 1) tr[del].cnt--, splay(del, 0); else tr[suc].s[0] = 0, splay(suc, 0);
}
void output(int k) {//中序遍历输出 if (tr[k].s[0]) output(tr[k].s[0]); if (tr[k].v <= n && tr[k].v >= 1) std::cout << tr[k].v << " "; if (tr[k].s[1]) output(tr[k].s[1]); }
signed main() { std::ios::sync_with_stdio(false); std::cin.tie(0); int n; std::cin >> n; insert(100000000), insert(-100000000); while (n--) { int op, x; std::cin >> op >> x; switch (op) { case 1: insert(x); break; case 2: del(x); break; case 3: std::cout << get_rank(x) << "\n"; break; case 4: std::cout << get_val(x + 1) << "\n"; break; case 5: std::cout << tr[get_pre(x)].v << "\n"; break; default: std::cout << tr[get_suc(x)].v << "\n"; } } }
|