0%

[20200215模拟赛] 学习

一半看到这种代价为 \(k\) 次方的形式,且需要转移复杂度的题,都是用第二类斯特林数解决的

第二类斯特林数我们记作 \(\begin{Bmatrix} n \\\\ m \end{Bmatrix}\) 表示将 \(n\) 个相互作区分的元素放入 \(m\) 个不做区分的集合里,每个集合元素非空的方案数,我们很容易就能有多项式得到一个恒等式

我们不妨计 \(\begin{Bmatrix} n \\\\ m \end{Bmatrix}\times m!\)egf\(f\),即集合做区分的母函数,考虑另一个 egf\(g_n\) 满足

\[ g_n=\sum_{i}i^n\frac{x^i}{i!} \] 很显然 \([x^m]g_n\) 的意义就是我们要求做区分的集合可以为空的方案数的生成函数,那么显然有

\[ e^xf=g_n \]

我们就可以得到

\[ m^n=\sum_{i=0}^{\min(n,m)}\binom{m}{i}\begin{Bmatrix} n \\\\ i \end{Bmatrix}\times i! \]

所以代入这道题要求的式子

\[ \sum_{\sum_{i=1}^mx_i=n}\prod_{i=1}^mx_i^{k_i} \]

化简得到

\[ \sum_{\sum_{i=1}^mx_i=n}\prod_{i=1}^m\sum_{j=0}^{\min(x_i,k_i)}\binom{x_i}{j}\begin{Bmatrix} k_i \\\\ j \end{Bmatrix}\times j! \]

注意到 \(\sum_{i=1}^mk_i\le1\times10^5\),所以我们考虑求这样的多项式

\[ \sum_{\sum_{i=1}^mx_i=n}\prod_{i=1}^m\sum_{j=0}^{\min(x_i,k_i)}\binom{x_i}{j}\begin{Bmatrix} k_i \\\\ j \end{Bmatrix}\times j!\ x^j \]

那么我们就可以将最内侧的和式移到外面

\[ \sum_j\sum_{\sum_{i=1}^mc_i=j}\sum_{\sum_{i=1}^mx_i=n}\prod_{i=1}^m\binom{x_i}{c_i}\prod_{i=1}^m\begin{Bmatrix} k_i \\\\ c_i \end{Bmatrix}\times c_i!\ x^j \]

我们注意到中间的

\[ \sum_{\sum_{i=1}^mx_i=n}\prod_{i=1}^m\binom{x_i}{c_i} \]

实际上是有组合意义的,与隔板法非常类似,我们想象有个长度为 \(n+m-1\) 的序列,我们从中选出 \(m-1+\sum_{i=1}^mc_i\) 个元素,对于每一种选择方案我们这样构造,先是 \(c_1\) 个元素,再算上 \(1\) 个隔板依次类推,所以答案的多项式又可以简化为

\[ \sum_j\sum_{\sum_{i=1}^mc_i=j}\binom{n+m-1}{m-1+j}\prod_{i=1}^m\begin{Bmatrix} k_i \\\\ c_i \end{Bmatrix}\times c_i!\ x^j \]

而实际上我们只需要多项式

\[ \prod_{i=1}^m\sum_{j=0}^{k_i}\begin{Bmatrix} k_i \\\\ j \end{Bmatrix}\times j!\ x^j \]

再算上相应的系数即可

那么用最开始斯特林数的生成函数的等式

那么为

\[ \prod_{i=1}^m\sum_{j=0}^{k_i}j!\times[x^j]{(}e^{-x}\times g_{k_i}{)}\ x^j \]

实际上后面的 \(e^x\)\(g_{k_i}\) 只需要求到 \(x^{k_i}\) 即可,那么整个式子可以用分治 FFT 优化,用一个类似于线段树的结构求出

\[ \prod_{i=l}^r\sum_{j=0}^{k_i}j!\times[x^j]{(}e^{-x}\times g_{k_i}{)}\ x^j \]

