bzoj4598(点分治+hash)

题解

点对问题考虑点分

我们考虑子树合并 有两种情况 分别是当前子树链作为开头或者结尾 对于长度大于m和小于等于m再分情况讨论下 然后分别维护已经合并完的子树在长度为x时 分别作为开头和结尾的情况下的方案数 统计贡献的话 就直接用hash判当前的子串是否合法即可

hash我们考虑自然溢出

时间复杂度 O(nlogn)

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <stack>
#include <queue>
#include <cmath>
#include <set>
#include <map>
#define mp make_pair
#define pb push_back
#define pii pair<int,int>
#define link(x) for(edge *j=H1[x];j;j=j->next)
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,r,l) for(int i=r;i>=l;i--)
const int MAXN=1e6+10;
const double eps=1e-8;
#define ll long long
const int inf=1e9;
using namespace std;
struct edge{int t;edge*next;}e[MAXN<<1],*H1[MAXN],*o=e;
void add(int x,int y){o->t=y;o->next=H1[x];H1[x]=o++;}
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}
char str[MAXN],s[MAXN];
int n,m;
int sz[MAXN],maxx[MAXN],key,rt,base;
bool vis[MAXN];
void get_root(int x,int pre){
sz[x]=1;maxx[x]=0;
link(x){
if(vis[j->t]||j->t==pre)continue;
get_root(j->t,x);
sz[x]+=sz[j->t];maxx[x]=max(maxx[x],sz[j->t]);
}
maxx[x]=max(maxx[x],base-sz[x]);
if(key>maxx[x])key=maxx[x],rt=x;
}

int g[MAXN],h[MAXN],g1[MAXN],h1[MAXN];//g 表示开头 h表示结尾
unsigned long long sum[MAXN],dis[MAXN],Dep[MAXN],ma[MAXN],sum1[MAXN];
int dep[MAXN];
ll ans;
bool G[MAXN],H[MAXN];
int st[MAXN],tot,St[MAXN],tot1;
void dfs(int x,int pre){
Dep[dep[x]]=dis[x];
if(dep[x]<=m){
if(dis[x]==sum[dep[x]])G[dep[x]]=1,ans+=h[m-dep[x]],g1[dep[x]-1]++;
if(dis[x]*ma[m-dep[x]]==sum1[m]-sum1[m-dep[x]])H[dep[x]]=1,ans+=g[m-dep[x]],h1[dep[x]-1]++;
st[++tot]=dep[x]-1;St[++tot1]=dep[x]-1;
}
else{
if(G[dep[x]-m]&&dis[x]-Dep[dep[x]-m]==ma[dep[x]-m]*sum[m])G[dep[x]]=1,ans+=h[(m-dep[x]%m)%m],g1[(dep[x]%m-1+m)%m]++;
if(H[dep[x]-m]&&dis[x]-Dep[dep[x]-m]==ma[dep[x]-m]*sum1[m])H[dep[x]]=1,ans+=g[(m-dep[x]%m)%m],h1[(dep[x]%m-1+m)%m]++;
st[++tot]=(dep[x]%m-1+m)%m;St[++tot1]=(dep[x]%m-1+m)%m;
}
link(x){
if(vis[j->t]||j->t==pre)continue;
dep[j->t]=dep[x]+1;
dis[j->t]=dis[x]+ma[dep[j->t]-1]*str[j->t];
dfs(j->t,x);
}
G[dep[x]]=H[dep[x]]=0;
}

void solve(int x){
vis[x]=1;tot1=0;g[0]=h[0]=1;dis[x]=str[x];Dep[1]=dis[x];
if(str[x]==s[m])H[1]=1;
if(str[x]==s[1])G[1]=1;
link(x){
if(vis[j->t])continue;
tot=0;dep[j->t]=2;dis[j->t]=dis[x]+ma[dep[j->t]-1]*str[j->t];dfs(j->t,x);
inc(i,1,tot)g[st[i]]+=g1[st[i]],g1[st[i]]=0,h[st[i]]+=h1[st[i]],h1[st[i]]=0;
}
inc(i,1,tot1)g[St[i]]=h[St[i]]=0;
g[0]=h[0]=0;
link(x){
if(vis[j->t])continue;
key=inf;base=sz[j->t];get_root(j->t,0);
solve(rt);
}
}

int main(){
int _=read();ma[0]=1;
inc(i,1,1e6)ma[i]=ma[i-1]*131;
while(_--){
memset(H1,0,sizeof(H1));o=e;
n=read();m=read();
inc(i,1,n)vis[i]=0;
scanf("%s",str+1);
int x,y;ans=0;
inc(i,2,n)x=read(),y=read(),add(x,y),add(y,x);
scanf("%s",s+1);
inc(i,1,m)sum[i]=sum[i-1]*131+s[i],sum1[i]=sum1[i-1]+ma[i-1]*s[i];
key=inf;base=n;get_root(1,0);
solve(rt);
printf("%lld\n",ans);
}
return 0;
}

题目描述

给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母AZ,再给出长度为m的模式串s,其中每一位仍然是AZ的大写字母。Alice希望知道,有多少对结点<u,v>满足T上从UV的最短路径形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v><v,u>需要被区分.所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以XYXYXY不能看作是S重复若干次得到的。

Input

每一个数据有多组测试,

第一行输入一个整数C,表示总的测试个数。

对于每一组测试来说:

第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1n,之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(第i个字符对应了第i个结点).之后n1行,每行有两个整数uv表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,为模式串S

1<=C<=10,3<=N<=106,3<=M<=106

Output

给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从uv的路径形成的字符串恰好是模式串的若干次重复.