我不会说我是为了学博弈才去做这题的……
题意:
给定一棵N个点的树,1号点为根,每个节点是白色或者黑色。双方轮流操作,每次选择一个白色节点,将从这个点到根的路径上的点全部染成黑色。问先手是否必胜,以及第一步可选节点有哪些。N<=100000。
分析:
首先是博弈方面的分析。令SG[x]为,只考虑以x为根的子树时的SG值。令g[x]为,只考虑以x为根的子树时,所有后继局面的SG值的集合。那么SG[x]=mex{g[x]}。
我们考虑怎么计算g[x]。假设x的儿子为v1,v2,...,vk,令sum[x]=SG[v1] xor SG[v2] xor .. xor SG[vk]。考虑两种情况:
1、x为黑色。不难发现以x的每个儿子为根的子树是互相独立的。假设这一步选择了vi子树的某一个节点,那么转移到的局面的SG值就是sum[x] xor SG[vi] xor (在g[vi]中的某个值)。那么我们只需将每个g[vi]整体xor上sum[x] xor SG[vi]再合并到g[x]即可。
2、x为白色。这时候我们多了一种选择,即选择x点。可以发现,选择x点之后x点变成黑色,所有子树仍然独立,而转移到的局面的SG值就是sum[x]。如果此时不选择x而是选择x子树里的某个白色节点,那么x一样会被染成黑色,所有子树依然独立。所以x为白色时只是要向g[x]中多插入一个值sum[x]。
这样我们就有一个自底向上的DP了。朴素的复杂度是O(N^2)的。
接下来再考虑第一步可选的节点。我们要考虑选择哪些节点之后整个局面的SG值会变成0。假设我们选择了x点,那么从x到根的路径都会被染黑,将原来的树分成了一堆森林。我们令up[x]为,不考虑以x为根的子树,将从x到根的路径染黑,剩下的子树的SG值的xor和。那么up[x]=up[fa[x]] xor sum[fa[x]] xor sg[x],其中fa[x]为x的父亲节点编号。那么如果点x初始颜色为白色且up[x] xor sum[x]=0,那么这个点就是第一步可选的节点。这一步是O(N)的。
剩下的就是优化求SG了。我们需要一个可以快速整体xor并合并的数据结构。整体xor可以用二进制Trie打标记实现,至于合并,用启发式合并是O(Nlog^2N)的,而用线段树合并的方法可以做到O(NlogN)。不过还需要注意各种常数的问题……比如不要用指针,Trie的节点不用记大小,只要记是否满了……
做这题的时候先去膜拜了主席的题解……然后又去膜拜了主席冬令营的讲课……最后还去膜拜了翱犇的代码……然后几乎是照着抄了一遍……
代码:(SPOJ上排到了倒数第三……)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146//SPOJ11414; COT3; Game Theory + Trie Merging #include <cstdio> #include <iostream> #include <algorithm> #include <ctime> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef unsigned int uint; typedef long double ld; #define pair(x, y) make_pair(x, y) #define runtime() ((double)clock() / CLOCKS_PER_SEC) #define N 100000 #define LOG 17 struct edge { int next, node; } e[N << 1 | 1]; int head[N + 1], tot = 0; inline void addedge(int a, int b) { e[++tot].next = head[a]; head[a] = tot, e[tot].node = b; } #define SIZE 2000000 struct Node { int l, r; bool full; int d; } tree[SIZE + 1]; #define l(x) tree[x].l #define r(x) tree[x].r #define d(x) tree[x].d #define full(x) tree[x].full int root[N + 1], tcnt = 0; int n, col[N + 1], sg[N + 1], sum[N + 1], up[N + 1]; bool v[N + 1]; inline int newnode() { return ++tcnt; } inline void update(int x) { full(x) = full(l(x)) && full(r(x)); } inline void push(int x) { if (d(x)) { if (l(x)) d(l(x)) ^= d(x) >> 1; if (r(x)) d(r(x)) ^= d(x) >> 1; if (d(x) & 1) swap(l(x), r(x)); d(x) = 0; } } int merge(int l, int r) { if (!l || full(r)) return r; if (!r || full(l)) return l; push(l), push(r); int ret = newnode(); l(ret) = merge(l(l), l(r)); r(ret) = merge(r(l), r(r)); update(ret); return ret; } inline int rev(int x) { int r = 0; for (int i = LOG; i > 0; --i) if (x >> i - 1 & 1) r += 1 << LOG - i; return r; } void insert(int x, int v, int p) { push(x); if (v >> p - 1 & 1) { if (!r(x)) r(x) = newnode(); if (p != 1) insert(r(x), v, p - 1); else full(r(x)) = true; } else { if (!l(x)) l(x) = newnode(); if (p != 1) insert(l(x), v, p - 1); else full(l(x)) = true; } update(x); } int mex(int x) { int r = 0; for (int i = LOG; x; --i) { push(x); if (full(l(x))) r += 1 << i - 1, x = r(x); else x = l(x); } return r; } void calc(int x) { v[x] = true; int xorsum = 0; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; calc(node); v[node] = false; xorsum ^= sg[node]; } for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; d(root[node]) ^= rev(xorsum ^ sg[node]); root[x] = merge(root[x], root[node]); } if (!col[x]) insert(root[x], xorsum, LOG); sg[x] = mex(root[x]); sum[x] = xorsum; } int ans[N + 1], cnt = 0; void find(int x) { v[x] = true; if ((up[x] ^ sum[x]) == 0 && col[x] == 0) ans[++cnt] = x; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; up[node] = up[x] ^ sum[x] ^ sg[node]; find(node); } } int main(int argc, char* argv[]) { #ifdef KANARI freopen("input.txt", "r", stdin); freopen("output.txt", "w", stdout); #endif scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", col + i); for (int i = 1; i < n; ++i) { static int x, y; scanf("%d%d", &x, &y); addedge(x, y), addedge(y, x); } for (int i = 1; i <= n; ++i) root[i] = newnode(); calc(1); for (int i = 1; i <= n; ++i) v[i] = false; find(1); if (cnt == 0) printf("-1n"); else { sort(ans + 1, ans + cnt + 1); for (int i = 1; i <= cnt; ++i) printf("%dn", ans[i]); } // cerr << runtime() << endl; // for (int i = 1; i <= n; ++i) printf("%d ", sg[i]); fclose(stdin); fclose(stdout); return 0; }
顺便贴一个指针的,感觉长得更好看:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173// #include <cstdio> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #include <climits> #include <cmath> #include <utility> #include <set> #include <map> #include <queue> #include <ios> #include <iomanip> #include <ctime> #include <numeric> #include <functional> #include <fstream> #include <sstream> #include <string> #include <vector> #include <bitset> #include <cstdarg> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef unsigned int uint; typedef long double ld; #define pair(x, y) make_pair(x, y) #define runtime() ((double)clock() / CLOCKS_PER_SEC) inline int read() { static int r; static char c; r = 0, c = getchar(); while (c < '0' || c > '9') c = getchar(); while (c >= '0' && c <= '9') r = r * 10 + (c - '0'), c = getchar(); return r; } template <typename T> inline void print(T *a, int n) { for (int i = 1; i < n; ++i) cout << a[i] << " "; cout << a[n] << endl; } #define PRINT(__l, __r, __begin, __end) { for (int __i = __begin; __i != __end; ++__i) cout << __l __i __r << " "; cout << endl; } #define N 100000 #define LOG 17 struct edge { int next, node; } e[N << 1 | 1]; int head[N + 1], tot = 0; inline void addedge(int a, int b) { e[++tot].next = head[a]; head[a] = tot, e[tot].node = b; } struct Node { Node *l, *r; bool full; int d; Node() { l = r = NULL, full = false, d = 0; } } *root[N + 1]; int n, col[N + 1], sg[N + 1], sum[N + 1], up[N + 1]; bool v[N + 1]; inline void update(Node *x) { if (x->l && x->r) x->full = x->l->full && x->r->full; else x->full = false; } inline void applyDelta(Node *x, int v) { x->d ^= v; } inline void push(Node *x) { if (x->d) { if (x->l) applyDelta(x->l, x->d >> 1); if (x->r) applyDelta(x->r, x->d >> 1); if (x->d & 1) swap(x->l, x->r); x->d = 0; } } Node* merge(Node *l, Node *r) { if (l == NULL || (r != NULL && r->full)) return r; if (r == NULL || (l != NULL && l->full)) return l; push(l), push(r); Node *ret = new Node(); ret->l = merge(l->l, r->l); ret->r = merge(l->r, r->r); update(ret); return ret; } inline int rev(int x) { int r = 0; for (int i = LOG; i > 0; --i) if (x >> i - 1 & 1) r += 1 << LOG - i; return r; } void insert(Node *x, int v, int p) { push(x); if (v >> p - 1 & 1) { if (x->r == NULL) x->r = new Node(); if (p != 1) insert(x->r, v, p - 1); else x->r->full = true; } else { if (x->l == NULL) x->l = new Node(); if (p != 1) insert(x->l, v, p - 1); else x->l->full = true; } update(x); } int mex(Node *x) { int r = 0; for (int i = LOG; x != NULL; --i) { push(x); if (x->l && x->l->full) r += 1 << i - 1, x = x->r; else x = x->l; } return r; } void calc(int x) { v[x] = true; int xorsum = 0; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; calc(node); v[node] = false; xorsum ^= sg[node]; } for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; applyDelta(root[node], rev(xorsum ^ sg[node])); root[x] = merge(root[x], root[node]); } if (!col[x]) insert(root[x], xorsum, LOG); sg[x] = mex(root[x]); sum[x] = xorsum; } int ans[N + 1], cnt = 0; void find(int x) { v[x] = true; if ((up[x] ^ sum[x]) == 0 && col[x] == 0) ans[++cnt] = x; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; up[node] = up[x] ^ sum[x] ^ sg[node]; find(node); } } int main(int argc, char* argv[]) { #ifdef KANARI freopen("input.txt", "r", stdin); freopen("output.txt", "w", stdout); #endif scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", col + i); for (int i = 1; i < n; ++i) { static int x, y; scanf("%d%d", &x, &y); addedge(x, y), addedge(y, x); } for (int i = 1; i <= n; ++i) root[i] = new Node(); calc(1); for (int i = 1; i <= n; ++i) v[i] = false; find(1); if (cnt == 0) printf("-1n"); else { sort(ans + 1, ans + cnt + 1); for (int i = 1; i <= cnt; ++i) printf("%dn", ans[i]); } // for (int i = 1; i <= n; ++i) printf("%d ", sg[i]); fclose(stdin); fclose(stdout); return 0; }
最后
以上就是欢呼花生最近收集整理的关于SPOJ COT3的全部内容,更多相关SPOJ内容请搜索靠谱客的其他文章。
发表评论 取消回复