LOJ6519 魔力环(Burnside引理,容斥原理)

题目链接

LOJ

洛谷

题意简述

你需要给 $n$ 颗珠子的项链染 $m$ 颗黑色,$n-m$ 颗白色,不能有连续的一串黑色珠子长度超过 $k$,求旋转同构下本质不同的染色方案数。

$1\le m,k\le n\le10^5$

简要做法

首先套用 Burnside 引理,以及位移为 $r$ 的旋转周期为 $\gcd(r, n)$ 的结论,得到答案的式子:
$$
\begin{aligned}
answer&=\frac 1 n\sum\limits_{i=1}^nf\left(\frac n{\gcd(i,n)}\right)\\
&=\frac 1 n\sum\limits_{d|n}\varphi(d)f(d)
\end{aligned}
$$
其中 $f(x)$ 表示在一个长为 $\frac n x$ 的项链上,染 $\frac{m}{x}$ 个黑珠子,$\frac{n-m}x$ 个白珠子,不能有连续的一串黑色珠子长度超过 $k$ 的方案数(在不旋转的意义下计数)。

可以看出只有 $d|m$ 时 $f(d)$ 可能不为零,如果用 $f(x, y)$ 表示在一个长为 $x+y$ 的项链上,染 $x$ 个黑珠子,$y$ 个白珠子,不能有连续的一串黑色珠子长度超过 $k$ 的方案数(在不旋转的意义下计数),答案的式子可以写成:

$$
answer=\frac 1 n\sum\limits_{d|\gcd(n, m)}\varphi(d)f\left(\frac m d, \frac{n-m}d\right)
$$

现在的问题转化成了快速求 $f(x, y)$。

首先,特判掉两种情况:

  1. $k=n$
  2. $y\ne 0$ 且 $x\le k$

这两种情况下 $f(x, y)=\binom{x+y}x$

由于是在环上不好处理,枚举两侧的黑珠子个数,就可以转化为序列上的问题。

而序列上的问题,就相当于求方程 $x_1+x_2+\cdots+x_{y+1}=x (0\le x_i\le k)$ 的解的个数。

考虑容斥,枚举至少有 $i$ 个变量的值大于 $k$(实际上是枚举大小为 $i$ 的子集都大于 $k$),解的个数为 $\binom{x+y-i(k+1)}y$。

这样的话,枚举两侧黑珠子个数最多枚举到 $k$,容斥复杂度为 $O(\frac{x+y}k)$,计算 $f(x,y)$ 的复杂度为 $O(x+y)$,整道题的复杂度就是 $O(\text{预处理组合数}+\sigma(n))$,其中 $\sigma(n)$ 表示 $n$ 的所有约数之和,在数据范围内最大为 $403200$。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <iostream>
#include <cstdio>

using namespace std;

typedef long long ll;

const int N = 100005;
const int mod = 998244353;

int n, m, k, p[N], ptot, phi[N], fact[N], invf[N];
bool np[N];

int qpow(int x, int y)
{
int out = 1;
while (y)
{
if (y & 1) out = (ll) out * x % mod;
x = (ll) x * x % mod;
y >>= 1;
}
return out;
}

int c(int x, int y)
{
if (y > x || y < 0) return 0;
return (ll) fact[x] * invf[y] % mod * invf[x - y] % mod;
}

int calc(int x, int y)
{
int out = 0;
for (int i = 0; i * (k + 1) <= x + y; ++i)
{
out = (out + (i & 1 ? -1ll : 1ll) * c(x + y - (k + 1) * i, y) * c(y + 1, i) % mod + mod) % mod;
}
return out;
}

int f(int x, int y)
{
if (k == n || y != 0 && x <= k) return c(x + y, x);
int out = 0;
for (int i = 0; i <= x && i <= k; ++i)
{
out = (out + (ll) (i + 1) * calc(x - i, y - 2)) % mod;
}
return out;
}

int gcd(int x, int y)
{
return y ? gcd(y, x % y) : x;
}

int main()
{
cin >> n >> m >> k;

fact[0] = invf[0] = 1;
for (int i = 1; i <= n; ++i) fact[i] = (ll) fact[i - 1] * i % mod;
invf[n] = qpow(fact[n], mod - 2);
for (int i = n - 1; i >= 1; --i) invf[i] = (ll) invf[i + 1] * (i + 1) % mod;

phi[1] = 1;
for (int i = 2; i <= n; ++i)
{
if (!np[i])
{
p[++ptot] = i;
phi[i] = i - 1;
}
for (int j = 1; j <= ptot && i * p[j] <= n; ++j)
{
int x = i * p[j];
np[x] = true;
if (i % p[j]) phi[x] = phi[i] * (p[j] - 1);
else
{
phi[x] = phi[i] * p[j];
break;
}
}
}

int ans = 0;
int g = gcd(m, n);

for (int i = 1; i * i <= g; ++i)
{
if (g % i == 0)
{
if (i * i == g) ans = (ans + (ll) f(m / i, (n - m) / i) * phi[i]) % mod;
else ans = (ans + (ll) f(m / i, (n - m) / i) * phi[i] + (ll) f(m / (g / i), (n - m) / (g / i)) * phi[g / i]) % mod;
}
}

cout << (ll) ans * qpow(n, mod - 2) % mod;

return 0;
}