cf616F(SAM+树dp)

题意

给你n个串 每个串都有一个价值$c_i$ 然后让你构造一个串$S$ $S$的价值定义为:

$F(S)=\sum_{i=1}^{n}c_ip_{S,i}|S|$

$p_{S,i}$ 表示在第$T_i$中$S$出现的次数

$|S|$表示串的长度

问 如何选择$S$能得到最大价值 输出最大价值

题解

我们考虑把所有串建广义的后缀自动机

众所周知的 $dis[fa[x]]+1$ 到$dis[x]$是本周不同子串 然后求出其在所有串中的价值和 这个可以在后缀上统计得到

复杂度 O$(\sum_{i-1}^{n}|T_i|)$

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <stack>
#include <queue>
#include <cmath>
#include <set>
#include <map>
#define mp make_pair
#define pb push_back
#define pii pair<int,int>
#define link(x) for(edge *j=h[x];j;j=j->next)
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,r,l) for(int i=r;i>=l;i--)
const int MAXN=1e6+10;
const double eps=1e-8;
const int mod=1e9+9;
#define ll long long
const ll inf=1e18;
using namespace std;
struct edge{int t;edge*next;}e[MAXN<<1],*h[MAXN],*o=e;
void add(int x,int y){o->t=y;o->next=h[x];h[x]=o++;}
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}

int n;
int ch[MAXN][26],fa[MAXN],dis[MAXN];
ll sum[MAXN];
int rt,cur,cnt;
void built(int x,int vul){
int last=cur;cur=++cnt;dis[cur]=dis[last]+1;sum[cur]=vul;int p=last;
for(;p&&!ch[p][x];p=fa[p])ch[p][x]=cur;
if(!p)fa[cur]=rt;
else{
int q=ch[p][x];
if(dis[q]==dis[p]+1)fa[cur]=q;
else{
int nt=++cnt;dis[nt]=dis[p]+1;
memcpy(ch[nt],ch[q],sizeof(ch[q]));
fa[nt]=fa[q];fa[q]=fa[cur]=nt;
for(;ch[p][x]==q;p=fa[p])ch[p][x]=nt;
}
}
}

string str[100005];
int a[MAXN];
ll maxx;
void dfs(int x){
link(x){
dfs(j->t);sum[x]+=sum[j->t];
}
if(x!=rt&&dis[x]>dis[fa[x]])maxx=max(maxx,max(sum[x]*(dis[fa[x]]+1),sum[x]*dis[x]));
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
inc(i,1,n)cin>>str[i];
inc(i,1,n)cin>>a[i];
rt=cnt=cur=1;
inc(i,1,n){
int sz=str[i].size();cur=1;
for(int j=0;j<sz;j++)built(str[i][j]-'a',a[i]);
}
inc(i,1,cnt)add(fa[i],i);
maxx=0;dfs(rt);
cout<<maxx<<endl;
return 0;
}