异或¶
题意¶
给定一个序列 a_1, a_2, \ldots, a_n 和一个数 x。求序列有多少个非空子集,满足其中任意两个数的异或和大于等于 x。
解析¶
建立序列的 01-Trie 树,从高位到低位考虑。设 dp_i 表示 01-Trie 树上,只选择 i 号节点的子树中的数时的方案数。
对于任意节点,左子树中的任何数和右子树中的任何数的异或和的这一位一定为 1。若 x 的这一位为 0,则左右子树的任意组合都是满足条件,答案为左右子树答案的积。
dp_i=(dp_{\text{left}}+1)(dp_\text{right}+1)-1
如果 x 的这一位为 1,则相同子树中的任何数的异或和的这一位一定为 0,小于 x。因此我们只能从左右子树中分别选至多 1 个数。直接查找即可。
实现¶
#include <bits/stdc++.h>
template <typename Tp>
void read(Tp &res){
static char ch; ch = getchar(), res = 0;
while(!std::isdigit(ch)) ch = getchar();
while(std::isdigit(ch)) res = res * 10 + ch - 48, ch = getchar();
}
typedef long long int ll;
const int maxn = 3e5 + 19, maxsize = maxn * 60, mod = 998244353;
int T, n;
ll a[maxn], x;
int son[maxsize][2], size[maxsize], ind = 1;
void ins(ll x){
int node = 1;
++size[node];
for(int i = 59; i >= 0; --i){
int &next = son[node][bool(x & (1ull << i))];
if(!next) next = ++ind;
node = next;
++size[node];
}
}
ll st[maxn]; int top;
int szst[maxn];
void put_to_stack(int node, ll val, int b){
if(!node) return;
if(b == -1){
st[++top] = val;
szst[top] = size[node];
return;
}
put_to_stack(son[node][1], val ^ (1ull << b), b - 1);
put_to_stack(son[node][0], val, b - 1);
}
int query(int node, ll val, int b){
if(!node) return 0;
if(b == -1) return size[node];
if(x & (1ull << b)) return query(son[node][1 ^ bool(val & (1ull << b))], val, b - 1);
return query(son[node][bool(val & (1ull << b))], val, b - 1) + size[son[node][1 ^ bool(val & (1ull << b))]];
}
int dp[maxsize];
void dfs(int node, int b){
if(!node) return;
if(b == -1){
dp[node] = size[node];
return;
}
if(!(x & (1ull << b))){
dfs(son[node][0], b - 1);
dfs(son[node][1], b - 1);
dp[node] = ((ll)(dp[son[node][0]] + 1) * (dp[son[node][1]] + 1) - 1) % mod;
return;
}
dp[node] = size[node];
put_to_stack(son[node][0], 0ll, b - 1);
while(top){
dp[node] = (dp[node] + (ll)szst[top] * query(son[node][1], st[top], b - 1)) % mod;
--top;
}
}
int main(){
read(T), read(n), read(x);
for(int i = 1; i <= n; ++i) read(a[i]);
if(x == 0){
int res = 1;
for(int i = 1; i <= n; ++i) res = (res * 2) % mod;
res = (res + mod - 1) % mod;
std::printf("%d\n", res);
return 0;
}
for(int i = 1; i <= n; ++i) ins(a[i]);
dfs(1, 59);
std::printf("%d\n", (dp[1] + mod) % mod);
}