Solution of Not another LCA problem

Solution (Gabriel Robert Inelus)

First of all, we should start our solution by thinking about a way of counting the number of pairs {(v1,v2) | (v1,v2) != (v2,v1)} where LCA(v1,v2) = x and they have the property that Value[v1] <= BigValue[x] and Value[v2] <= BigValue[x].

Considering the vertex x, the v1 and v2 vertices can be in two distinct subtrees denoted by the neighbours of vertex x. Let S be the total number of vertices from x's subtrees which have their Value less or equal to BigValue[x] excluding vertex x. The number of pairs (v1,v2) with LCA(v1,v2) = x and neither of v1 and v2 equal x is sum of S*(S-cnt(nb)) where we fix nb as every single neighbour of vertex x and cnt(nb) is the number of good vertices in nb's subtree including nb. Now, if Value[x] <= BigValue[x], we have to consider (v,x) pairs where LCA(v,x) = x and symetrically, (x,v) pairs. Thus, in this case we have to add 2*S pairs to the solution and also a unit which represents the pair (x,x).

We can compute the solution for each vertex x in O(logN) time complexity if we consider that initially, neither of the vertices is good, and we consider each tuple of type (Value[vertex], vertex, 1) and (BigValue[vertex],vertex,2) as events. These events have to be incresingly sorted by the first element of each tuple and in case of equality, by the last element. Now, we can use a reduced imaginary eulerian representation of a rooted tree in this way: We build our imaginary eulerian path by taking into consideration a vertex when we enter the recursion only. Thus, we can use two vectors, poz[x] which denotes the position of the vertex x in the imaginary eulerian representation and len[x] the length of the continuous sequence of vertices which start on poz[x] and denotes the subtree rooted in x. We can easily observe that the interval [poz[x]+1,poz[x]+len[x]-1] represents the subtree of vertex x, excluding x. We can use a Fenwick Tree or a Segment Tree on the reduced eulerian path of the tree in order to have queries on this type of intervals to compute in O(logN) how many vertices of a subtree have Value less or equal to BigValue. We assume that initially, there is no good vertex. By iterating trough the events, if we find an event type 1, we have an update: at the position poz[vertex], we should update from state 0 to state 1 and for type 2 events, we have to compute queries on the actual (updated) eulerian representation. Having the events sorted means that it is guaranteed that by making a query on a subtree, the sum on the interval which represents the specific subtree represents the the number that we are actually searching because all the good vertices in the whole three have already been marked with a 1.

Returning to our formula, when we reach a query event (V,x,2), we compute S as a query on the interval [poz[x]+1,poz[x]+len[x]-1] and for each neighbour of x let nb denote them, we add to the current solution S*(S-q) where q is the result of a query on the interval [poz[nb],poz[nb]+len[nb]+1]. After this sum, if Value[x] <= BigValue[x], we add to the current solution 2*S+1 ( or (S<<1)|1 ) which represent the rest of valid pairs.

The final complexity is O(NlogN) time and O(N) memory. Segment Trees compute the queries faster than Fenwick Trees.

C++ solution (Cosmin Rusu)

    1 #include <iostream>
    2 #include <fstream>
    3 #include <vector>
    4 #include <algorithm>
    5 
    6 using namespace std;
    7 
    8 const int maxn = 100005;
    9 
   10 int n, aib[maxn], in[maxn], out[maxn], k;
   11 long long value[maxn], bigvalue[maxn], answer[maxn];
   12 vector <pair<long long, int> > events;
   13 vector <int> g[maxn];
   14 
   15 inline void dfs(int node) {
   16     in[node] = ++ k;
   17     for(auto it : g[node])
   18         dfs(it);
   19     out[node] = k;
   20 }
   21 
   22 inline int lsb(int x) {
   23     return x & (-x);
   24 }
   25 
   26 inline void update(int pos, int value) {
   27     for(int i = pos ; i <= n ; i += lsb(i))
   28         aib[i] += value;
   29 }
   30 
   31 inline int query(int pos) {
   32     int sum = 0;
   33     for(int i = pos ; i > 0 ; i -= lsb(i))
   34         sum += aib[i];
   35     return sum;
   36 }
   37 
   38 int main() {
   39     cin >> n;
   40     for(int i = 2 ; i <= n ; ++ i) {
   41         int x;
   42         cin >> x;
   43         g[x].push_back(i);
   44     }
   45     dfs(1);
   46     for(int i = 1 ; i <= n ; ++ i) {
   47         cin >> value[i];
   48         events.push_back(make_pair(value[i], i));
   49     }
   50     for(int i = 1 ; i <= n ; ++ i) {
   51         cin >> bigvalue[i];
   52         events.push_back(make_pair(bigvalue[i], i + n));
   53     }
   54     sort(events.begin(), events.end());
   55     for(auto event : events) {
   56         int node = event.second;    
   57         if(node <= n) { // update
   58             update(in[node], 1);
   59         }
   60         else { // query
   61             node -= n;
   62             long long ans = 0;  
   63             int sum = query(out[node]) - query(in[node]);
   64             if(value[node] <= bigvalue[node]) {
   65                 ans += sum; // (node, sons)
   66                 ans += sum; // (sons, node)
   67                 ++ ans; // (node, node)
   68             }
   69             for(auto it : g[node]) {
   70                 int act = query(out[it]) - query(in[it] - 1);
   71                 ans += 1LL * act * (sum - act);
   72             }
   73             answer[node] = ans;
   74         }
   75     }
   76     for(int i = 1 ; i <= n ; ++ i)
   77         cout << answer[i] << '\n';
   78 }

