MOCK NOIP 20220920 B – 骰子 另解题解
形式化题意
给定 $A+B$ 个本质相同的随机变量,在 $\{1,2,\cdots,6\}$ 内等概率取值。若两个随机变量 $a \in A, b \in B, a=b$,则二者相消为 $0$。求在所有可能的消除完成后,存在 $a \leq t, a \in A, t \in \{1,2,\cdots,6\}$ 的概率。对 $10^9+7$ 取模。
朴素DP
令 $c_i$ 表示随机结果中数值为 $i$ 的 $A$ 内变量数;$d_i$ 同理。原题意可以转化为求出 $1-P(\text{不存在小于等于}t\text{的}a)$。由概率的定义,我们可以用能造成后者局面的方案总数除以 $6^{A+B}$ 既是结果。换句话说,原命题的否命题应当满足 $\forall i \in \{1,2,\cdots,t\}, c_i \leq d_i$。
又因为这些随机变量本质相同,则由“可重集的排列数”得,若有确定的 $c_{1},c_2,\cdots,c_6, \sum_{i=1}^{6}c_i=A$,则掷出对应结果的方案数为
$$\dbinom{A}{c_1}\dbinom{A-c_1}{c_2}\cdots\dbinom{A-\sum_{i=1}^{5}c_i}{c_6}=\dfrac{\color{blue}A!}{\prod_{i=1}^{6}c_i!}$$
由此,我们逐个考虑 $c_i, d_i, c_i\leq d_i$ 并统计所有方案中 $1/\prod_{i=1}^{t}c_i, 1/\prod_{i=1}^{t}d_i$ 之和。设 $f(i,x,y)$ 表示考虑了数字 $1,\cdots,i, \sum_{j=1}^{i}c_j=x, \sum_{j=1}^{i}d_j=y$ 之方案的权值之和。则存在如下转移:
$$f(i+1,x+c_{i+1}, y+d_{i+1})\leftarrow f(i,x,y)\frac{1}{c_{i+1}!}\frac{1}{d_{i+1}!}, c_{i+1}\leq d_{i+1}$$
由于我们并不关心最后 $(6-t)$ 个数字对应的 $A, B$ 变量数,故我们可以将它们看作一类数字并任意取 $\{t+1,\cdots,6\}$ 内值,所以满足原命题的否命题的实际方案数为
$$\sum_{x=0}^{A}\sum_{y=x}^{B}f(t,x,y)\dfrac{\color{blue}A!B!}{(A-x)!(B-y)!}(6-t)^{A-x+B-y}$$
直接计算的时间复杂度为 $O(tA^2B^2)$,不能接受。
利用卷积优化
我们发现对于确定的 $c_{i+1},d_{i+1}$,其转移系数完全相同且基本独立(除 $c_{i+1}\leq d_{i+1}$ 这一条件外)。因此似乎可用“卷积”优化。
将 DP 数组一维化。更具体地,令 $f'(i,x(B+1)+y)=f(i,x,y)$,相当于在 $B+1$ 进制下填充 $x, y$ 作为次低位和最低位。于是下标的加减具有了结合律:
$$\begin{aligned}
&f(i+1,x+c_{i+1}, y+d_{i+1})\leftarrow f(i,x,y)\frac{1}{c_{i+1}!}\frac{1}{d_{i+1}!}, c_{i+1}\leq d_{i+1}\\
\Longrightarrow &f(i+1,(x+c_{i+1})(B+1)+(y+d_{i+1}))\leftarrow f(i,x(B+1)+y)\frac{1}{c_{i+1}!d_{i+1}!}
\end{aligned}$$
但存在不合法转移的问题:当 $x’=x+c_{i+1}>A$ 时,必然有 $x'(B+1)+y’\leq (A+1)(B+1)$,在数列的更高位,不去收集即可;但对于 $y’=y+d_{i+1}>B$,它产生进位而转移到原属于 $x’+1$ 的位置上,而我们束手无策。因此只好采用 $2B+1$ 进制以解决该问题。由此,每一类合法转移的结果均能映射为唯一的 $2B+1$ 进制数,同时每一轮转移完后应当及时清除非法项(第 $x(2B+1)+y, y\in (B,2B] \lor x>A$ 项)的系数。
时间复杂度 $\operatorname{O}(tAB\log (AB))$,常数极大(需采用任意模数下的数论变换)。
#include <bits/stdc++.h>
using namespace std;
/* 快读已省略。 */
#define inl inline
#define newl putchar('\n')
#define scd second
#define fst first
using ll = long long;
using ull = unsigned long long;
using pint = pair <int, int>;
constexpr int N = 256, G = 1<<18, P = 1e9 + 7, gen = 3,
P1 = 998244353, P2 = 1004535809, P3 = 985661441,
invP1P2 = 669690699, invP1P2P3 = 401569863;
int t, a, b, l, f[G], _f[3][G], g[3][G]; ll fact[N], finv[N], tot;
inl ll fpow (ll a, ll b, const int mod) {
ll res = 1; a %= mod;
for (; b; b >>= 1) {
if (b & 1) (res *= a) %= mod;
(a *= a) %= mod;
}
return res;
}
inl void calc_inv (int n) {
fact[0] = 1;
for (ll x = 1; x <= n; ++x)
fact[x] = fact[x - 1] * x % P;
finv[n] = fpow (fact[n], P-2, P);
for (int x = n - 1; ~x; --x)
finv[x] = finv[x + 1] * (x + 1ll) % P;
}
inl ll C (int x, int y) { return x < y || x < 0 ? 0 :
fact[x] * finv[y] % P * finv[x - y] % P; }
inl void henkan (int f[], int l) {
static int tr[G], lst = tr[0] = 0;
if (lst != l) { lst = l;
for (int i = 1; i < 1<<l; ++i)
tr[i] = tr[i>>1]>>1|((i&1)*(1<<l-1)); }
for (int i = 1; i < 1<<l; ++i)
if (tr[i] < i) swap (f[tr[i]], f[i]);
}
#define clog2(x) ceil (log2 (x))
template <const int P>
inl void NTT (int f[], int l, bool rev) {
henkan (f, l);
for (int len = 2; len <= 1<<l; len <<= 1) {
const int w_n = fpow (gen, (P-1)/len, P);
ll w = 1, g, h;
for (int st = 0; st < 1<<l; st += len, w = 1)
for (int i = st; i < st + len/2; ++i, (w *= w_n) %= P)
g = f[i], h = f[i + len/2] * w % P,
f[i] = (g + h) % P,
f[i + len/2] = (g + P - h) % P;
}
if (!rev) return;
const ll invl = fpow (1<<l, P-2, P);
for (int i = 0; i < 1<<l; ++i)
f[i] = f[i] * invl % P;
reverse (f + 1, f + (1<<l));
}
inl int conv (int x, int y) { return x * (b<<1|1) + y; }
int main () {
/* */
read (t, a, b); calc_inv (max (a, b));
f[0] = 1; l = clog2 ((a + 1) * (b<<1|1) * 2 - 1);
for (int x = 0; x <= a; ++x)
for (int y = x; y <= b; ++y)
g[0][conv (x, y)] = finv[x] * finv[y] % P;
memcpy (g[1], g[0], (1<<l)<<2);
memcpy (g[2], g[0], (1<<l)<<2);
for (int i = 0; i < 1<<l; ++i)
g[0][i] %= P1, g[2][i] %= P3;
NTT <P1> (g[0], l, 0);
NTT <P2> (g[1], l, 0);
NTT <P3> (g[2], l, 0);
for (int kai = 0; kai < t; ++kai) {
memcpy (_f[0], f, (1<<l)<<2);
NTT <P1> (_f[0], l, 0);
for (int i = 0; i < 1<<l; ++i)
_f[0][i] = (ll) _f[0][i] * g[0][i] % P1;
NTT <P1> (_f[0], l, 1);
memcpy (_f[1], f, (1<<l)<<2);
NTT <P2> (_f[1], l, 0);
for (int i = 0; i < 1<<l; ++i)
_f[1][i] = (ll) _f[1][i] * g[1][i] % P2;
NTT <P2> (_f[1], l, 1);
memcpy (_f[2], f, (1<<l)<<2);
NTT <P3> (_f[2], l, 0);
for (int i = 0; i < 1<<l; ++i)
_f[2][i] = (ll) _f[2][i] * g[2][i] % P3;
NTT <P3> (_f[2], l, 1);
memset (f, 0, (1<<l)<<2);
static ll k1, k4, x_1, x_2, x_3, __x, id;
for (int x = 0; x <= a; ++x)
for (int y = x; y <= b; ++y)
x_1 = _f[0][id = conv (x, y)],
x_2 = _f[1][id], x_3 = _f[2][id],
k1 = (x_2 - x_1 + P2) * invP1P2 % P2,
__x = k1 * P1 + x_1,
k4 = (x_3 - __x % P3 + P3) * invP1P2P3 % P3,
f[id] = (k4 * P1 % P * P2 + __x) % P;
}
for (int x = 0; x <= a; ++x)
for (int y = x; y <= b; ++y)
tot += f[conv (x, y)] * C (a, x) % P * C (b, y)
% P * fact[x] % P * fact[y] % P
* fpow (6 - t, a - x + b - y, P) % P;
print ((P + 1 - tot % P * fpow (fpow (6, a + b, P), P - 2, P) % P) % P);
return 0;
}
近期评论