串串划分¶
题意¶
给定一个字符串 s,将其划分为 k (k > 1) 个子串 s=s_1s_2\ldots s_k,满足:
- \forall i\in[1,k],s_i 不是循环串
- \forall i\in [1,k-1],s_i \neq s_{i+1}
求方案数。
解析¶
如果我们每次强制将一个字符串拆成最小循环,那么最终得到的划分一定满足条件 1。而条件二我们可以通过容斥满足。
设 dp_i 表示 s 的长度为 i 的前缀的满足所有两种条件的划分方案,f(t) 表示 t 含有的最小循环的个数,则有转移方程
dp_i=\sum_{j=0}^{i-1}(-1)^{f(s[j+1,i])-1}dp_j
可以这样理解这个容斥:g_k=\sum\limits_{f(s[j+1,i])=k}dp_j 表示最后一个子串重复 k 次或 k+1 次的方案数 (由于 j 之前不存在重复串,只能和 j 尾部的子串相等,导致多重复 1 次)。设最后一个子串恰好重复 i 次的方案数为 h_i,根据推导有 g_k=h_k+h_{k+1},反演一下得到 h_1=\sum\limits_i(-1)^{i-1}g_i。
现在我们要求一个字符串包含的循环次数。参考 WC2019 课件,primitive square 的数量不超过 O(|s|\log |s|) 级别,记录前缀和转移即可。
实现¶
#include <bits/stdc++.h>
typedef unsigned long long int ull;
const int maxn = 2e5 + 19, mod = 998244353;
struct hash_engine{
ull key[maxn], basep[maxn], base;
void init(const char *s, int n){
basep[0] = 1ull, base = 233ull;
for(int i = 1; i <= n; ++i){
key[i] = key[i - 1] * base + (ull)s[i];
basep[i] = basep[i - 1] * base;
}
}
ull operator()(const int &l, const int &r){
if(l > r)
return 0ull;
return key[r] - key[l - 1] * basep[r - l + 1];
}
}mhash;
std::vector<int> x[maxn];
std::vector<std::vector<int> > g;
std::vector<std::pair<std::pair<int, int>, int> > runs;
char s[maxn];
int n, lyn[maxn], dp[maxn], tot;
int st[maxn], top;
int find_prev(int a, int b){
int l = 0, r = std::min(a, b);
while(l < r){
int mid = (l + r + 1) >> 1;
if(mhash(a - mid + 1, a) == mhash(b - mid + 1, b))
l = mid;
else
r = mid - 1;
}
return l;
}
int find_next(int a, int b){
int l = 0, r = std::min(n - a + 1, n - b + 1);
while(l < r){
int mid = (l + r + 1) >> 1;
if(mhash(a, a + mid - 1) == mhash(b, b + mid - 1))
l = mid;
else
r = mid - 1;
}
return l;
}
void lyndon(int opt){
if(opt) s[n + 1] = 'a' - 1;
else s[n + 1] = 'z' + 1;
st[0] = n + 1, top = 0;
for(int i = n; i >= 1; --i){
while(top){
int x = find_next(i, st[top]);
if(opt ? s[i + x] < s[st[top] + x] : s[i + x] > s[st[top] + x])
--top;
else
break;
}
lyn[i] = st[top] - 1, st[++top] = i;
}
}
int main(){
std::scanf("%s", s + 1), n = std::strlen(s + 1);
mhash.init(s, n);
for(int opt = 0; opt < 2; ++opt){
lyndon(opt);
for(int i = 1; i <= n; ++i){
int l = i, r = lyn[i], p = r - l + 1;
std::pair<int, int> mr(l - find_prev(l - 1, r), r + find_next(l, r + 1));
if(p * 2 <= mr.second - mr.first + 1)
runs.push_back(std::make_pair(mr, p));
}
}
std::sort(runs.begin(), runs.end());
runs.resize(std::unique(runs.begin(), runs.end()) - runs.begin());
g.resize(runs.size());
for(int i = 0; i < (int)runs.size(); ++i){
int l = runs[i].first.first, r = runs[i].first.second, p = runs[i].second;
g[i].resize(r - l - p * 2 + 2);
for(int j = l + p * 2 - 1; j <= r; ++j)
x[j].push_back(i);
}
dp[0] = 1, tot = 1;
for(int i = 1; i <= n; ++i){
dp[i] = tot;
for(int j = 0; j < (int)x[i].size(); ++j){
int k = x[i][j];
int l = runs[k].first.first, p = runs[k].second;
if(i - l - 4 * p + 1 >= 0)
g[k][i - l - 2 * p + 1] = (g[k][i - l - 4 * p + 1] + dp[i - 2 * p]) % mod;
else
g[k][i - l - 2 * p + 1] = dp[i - 2 * p];
dp[i] = (dp[i] - 2ll * g[k][i - l - 2 * p + 1]) % mod;
}
tot = (tot + dp[i]) % mod;
}
std::printf("%d\n", (dp[n] + mod) % mod);
return 0;
}