BZOJ4027: Rabbits and Sakura [HEOI2015]

Problem Description: You are given a sakura tree(try to be descriptive). At each node of the tree(actually a tree structure!), there is a weight and children nodes. The rabbits want to remove as many nodes as possible on the sakura tree(don’t know why) while following the several rules. For each node the rabbits remove, its children nodes are concatenated to its nearest existing father, and the weight of itself is passed to that father as well. For each node to exist, it has to have a sum of the weight of itself and the number of its children less than M. Please give out the maximum number of nodes the rabbits can remove.

Input Format:
The first line contains two integers N ≤ 2000000, number of nodes, and M ≤ 100000.
The second line contains N integers representing the weight of each node.
The next N lines start with one integer k, representing the number of children of the ith node. Then, k number follows, representing the children nodes.

Output Format:
One number, representing the maximum number of nodes the rabbits can remove.

Sample Input:
10 4 0 2 2 2 4 1 0 4 1 1 3 6 2 3 1 9 1 8 1 1 0 0 2 7 4 0 1 5 0

Sample Output:
4

Solution: It is a greedy problem. The nodes that are closer to the leaves are nodes that are more potentially allow the rabbits to delete. Also, for each node u, the cost to delete it is $w_u + son(u)$. So, we can sort the children nodes based on the cost when doing a DFS, and delete any children node that can fit in the constraint.

#include <iostream>
#include <cstring>
#include <string>
#include <vector>
#include <queue>
#include <map>
#include <set>
#include <stdio.h>
#include <fstream>
#include <cmath>
#include <stdlib.h>
#include <iomanip>
#include <algorithm>
#include <limits.h>
#include <stack>
using namespace std;
#define FAST_IO ios::sync_with_stdio(false);
typedef pair<int,int> pii;
typedef pair<long long,long long> pll;
typedef pair<double,double> pdd;
typedef long long ll;
const int maxn = 2e6 + 10;
vector<int> G[maxn];
int c[maxn],N,M,cnt;
bool cmp(int a,int b){
return c[a] < c[b];
}
void dfs(int u){
for(int i = 0; i < G[u].size(); i++) dfs(G[u][i]);
sort(G[u].begin(),G[u].end(),cmp);
c[u] += G[u].size();
for(int i = 0; i < G[u].size(); i++){
int v = G[u][i];
if(c[u] + c[v] - 1 <= M){
c[u] += c[v] - 1;
cnt++;
}
else break;
}
}

int main(){
FAST_IO
cin>>N>>M;
for(int i = 1; i <= N; i++) cin>>c[i];
for(int i = 1; i <= N; i++){
int num;
cin>>num;
for(int j = 1; j <= num; j++){
int nxt;
cin>>nxt;
nxt++;
G[i].push_back(nxt);
}
}
dfs(1);
cout<<cnt<<endl;
return 0;
}