Code
#include <bits/stdc++.h>
using namespace std;
long long minus_inf = -1 * (long long)1e15;
long long solve(vector<long long> &a, int n) {
vector<long long> end_before_or_at(3 * n), start_at_or_after(3 * n);
priority_queue<long long, vector<long long>, greater<long long>> minHeap;
long long largest_n_sum = 0;
for (int i = 0; i < n; i++) {
minHeap.push(a[i]);
largest_n_sum += a[i];
}
end_before_or_at[n - 1] = largest_n_sum;
for (int i = n; i < 2 * n; i++) {
minHeap.push(a[i]);
largest_n_sum += a[i];
largest_n_sum -= minHeap.top();
minHeap.pop();
end_before_or_at[i] = largest_n_sum;
}
priority_queue<long long> maxHeap;
long long smallest_n_sum = 0;
for (int i = 3 * n - 1; i >= 2 * n; i--) {
maxHeap.push(a[i]);
smallest_n_sum += a[i];
}
start_at_or_after[2 * n] = smallest_n_sum;
for (int i = 2 * n - 1; i >= n; i--) {
maxHeap.push(a[i]);
smallest_n_sum += a[i];
smallest_n_sum -= maxHeap.top();
maxHeap.pop();
start_at_or_after[i] = smallest_n_sum;
}
vector<long long> dp(3 * n, minus_inf);
for (int i = n - 1; i < 2 * n; i++) {
dp[i] = end_before_or_at[i] - start_at_or_after[i + 1];
}
return *max_element(dp.begin(), dp.end());
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> a(3 * n);
for (int i = 0; i < 3 * n; i++) {
cin >> a[i];
}
long long res = solve(a, n);
cout << res << endl;
return 0;
}
#include <bits/stdc++.h>
using namespace std;
long long minus_inf = -1 * (long long)1e15;
long long solve(vector<long long> &a, int n) {
vector<long long> end_before_or_at(3 * n), start_at_or_after(3 * n);
for (int i = n - 1; i < 2 * n; i++) {
priority_queue<long long> maxHeap;
for (int j = 0; j <= i; j++) {
maxHeap.push(a[j]);
}
long long largest_n_sum = 0;
for (int k = 0; k < n; k++) {
largest_n_sum += maxHeap.top();
maxHeap.pop();
}
end_before_or_at[i] = largest_n_sum;
}
for (int i = n; i < 2 * n + 1; i++) {
priority_queue<long long, vector<long long>, greater<long long>>
minHeap;
for (int j = i; j < 3 * n; j++) {
minHeap.push(a[j]);
}
long long smallest_n_sum = 0;
for (int k = 0; k < n; k++) {
smallest_n_sum += minHeap.top();
minHeap.pop();
}
start_at_or_after[i] = smallest_n_sum;
}
vector<long long> dp(3 * n, minus_inf);
for (int i = n - 1; i < 2 * n; i++) {
dp[i] = end_before_or_at[i] - start_at_or_after[i + 1];
}
return *max_element(dp.begin(), dp.end());
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> a(3 * n);
for (int i = 0; i < 3 * n; i++) {
cin >> a[i];
}
long long res = solve(a, n);
cout << res << endl;
return 0;
}