题目链接

洛谷

dark bzoj

题意简述

给你两个字符串,从中各取一个子串使这两个子串相同,求方案数。

简要做法

以某两个位置开头的相同子串数=这两个位置开头的后缀的 $lcp$

如果在同一个字符串中,求出 height 数组后使用单调栈求出每个位置作为最小值的贡献即可(单调栈部分与 P2659 美丽的序列[ZJOI2007]棋盘制作 等题类似,在此就不赘述了;求两两 $lcp​$ 之和这部分与 [AHOI2013]差异 类似,故没有写那题的题解)。

由于有两个字符串不太方便,考虑将它们拼接起来并在中间加上一个不存在的字符(如#)。这样求出拼接后的字符串的答案,减去两个原串的答案,就是最终的答案。

参考代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int N=400010;

typedef long long ll;

int sa[N],sa2[N<<1],rk[N<<1],px[N],cnt[N],sta[N],top,f[N],height[N];

struct Suffix_Array
{
    char s[N];
    ll calc()
    {
        ll out=0;
        int n,i,k,w,p,m=200;
        n=strlen(s+1);
        memset(sa2,0,sizeof(sa2));
        memset(cnt,0,sizeof(cnt));
        for (i=1;i<=n;++i) ++cnt[rk[i]=s[i]];
        for (i=1;i<=m;++i) cnt[i]+=cnt[i-1];
        for (i=n;i>=1;--i) sa[cnt[rk[i]]--]=i;
        for (w=1;w<n;w<<=1,m=p)
        {
            memset(cnt,0,sizeof(cnt));
            for (p=0,i=n;i>n-w;--i) sa2[++p]=i;
            for (i=1;i<=n;++i) if (sa[i]>w) sa2[++p]=sa[i]-w;
            for (i=1;i<=n;++i) ++cnt[px[i]=rk[sa2[i]]];
            for (i=1;i<=m;++i) cnt[i]+=cnt[i-1];
            for (i=n;i>=1;--i) sa[cnt[px[i]]--]=sa2[i];
            swap(rk,sa2);
            for (p=0,i=1;i<=n;++i) rk[sa[i]]=sa2[sa[i]]==sa2[sa[i-1]]&&sa2[sa[i]+w]==sa2[sa[i-1]+w]?p:++p;
        }
        for (i=1,k=0;i<=n;++i)
        {
            if (k) --k;
            while (s[i+k]==s[sa[rk[i]-1]+k]) ++k;
            height[rk[i]]=k;
        }
        for (i=1;i<=n;++i)
        {
            while (top&&height[sta[top]]>=height[i]) --top;
            f[i]=i-sta[top];
            sta[++top]=i;
        }
        top=0;
        sta[++top]=n+1;
        height[n+1]=0;
        for (i=n;i>=1;--i)
        {
            while (top&&height[sta[top]]>height[i]) --top;
            out+=(ll)f[i]*(sta[top]-i)*height[i];
            sta[++top]=i;
        }
        return out;
    }
} a,b,ab;

int main()
{
    int n,m,i;

    scanf("%s%s",a.s+1,b.s+1);

    n=strlen(a.s+1);
    m=strlen(b.s+1);
    for (i=1;i<=n;++i) ab.s[i]=a.s[i];
    ab.s[n+1]='#';
    for (i=1;i<=m;++i) ab.s[n+1+i]=b.s[i];

    cout<<ab.calc()-a.calc()-b.calc();

    return 0;
}