三维立体混元劲¶
题意¶
你身上一共有 n_1 + n_2 + \ldots + n_k 处穴位,其中一共有 n_i 处是第 i 维的穴位。打通某两个位于第 i 维和第 j 维的穴位的方案数为 a_{i,j} (同维的穴位之间也可以打通,但一个穴位不能与自己打通),求使得所有穴位连通的方案数。
解析¶
这道题要求计连通图的数量,我们考虑使用与 城市规划 类似的方法解决。
重申城市规划的做法。设 f_i 为 i 个点的有标号无向连通图的个数,g_i 为 i 个点的有标号无向图的个数,则有
之所以不是 \binom{n}{i} 是因为,我们用 f_i 枚举和 1 号点连通的子图,1 号点不参与编号的打乱。于是我们可以通过多项式求逆解出 f_i。当然也有更清真的做法:G(z)=\exp F(z)。
回到本题。
我们发现仅仅记录和 1 连通的穴位的个数,信息是不足的。但发现 \prod(n+1) 的范围较小,我们可以把每个和 1 连通的穴位状态压缩为一个 (n_1, n_2, \ldots, n_k) 进制数。设有 a_i 个 i 维穴位与 1 连通 (不包括 1 本身),则可以将状态表示为
并且我们发现,两个状态对应的数的和,恰好就是两个状态的并对应的数,因此可以直接照搬普通图的做法;但卷积时我们必须去掉进位的影响。设占位函数 \chi(S)=\sum\limits_{i=1}^{k-1} \dfrac{S}{\prod\limits_{j=1}^i(n_j+1)},则 \chi(A+B)=\chi(A)+\chi(B) 当且仅当 A+B 不进位。
我们设二元多项式 F(x, z)=f_ix^iz^{\chi(i)},则我们需要的不进位卷积满足 (F\ast G)(x, z)=\sum[z^{\chi(i)}]f_jg_{i-j}x^iz^{\chi(i)+\chi(i-j)}。这里 [z^i] 表示取多项式的 z^i 项。
由于我们只关心 \chi(i+j)-\chi(i)+\chi(j) 是否为 0,不关心它的具体值,而 \chi(i+j)-\chi(i)+\chi(j)\in[0, k) (每部分位最多通过增加 1),因此我们只需要计算 \bmod (z^{k}-1) 意义下的卷积,再手动取 z 的对应次数即可。
一种比较合理的实现是,对 i\in[0,k) 求 [z^i]F(x, z) 和 [z^i]G(x, z) DFT (也就是把二元多项式按 z 的次数拆开),然后在 z 这一维度上暴力卷积,得到 [z^i](F*G)(x, z),再合并回去。由于 z 这一维的的次数为 k-1,F(x,z) 只能拆为 k 个多项式分别 DFT,产生 O(kn\log n) 的复杂度;在 z 这一维卷积的复杂度为 O(k^2n),因此总复杂度为 O(kn\log n+k^2\log n)。
我们已经得到了一种优秀的 k 维多项式卷积。事实上这种卷积拓展性很好,参见 https://rushcheyo.blog.uoj.ac/blog/6547。
实现¶
#include <bits/stdc++.h>
typedef long long int ll;
typedef std::vector<int> poly;
const int maxn = 1 << 21, maxk = 19, mod = 998244353;
int qpow(int a, int b){
int res = 1;
while(b){
if(b & 1)
res = (ll)res * a % mod;
a = (ll)a * a % mod, b >>= 1;
}
return res;
}
int rev[maxn];
void make_rev(int n){
for(int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
}
int ntt_len;
void dft(int *f, int N, int b){
static int w[maxn];
ntt_len += N;
for(int i = 0; i < N; ++i)
if(i < rev[i])
std::swap(f[i], f[rev[i]]);
for(int i = 2; i <= N; i <<= 1){
w[0] = 1, w[1] = qpow(3, (mod - 1) / i);
if(b == -1) w[1] = qpow(w[1], mod - 2);
for(int j = 2; j < i / 2; ++j) w[j] = (ll)w[j - 1] * w[1] % mod;
for(int j = 0; j < N; j += i){
int *g = f + j, *h = f + j + i / 2;
for(int k = 0; k < i / 2; ++k){
int p = g[k], q = (ll)h[k] * w[k] % mod;
g[k] = (p + q) % mod, h[k] = (p - q) % mod;
}
}
}
if(b == -1){
for(int i = 0, inv = qpow(N, mod - 2); i < N; ++i)
f[i] = (ll)f[i] * inv % mod;
}
}
int N = 1, n[maxk], k, a[maxk][maxk];
int preproduct[maxk], chi[maxn];
int fact[maxn], ifact[maxn];
void init_fact(int n){
fact[0] = 1;
for(int i = 1; i <= n; ++i) fact[i] = (ll)fact[i - 1] * i % mod;
ifact[n] = qpow(fact[n], mod - 2);
for(int i = n - 1; i >= 0; --i) ifact[i] = (ll)ifact[i + 1] * (i + 1) % mod;
}
int SETEDSZ;
poly operator*(poly f, poly g){
int sz = 1;
if(SETEDSZ) sz = SETEDSZ;
else while(sz < (int)f.size() + (int)g.size() - 1) sz <<= 1;
make_rev(sz);
static int zf[maxk][maxn], zg[maxk][maxn], res[maxk][maxn];
for(int i = 0; i < k; ++i){
std::fill(zf[i], zf[i] + sz, 0);
std::fill(zg[i], zg[i] + sz, 0);
std::fill(res[i], res[i] + sz, 0);
}
for(int i = 0; i < (int)f.size(); ++i)
zf[chi[i]][i] = f[i];
for(int i = 0; i < (int)g.size(); ++i)
zg[chi[i]][i] = g[i];
for(int i = 0; i < k; ++i)
dft(zf[i], sz, 1), dft(zg[i], sz, 1);
for(int i = 0; i < k; ++i)
for(int j = 0; j < k; ++j){
int p = (i + j) % k;
for(int q = 0; q < sz; ++q)
res[p][q] = (res[p][q] + (ll)zf[i][q] * zg[j][q]) % mod;
}
for(int i = 0; i < k; ++i)
dft(res[i], sz, -1);
sz = std::min<int>(f.size() + g.size() - 1, N);
poly ans(sz);
for(int i = 0; i < sz; ++i)
ans[i] = res[chi[i]][i];
return ans;
}
poly inv(const poly &h){
poly f(1), g; f[0] = qpow(h[0], mod - 2);
for(int w = 2; w / 2 < N; w <<= 1){
g.resize(w);
for(int i = std::min(w, (int)h.size()) - 1; i >= 0; --i)
g[i] = h[i];
for(int i = w - 1; i >= (int)h.size(); --i)
g[i] = 0;
static poly t;
t = f * g; t.resize(w); t = t * f;
f.resize(w);
for(int i = 0; i < w; ++i) f[i] = (2ll * f[i] - t[i]) % mod;
}
f.resize(h.size());
return f;
}
poly deriv(const poly &f){
poly g(N);
for(int i = 0; i < N; ++i)
g[i] = (ll)f[i] * i % mod;
return g;
}
poly integ(const poly &f){
poly g(N); g[0] = 1;
for(int i = 0; i < N; ++i)
g[i] = (ll)f[i] * ifact[i] % mod * fact[i - 1] % mod;
return g;
}
poly G;
int main(){
std::scanf("%d", &k);
for(int i = 1; i <= k; ++i)
std::scanf("%d", n + i), N *= (n[i] + 1);
for(int i = 1; i <= k; ++i)
for(int j = 1; j <= k; ++j)
std::scanf("%d", &a[i][j]);
init_fact(N);
preproduct[0] = 1;
for(int i = 1; i <= k; ++i)
preproduct[i] = preproduct[i - 1] * (n[i] + 1);
G.resize(N);
for(int s = 0; s < N; ++s){
static int cnt[maxn];
for(int i = 1; i <= k; ++i) cnt[i] = s % preproduct[i] / preproduct[i - 1];
for(int i = 1; i < k; ++i)
chi[s] += s / preproduct[i];
chi[s] %= k, G[s] = 1;
if(s){
for(int i = 1; i <= k; ++i)
if(cnt[i]){
G[s] = (ll)G[s - preproduct[i - 1]] * qpow(a[i][i] + 1, cnt[i] - 1) % mod;
G[s] = (ll)G[s] * ifact[cnt[i]] % mod * fact[cnt[i] - 1] % mod;
for(int j = i + 1; j <= k; ++j)
G[s] = (ll)G[s] * qpow(a[i][j] + 1, cnt[j]) % mod;
break;
}
}
}
G = inv(G) * deriv(G), G.resize(N), G = integ(G);
int res = G[N - 1];
for(int i = 1; i <= k; ++i)
res = (ll)res * fact[n[i]] % mod;
std::printf("%d\n", ntt_len);
std::printf("%d\n", (res + mod) % mod);
return 0;
}