题目链接

洛谷

题意简述

出现次数不为 $1$ 的子串的 出现次数 $\times$ 长度 的最大值。

SAM 做法

简要做法

一个状态的出现次数可以这么计算:

插入一个字符时,$np$ 的 $cnt$ 设为 $1$,一个状态的出现次数就是它在 $parent$ 树上的子树的 $cnt$ 之和。

证明..简要说一下:因为 $np$ 的 $right$ 集合为 $\{L\}$ 。

所以,插入整个字符串后 dfs 一遍 $parent$ 树算一算就好了。

参考代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N=1000010;
const int DELTA=26;

struct Node
{
    int len,ch[DELTA],par,cnt;
    Node(){ memset(ch,0,sizeof(ch)); }
} sam[N<<1];

void insert(int x);
void add(int u,int v);
void dfs(int u);

char s[N];
int p,tot,head[N<<1],nxt[N<<1],to[N<<1],cnt;
ll ans;

int main()
{
    int i;

    scanf("%s",s);

    for (i=0;s[i];++i) insert(s[i]-'a');

    for (i=1;i<=tot;++i) add(sam[i].par,i);

    dfs(0);

    cout<<ans;

    return 0;
}

void insert(int x)
{
    int np=++tot;
    sam[np].len=sam[p].len+1;
    sam[np].cnt=1;
    while (p&&!sam[p].ch[x])
    {
        sam[p].ch[x]=np;
        p=sam[p].par;
    }
    if (sam[p].ch[x])
    {
        int q=sam[p].ch[x];
        if (sam[q].len==sam[p].len+1) sam[np].par=q;
        else
        {
            int nq=++tot;
            sam[nq].len=sam[p].len+1;
            memcpy(sam[nq].ch,sam[q].ch,sizeof(sam[q].ch));
            sam[nq].par=sam[q].par;
            sam[q].par=sam[np].par=nq;
            while (sam[p].ch[x]==q)
            {
                sam[p].ch[x]=nq;
                p=sam[p].par;
            }
        }
    }
    else
    {
        sam[p].ch[x]=np;
        sam[np].par=0;
    }
    p=np;
}

void add(int u,int v)
{
    nxt[++cnt]=head[u];
    head[u]=cnt;
    to[cnt]=v;
}

void dfs(int u)
{
    int i,v;
    for (i=head[u];i;i=nxt[i])
    {
        v=to[i];
        dfs(v);
        sam[u].cnt+=sam[v].cnt;
    }
    if (sam[u].cnt>1) ans=max(ans,(ll)sam[u].cnt*sam[u].len);
}

后缀数组做法

简要做法

一个长度为 $h$ 的子串出现 $k$ 次就是有 $k-1$ 个连续的 $height\ge h$。单调栈维护即可。

然而..卡常卡不过去QAQ

参考代码(最高80分)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int N=1000010;

char s[N];
int n,sa[N],rk[N<<1],id[N<<1],px[N],cnt[N],ht[N],l[N],sta[N],top;
long long ans;

bool cmp(int x,int y,int w){ return id[x]==id[y]&&id[x+w]==id[y+w]; }

int main()
{
    int i,w,p,m=300,k;

    scanf("%s",s+1);
    n=strlen(s+1);
    for (i=1;i<=n;++i) ++cnt[rk[i]=s[i]];
    for (i=1;i<=m;++i) cnt[i]+=cnt[i-1];
    for (i=n;i>=1;--i) sa[cnt[rk[i]]--]=i;

    for (w=1;w<n;w<<=1,m=p)
    {
        memset(cnt,0,sizeof(cnt));
        for (p=0,i=n;i>n-w;--i) id[++p]=i;
        for (i=1;i<=n;++i) if (sa[i]>w) id[++p]=sa[i]-w;
        for (i=1;i<=n;++i) ++cnt[px[i]=rk[id[i]]];
        for (i=1;i<=m;++i) cnt[i]+=cnt[i-1];
        for (i=n;i>=1;--i) sa[cnt[px[i]]--]=id[i];
        swap(id,rk);
        for (p=0,i=1;i<=n;++i) rk[sa[i]]=cmp(sa[i],sa[i-1],w)?p:++p;
    }

    for (i=1,k=0;i<=n;++i)
    {
        if (k) --k;
        while (s[i+k]==s[sa[rk[i]-1]+k]) ++k;
        ht[rk[i]]=k;
    }

    for (i=1;i<=n;++i)
    {
        while (top&&ht[sta[top]]>=ht[i]) --top;
        l[i]=sta[top];
        sta[++top]=i;
    }

    sta[top=1]=n+1;

    for (i=n;i>=1;--i)
    {
        while (top&&ht[sta[top]]>ht[i]) --top;
        if (sta[top]-l[i]>1) ans=max(ans,(long long)ht[i]*(sta[top]-l[i]));
        sta[++top]=i;
    }

    cout<<ans;

    return 0;
}