#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define ll long long #define pb push_back #define mp make_pair #define pii pair #define pll pair #define all(x) (x).begin(), (x).end() #define fi first #define se second using namespace std; const int nmax = 10005; const int kmax = 305; const int lmax = 15; int n, m, k, i, j, x, y, lim, le, ri, mi, l[kmax], r[kmax], ok, ans, sol, a[nmax], b[nmax], lev[nmax], f[lmax][nmax], cost[kmax][kmax]; vector v[nmax], V[nmax]; bitset viz; bitset used; void dfs(int x) { viz[x] = 1; for(auto it : v[x]) if(!viz[it]) { lev[it] = lev[x] + 1; f[0][it] = x; dfs(it); } } int LCA(int x, int y) { int lca; if(lev[x] > lev[y]) swap(x, y); int dif = lev[y] - lev[x]; for(int i = 0; i <= lim; i++) if((1 << i) & dif) y = f[i][y]; if(x == y) lca = x; else { for(int i = lim; i >= 0; i--) if(f[i][x] != f[i][y]) { x = f[i][x]; y = f[i][y]; } lca = f[0][x]; } return lca; } int dist(int x, int y) { int lca = LCA(x, y); return lev[x] + lev[y] - 2 * lev[lca]; } bool pairup(int x, int VAL) { if(used[x]) return 0; used[x] = 1; for(auto it : V[x]) if(cost[x][it] <= VAL) if(!r[it] || pairup(r[it], VAL)) { l[x] = it; r[it] = x; return 1; } return 0; } int check(int VAL) { memset(l, 0, sizeof(l)); memset(r, 0, sizeof(r)); sol = 0; for(ok = 1; ok;) { ok = 0; used = 0; for(i = 1; i <= k; i++) if(!l[i] && pairup(i, VAL)) { sol++; ok = 1; } } return sol; } int main() { // freopen("test.in", "r", stdin); // freopen("test.out", "w", stdout); scanf("%d%d", &n, &k); for(i = 1; i <= k; i++) scanf("%d", &a[i]); for(i = 1; i <= k; i++) scanf("%d", &b[i]); for(m = n - 1; m; m--) { scanf("%d%d", &x, &y); v[x].pb(y); v[y].pb(x); } dfs(1); for(i = 1; (1 << i) <= n; i++) for(j = 1; j <= n; j++) f[i][j] = f[i - 1][f[i - 1][j]]; lim = i - 1; for(i = 1; i <= k; i++) for(j = 1; j <= k; j++) { V[i].pb(j); cost[i][j] = dist(a[i], b[j]); } le = 0; ri = n; while(le <= ri) { int mi = (le + ri) / 2; if(check(mi) == k) { ans = mi; ri = mi - 1; } else le = mi + 1; } printf("%d\n", ans); return 0; }