1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
| #include <algorithm> #include <iostream> #include <cstring> #include <cstdio> #include <vector> #include <stack> #include <queue> #include <cmath> #include <set> #include <map> #define mp make_pair #define pb push_back #define pii pair<int,int> #define link(x) for(edge *j=H1[x];j;j=j->next) #define inc(i,l,r) for(int i=l;i<=r;i++) #define dec(i,r,l) for(int i=r;i>=l;i--) const int MAXN=1e6+10; const double eps=1e-8; #define ll long long const int inf=1e9; using namespace std; struct edge{int t;edge*next;}e[MAXN<<1],*H1[MAXN],*o=e; void add(int x,int y){o->t=y;o->next=H1[x];H1[x]=o++;} ll read(){ ll x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch))x=x*10+ch-'0',ch=getchar(); return x*f; } char str[MAXN],s[MAXN]; int n,m; int sz[MAXN],maxx[MAXN],key,rt,base; bool vis[MAXN]; void get_root(int x,int pre){ sz[x]=1;maxx[x]=0; link(x){ if(vis[j->t]||j->t==pre)continue; get_root(j->t,x); sz[x]+=sz[j->t];maxx[x]=max(maxx[x],sz[j->t]); } maxx[x]=max(maxx[x],base-sz[x]); if(key>maxx[x])key=maxx[x],rt=x; }
int g[MAXN],h[MAXN],g1[MAXN],h1[MAXN]; unsigned long long sum[MAXN],dis[MAXN],Dep[MAXN],ma[MAXN],sum1[MAXN]; int dep[MAXN]; ll ans; bool G[MAXN],H[MAXN]; int st[MAXN],tot,St[MAXN],tot1; void dfs(int x,int pre){ Dep[dep[x]]=dis[x]; if(dep[x]<=m){ if(dis[x]==sum[dep[x]])G[dep[x]]=1,ans+=h[m-dep[x]],g1[dep[x]-1]++; if(dis[x]*ma[m-dep[x]]==sum1[m]-sum1[m-dep[x]])H[dep[x]]=1,ans+=g[m-dep[x]],h1[dep[x]-1]++; st[++tot]=dep[x]-1;St[++tot1]=dep[x]-1; } else{ if(G[dep[x]-m]&&dis[x]-Dep[dep[x]-m]==ma[dep[x]-m]*sum[m])G[dep[x]]=1,ans+=h[(m-dep[x]%m)%m],g1[(dep[x]%m-1+m)%m]++; if(H[dep[x]-m]&&dis[x]-Dep[dep[x]-m]==ma[dep[x]-m]*sum1[m])H[dep[x]]=1,ans+=g[(m-dep[x]%m)%m],h1[(dep[x]%m-1+m)%m]++; st[++tot]=(dep[x]%m-1+m)%m;St[++tot1]=(dep[x]%m-1+m)%m; } link(x){ if(vis[j->t]||j->t==pre)continue; dep[j->t]=dep[x]+1; dis[j->t]=dis[x]+ma[dep[j->t]-1]*str[j->t]; dfs(j->t,x); } G[dep[x]]=H[dep[x]]=0; }
void solve(int x){ vis[x]=1;tot1=0;g[0]=h[0]=1;dis[x]=str[x];Dep[1]=dis[x]; if(str[x]==s[m])H[1]=1; if(str[x]==s[1])G[1]=1; link(x){ if(vis[j->t])continue; tot=0;dep[j->t]=2;dis[j->t]=dis[x]+ma[dep[j->t]-1]*str[j->t];dfs(j->t,x); inc(i,1,tot)g[st[i]]+=g1[st[i]],g1[st[i]]=0,h[st[i]]+=h1[st[i]],h1[st[i]]=0; } inc(i,1,tot1)g[St[i]]=h[St[i]]=0; g[0]=h[0]=0; link(x){ if(vis[j->t])continue; key=inf;base=sz[j->t];get_root(j->t,0); solve(rt); } }
int main(){ int _=read();ma[0]=1; inc(i,1,1e6)ma[i]=ma[i-1]*131; while(_--){ memset(H1,0,sizeof(H1));o=e; n=read();m=read(); inc(i,1,n)vis[i]=0; scanf("%s",str+1); int x,y;ans=0; inc(i,2,n)x=read(),y=read(),add(x,y),add(y,x); scanf("%s",s+1); inc(i,1,m)sum[i]=sum[i-1]*131+s[i],sum1[i]=sum1[i-1]+ma[i-1]*s[i]; key=inf;base=n;get_root(1,0); solve(rt); printf("%lld\n",ans); } return 0; }
|