BZOJ3622 已经没有什么好害怕的了(二项式反演,组合数学)

题目链接

洛谷

darkbzoj

题意简述

给你两个长为 $n$ 、无重复元素的数列 $a,b$,求将 $b$ 重排后 $\left(\sum\limits_{i=1}^n[a_i>b_i]\right)-\left(\sum\limits_{i=1}^n[a_i<b_i]\right)=k$($k$ 给定)的方案数。

$n\le2000$。

简要做法

首先,$a$ 大于 $b$ 比 $b$ 大于 $a$ 多 $k$ 对可以转化为 $a$ 大于 $b$ 有 $\frac{n+k}2$ 对。

考虑 $dp​$,设 $f(i,j)​$ 为 $a_{1..i}​$ 中选 $j​$ 个数,给这 $j​$ 个数每个数匹配一个小于它的 $b​$ 的方案数。为方便转移,一开始要先对 $a,b​$ 分别从小到大排序,这样的话若 $a_i>b_j​$,必然有 $a_{i+1}>b_j​$。

设比 $a_i$ 大的 $b$ 有 $cnt_i$ 个,转移方程就是 $f(i,j)=f(i-1,j)+(cnt_i-j+1)\times f(i-1,j-1)$。$cnt_i$ 可以二分查找/双指针计算。

考虑这样一个式子:$f(n,i)\times(n-i)!$,它的意义是,先从 $a$ 中选 $i$ 个数,给它们分别匹配一个小于它们的 $b$,再把剩下的 $n-i$ 个数随意匹配。这个式子并不是某种“方案数”,因为它可能会将相同的匹配方案算重。事实上,对于每种恰有 $j$ 对 $a>b$ 的匹配方案,它在 $f(n,i)\times(n-i)!$ 中被计算了 $\binom{j}{i}$ 次。令 $ans_i$ 表示恰好有 $i$ 对 $a$ 大于 $b$ 的方案数,就有 $ans_i=f(n,i)\times(n-i)!-\sum\limits_{j=i+1}^n\binom{j}{i}ans_j$。这样的话,就可以递推地计算答案。

这题还有另一种做法,叫二项式反演。我自己没有太理解清楚这种方法,所以不详细阐述。这种做法可以 $O(n)​$ 地计算某个 $ans_i​$,但不影响总复杂度($dp​$ 还是 $O(n^2)​$ 的):

$$ans_i=\sum\limits_{j=i}^n(-1)^{j-i}\binom j i f(n,j)\times(n-j)​$$

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N=2010;
const int mod=1e9+9;

int qpow(int x,int y);

int n,k,a[N],b[N],f[N][N],ans[N],c[N][N],fac[N];

int main()
{
int i,j,cnt=0;

cin>>n>>k;

if ((n&1)!=(k&1))
{
cout<<0;
return 0;
}

k=(n+k)/2;

for (i=1;i<=n;++i) scanf("%d",a+i);
for (i=1;i<=n;++i) scanf("%d",b+i);

sort(a+1,a+n+1);
sort(b+1,b+n+1);

fac[0]=1;
for (i=1;i<=n;++i) fac[i]=(ll)fac[i-1]*i%mod;

c[0][0]=1;
for (i=1;i<=n;++i)
{
c[i][0]=1;
for (j=1;j<=i;++j) c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
}

f[0][0]=1;
for (i=1;i<=n;++i)
{
f[i][0]=1;
while (cnt<n&&b[cnt+1]<a[i]) ++cnt;
for (j=1;j<=cnt;++j)
{
f[i][j]=(f[i-1][j]+(ll)f[i-1][j-1]*(cnt-j+1))%mod;
}
}

for (i=n;i>=k;--i)
{
ans[i]=(ll)f[n][i]*fac[n-i]%mod;
for (j=n;j>i;--j) ans[i]=(ans[i]+mod-(ll)ans[j]*c[j][i]%mod)%mod;
}

cout<<ans[k];

return 0;
}

int qpow(int x,int y)
{
int out=1;
while (y)
{
if (y&1) out=(ll)out*x%mod;
x=(ll)x*x%mod;
y>>=1;
}
return out;
}