跳转至

异或

题意

给定一个序列 a_1, a_2, \ldots, a_n 和一个数 x。求序列有多少个非空子集,满足其中任意两个数的异或和大于等于 x

解析

建立序列的 01-Trie 树,从高位到低位考虑。设 dp_i 表示 01-Trie 树上,只选择 i 号节点的子树中的数时的方案数。

对于任意节点,左子树中的任何数和右子树中的任何数的异或和的这一位一定为 1。若 x 的这一位为 0,则左右子树的任意组合都是满足条件,答案为左右子树答案的积。

dp_i=(dp_{\text{left}}+1)(dp_\text{right}+1)-1

如果 x 的这一位为 1,则相同子树中的任何数的异或和的这一位一定为 0,小于 x。因此我们只能从左右子树中分别选至多 1 个数。直接查找即可。

实现

#include <bits/stdc++.h>

template <typename Tp>
void read(Tp &res){
    static char ch; ch = getchar(), res = 0;
    while(!std::isdigit(ch)) ch = getchar();
    while(std::isdigit(ch)) res = res * 10 + ch - 48, ch = getchar();
}

typedef long long int ll;
const int maxn = 3e5 + 19, maxsize = maxn * 60, mod = 998244353;

int T, n;
ll a[maxn], x;

int son[maxsize][2], size[maxsize], ind = 1;

void ins(ll x){
    int node = 1;
    ++size[node];
    for(int i = 59; i >= 0; --i){
        int &next = son[node][bool(x & (1ull << i))];
        if(!next) next = ++ind;
        node = next;
        ++size[node];
    }
}

ll st[maxn]; int top;
int szst[maxn];

void put_to_stack(int node, ll val, int b){
    if(!node) return;
    if(b == -1){
        st[++top] = val;
        szst[top] = size[node];
        return;
    }
    put_to_stack(son[node][1], val ^ (1ull << b), b - 1);
    put_to_stack(son[node][0], val, b - 1);
}

int query(int node, ll val, int b){
    if(!node) return 0;
    if(b == -1) return size[node];
    if(x & (1ull << b)) return query(son[node][1 ^ bool(val & (1ull << b))], val, b - 1);
    return query(son[node][bool(val & (1ull << b))], val, b - 1) + size[son[node][1 ^ bool(val & (1ull << b))]];
}

int dp[maxsize];

void dfs(int node, int b){
    if(!node) return;
    if(b == -1){
        dp[node] = size[node];
        return;
    }
    if(!(x & (1ull << b))){
        dfs(son[node][0], b - 1);
        dfs(son[node][1], b - 1);
        dp[node] = ((ll)(dp[son[node][0]] + 1) * (dp[son[node][1]] + 1) - 1) % mod;
        return;
    }
    dp[node] = size[node];
    put_to_stack(son[node][0], 0ll, b - 1);
    while(top){
        dp[node] = (dp[node] + (ll)szst[top] * query(son[node][1], st[top], b - 1)) % mod;
        --top;
    }
}

int main(){
    read(T), read(n), read(x);
    for(int i = 1; i <= n; ++i) read(a[i]);
    if(x == 0){
        int res = 1;
        for(int i = 1; i <= n; ++i) res = (res * 2) % mod;
        res = (res + mod - 1) % mod;
        std::printf("%d\n", res);
        return 0;
    }
    for(int i = 1; i <= n; ++i) ins(a[i]);

    dfs(1, 59);
    std::printf("%d\n", (dp[1] + mod) % mod);
}

来源

NOI.AC #2242 异或

评论