每一层总的项数和为 \(S=\sum_{i=1}^mk_i\) 所以复杂度为 \(\operatorname{O}(S\log^2S)\)

其他类似题目可参照 [BZOJ 5093] 图的价值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
using namespace std;
typedef long long int64;
inline int read(int f = 1, int x = 0, char ch = ' ')
{
while(!isdigit(ch = getchar())) if(ch == '-') f = -1;
while(isdigit(ch)) x = x*10+ch-'0', ch = getchar();
return f*x;
}
const int N = 8e5+5, M = 1e7+5+1e5, P = 998244353;
int64 w[2][N], fac[M], ifac[M]; int lim, rev[N];
int64 qpow(int64 a, int b)
{
int64 ret = 1;
for( ; b; b >>= 1, a = a*a%P) if(b&1) ret = ret*a%P;
return ret;
}
void prepare(int ti)
{
for(lim = 1; lim <= ti; lim <<= 1);
int64 g = qpow(3, (P-1)/lim);
w[0][0] = w[1][0] = w[0][lim] = w[1][lim] = 1;
for(int i = 1, j = lim>>1; i < lim; ++i)
{
w[0][i] = w[1][lim-i] = w[0][i-1]*g%P, rev[i] = j;
for(int k = lim>>1; (j ^= k) < k; k >>= 1);
}
}
struct Poly
{
vector<int> A;
int& operator [] (const int i) { return A[i]; }
void set(int ti) { A.resize(ti+1); }
int ti() { return A.size()-1; }
void NTT(int t)
{
if(!t) A.resize(lim);
for(int i = 0; i < lim; ++i) if(rev[i] > i) swap(A[rev[i]], A[i]);
for(int mid = 1; mid < lim; mid <<= 1)
for(int len = mid<<1, j = 0; j < lim; j += len)
for(int k = 0, p = 0, q = lim/len; k < mid; ++k, p += q)
{
int x = A[j+k], y = A[j+k+mid]*w[t][p]%P;
A[j+k] = (x+y)%P, A[j+k+mid] = (x-y+P)%P;
}
if(!t) return; int64 v = qpow(lim, P-2);
for(int i = 0; i < lim; ++i) A[i] = A[i]*v%P;
}
friend Poly operator * (Poly A, Poly B)
{
int n = A.ti(), m = B.ti(); prepare(n+m), A.NTT(0), B.NTT(0);
for(int i = 0; i < lim; ++i) A[i] = 1ll*A[i]*B[i]%P;
return A.NTT(1), A.set(n+m), A;
}
}F;
int n, m, s, a[N], ans;
int64 C(int n, int m) { return 0 <= m&&m <= n?fac[n]*ifac[m]%P*ifac[n-m]%P:0; }
Poly solve(int l, int r)
{
if(l == r)
{
Poly A, B; A.set(a[l]), B.set(a[l]);
for(int i = 0; i <= a[l]; ++i) A[i] = i&1?P-ifac[i]:ifac[i], B[i] = qpow(i, a[l])*ifac[i]%P;
A = A*B, A.set(a[l]); for(int i = 0; i <= a[l]; ++i) A[i] = A[i]*fac[i]%P; return A;
}
int mid = (l+r)>>1; return solve(l, mid)*solve(mid+1, r);
}
int main()
{
freopen("b.in", "r", stdin), freopen("b.out", "w", stdout);
m = read(), n = read(), fac[0] = 1;
for(int i = 1; i <= m; ++i) a[i] = read(), s += a[i];
for(int i = 1; i <= n+m-1; ++i) fac[i] = fac[i-1]*i%P;
ifac[n+m-1] = qpow(fac[n+m-1], P-2);
for(int i = n+m-1; i; --i) ifac[i-1] = ifac[i]*i%P;
F = solve(1, m); for(int i = 0; i <= s; ++i) ans = (ans+C(n+m-1, m-1+i)*F[i]%P)%P;
printf("%d\n", ans);
return 0;
}