bzoj3365(点分治)

题解

同上题

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <stack>
#include <queue>
#include <cmath>
#include <set>
#include <map>
#define mp make_pair
#define pb push_back
#define pii pair<int,int>
#define link(x) for(edge *j=h[x];j;j=j->next)
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,r,l) for(int i=r;i>=l;i--)
const int MAXN=3e5+10;
const double eps=1e-8;
#define ll long long
using namespace std;
const int inf=1e9;
struct edge{int t,v;edge*next;}e[MAXN<<1],*h[MAXN],*o=e;
void add(int x,int y,int vul){o->t=y;o->v=vul;o->next=h[x];h[x]=o++;}
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}

int n,K;
int rt,base,sz[MAXN],maxx[MAXN],key;
bool vis[MAXN];
ll ans;

void get_root(int x,int pre){
sz[x]=1;maxx[x]=0;
link(x){
if(j->t==pre||vis[j->t])continue;
get_root(j->t,x);
sz[x]+=sz[j->t];maxx[x]=max(maxx[x],sz[j->t]);
}
maxx[x]=max(maxx[x],base-sz[x]);
if(maxx[x]<key)key=maxx[x],rt=x;
}

ll st[MAXN],dis[MAXN];int tot,num[MAXN];
void get_deep(int x,int pre){
num[x]=1;
if(!pre)st[++tot]=dis[x];
link(x){
if(j->t==pre||vis[j->t])continue;
dis[j->t]=dis[x]+j->v;st[++tot]=dis[j->t];
get_deep(j->t,x);
num[x]+=num[j->t];
}
}

ll get_sum(int x,ll dist){
dis[x]=dist;tot=0;get_deep(x,0);
sort(st+1,st+tot+1);
ll sum=0;
inc(i,1,tot){
int l=i+1;int r=tot;int ans1=0;
while(l<=r){
int mid=(l+r)>>1;
if(st[i]+st[mid]<=K)ans1=mid,l=mid+1;
else r=mid-1;
}
if(!ans1)continue;
sum+=ans1-i;
}
return sum;
}

void solve(int x,int y){
vis[x]=1;ans+=get_sum(x,0);
link(x){
if(j->t==y||vis[j->t])continue;
key=inf;base=num[j->t];get_root(j->t,x);ans-=get_sum(j->t,j->v);
solve(rt,x);
}
}

int main(){
int x,y,z;
char ch;
n=read();int m=read();
inc(i,2,n)scanf("%d %d %d %c",&x,&y,&z,&ch),add(x,y,z),add(y,x,z);
K=read();
key=inf;base=n;get_root(1,0);
solve(rt,0);
printf("%lld\n",ans);
}

题目描述

https://www.lydsy.com/JudgeOnline/problem.php?id=3365