Problem - C - Codeforces
给定两个序列a和b,a不能动,b可任意重排,然后把b插入a中得到新序列c,求c的最小逆序对个数。
观察可以得到一个结论,依次把b往a插入,插入第一个b中元素时,如果把bi插在ps(i)中最优,那么一定有bi <= bj时, ps(i) <= ps(j)
简单证明:假设bi增大,然后它的最优位置反而左移,那么左边一定存在一些比当前bi大的值,左移过了这些值才能更优,但是原来bi更小,因此左移过这些值对之前也更优,矛盾。
得到这个性质之后,有每次插入ab两序列之间产生的逆序对数最小,且b本身不会产生逆序对,故只要对于每个bi,求出它在a内最优位置产生的逆序对数和,加上a本身的逆序对数,一定是最优解。
之后有两个做法:
较常规做法:
把a数组、b数组都按值排序。从小到大枚举b,同时维护一棵线段树,节点i表示b插入ai 与 ai - 1之间时的代价,一开始a都没有加进来,意味着当前每个a都是大于b的,那么节点i的值 = i-1。然后b增大了,一些大于b的元素变成了等于b,那么把这些元素往后的插入空隙在线段树上对应节点值-1;一些等于b的元素变成了小于b,那么这些元素往前的插入空隙在线段树上对应节点值+1,更改完后查询整棵线段树的最小值就是当前这个b插入最优位置产生的逆序对数。
题解做法:
之前的常规做法其实没怎么用到bi最优位置随bi增加而增加的这一条性质。题解中用了一种分治求解,即每次求中间的bi的最优位置,然后递归处理下一层,递归下去时bi左侧的b元素最优可能位置+bi右侧的b元素最优可能位置才等于当前层的个数,因此复杂度也是nlogn
代码:(第二种做法)
#include <bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define sc second
#define pb push_back
#define ll long long
#define trav(v, x) for(auto v:x)
#define VI vector<int>
#define VLL vector<ll>
//define double long double
#define all(x) (x).begin(),(x).end()
using namespace std;
const double eps = 1e-10;//1e-12
const int N = 2e6 + 100;
const ll mod = 998244353;//1e9 + 7;
int n, m, a[N], b[N], ans[N];
vector<int> hav[N];
void calc(int l, int r, int lp, int rp)
{
// cerr << l << ' ' << r << ' ' << lp << ' ' << rp << '\n';
// system("pause");
int mid = (l + r) / 2;
VI buk(rp - lp + 2, 0);
for(int i = lp + 1; i <= rp; i++)
{
buk[i - lp] = buk[i - lp - 1];
if(a[i - 1] > b[mid])
++buk[i - lp];
}
//cerr << "!!" << '\n';
// for(int i = lp; i <= rp; i++)
// cerr << buk[i - lp] << ' ';
//cerr << '\n';
int mn = 1e9, res = -1, tmp = 0;
for(int i = rp; i >= lp; i--)
{
if(i < rp && a[i] < b[mid])
tmp++;
int nw = buk[i - lp];
nw += tmp;
// cerr << "??" << i << ' ' << nw << '\n';
if(nw < mn)
mn = nw, res = i;
}
ans[mid] = res;
if(l == r)
return;
if(l < mid)calc(l, mid - 1, lp, res);
if(mid < r)calc(mid + 1, r, res, rp);
}
int fen[N], num, val[N];
void upd(int x)
{
for(; x <= num; x += x & (-x))
fen[x]++;
}
int calc(int x)
{
int res = 0;
for(; x; x -= x & (-x))
res += fen[x];
return res;
}
void sol()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
cin >> a[i];
for(int i = 1; i <= m; i++)
cin >> b[i];
sort(b + 1, b + m + 1);
calc(1, m, 1, n + 1);
for(int i = 1; i <= n + 1; i++)
hav[i].clear();
for(int i = 1; i <= m; i++)
{
hav[ans[i]].pb(b[i]);
}
num = 0;
for(int i = 1; i <= n; i++)
{
trav(v, hav[i])
val[++num] = v;
val[++num] = a[i];
}
trav(v, hav[n + 1])
val[++num] = v;
// for(int i = 1; i <= num; i++)
// cerr << val[i] << ' ';
// cerr << '\n';
sort(val + 1, val + num + 1);
fill(fen, fen + num + 5, 0);
ll res = 0;
ll tot = 0;
for(int i = 1; i <= n; i++)
{
trav(v, hav[i])
{
v = lower_bound(val + 1, val + num + 1, v) - val;
res += tot - calc(v);
upd(v), ++tot;
}
a[i] = lower_bound(val + 1, val + num + 1, a[i]) - val;
res += tot - calc(a[i]);
upd(a[i]), ++tot;
}
trav(v, hav[n + 1])
{
v = lower_bound(val + 1, val + num + 1, v) - val;
res += tot - calc(v);
upd(v), ++tot;
}
cout << res << '\n';
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
int tt;
cin >> tt;
while(tt--)
{
sol();
}
}
|