C++ solution (Gabriel Inelus)

    1 #include <bits/stdc++.h>

    2 #define Nmax 100005

    3 
    4 using namespace std;
    5 
    6 int N;
    7 long long Value[Nmax],BigValue[Nmax];
    8 long long rasp[Nmax];
    9 int poz[Nmax];
   10 int len[Nmax];
   11 vector<int> G[Nmax];
   12 vector<pair<long long,int> > P;
   13 
   14 void Read()
   15 {
   16     long long a;
   17     scanf("%lld",&a);
   18     N = a;
   19     for(int i = 2; i <= N; ++i){
   20         scanf("%lld",&a);
   21         G[a].push_back(i);
   22     }
   23     for(int i = 1; i <= N; ++i){
   24         scanf("%lld",&a);
   25         Value[i] = a;
   26         P.push_back(make_pair(a,i));
   27     }
   28     for(int i = 1; i <= N; ++i){
   29         scanf("%lld",&a);
   30         BigValue[i] = a;
   31         P.push_back(make_pair(a,i+N));
   32     }
   33     sort(P.begin(),P.end());
   34 }
   35 int eup;
   36 
   37 void DFS(int k){
   38     ++eup;
   39     poz[k] = eup;
   40     for(auto it : G[k])
   41         DFS(it);
   42     len[k] = eup - poz[k] + 1;
   43 }
   44 
   45 int A,B,pos;
   46 long long answer;
   47 
   48 class SegmentTree{
   49 public:
   50     vector<int> range;
   51     void Resize(int k){
   52         range.resize(1 <<( (int)ceil(log2( (double) k)) + 1 ) );
   53     }
   54     void Update(int li,int lf,int pz){
   55         if(li == lf){
   56             range[pz] = 1;
   57             return;
   58         }
   59         int m = li + ((lf - li) >> 1);
   60         if(pos <= m) Update(li,m,pz<<1);
   61         else Update(m+1,lf,(pz<<1)|1);
   62         range[pz] = range[pz<<1] + range[(pz<<1)|1];
   63     }
   64     void Querry(int li,int lf,int pz){
   65         if(A <= li && lf <= B){
   66             answer += range[pz];
   67             return;
   68         }
   69         int m = li + ((lf - li) >> 1);
   70         if(A <= m) Querry(li,m,pz<<1);
   71         if(B > m) Querry(m+1,lf,(pz<<1)|1);
   72     }
   73 };
   74 SegmentTree Aint;
   75 
   76 void Solve()
   77 {
   78     Aint.Resize(N);
   79     int crt;
   80     long long cst,S,vi;
   81     for( auto it : P )
   82     {
   83         crt = it.second;
   84         cst = it.first;
   85         if(crt <= N){
   86             pos = poz[crt];
   87             Aint.Update(1,N,1);
   88         }
   89         else
   90             {
   91                 crt -= N;
   92                 answer = 0;
   93                 A = poz[crt] + 1;
   94                 B = poz[crt] + len[crt] - 1;
   95                 if(A <= B){
   96                     Aint.Querry(1,N,1);
   97                     S = answer;
   98                 }
   99                 else
  100                     S = 0;
  101                 for(auto jt : G[crt]){
  102                     answer = 0;
  103                     A = poz[jt];
  104                     B = poz[jt] + len[jt] - 1;
  105                     Aint.Querry(1,N,1);
  106                     rasp[crt] += answer * (S - answer);
  107                 }
  108                 if(Value[crt] <= BigValue[crt])
  109                     rasp[crt] += ((S<<1)|1);
  110             }
  111     }
  112     for(int i = 1; i <= N; ++i)
  113         printf("%lld\n",rasp[i]);
  114 }
  115 
  116 int main()
  117 {
  118 
  119     Read();
  120     DFS(1);
  121     Solve();
  122 
  123     return 0;
  124 }
Questions?

Sponsors Gold