2023.1.18 PM 03:00 by CBJ
來源 : https://zerojudge.tw/ShowProblem?problemid=g277 出題者 : 2021年9月APCS 標籤 : 前綴和、排序、輸入優化 難易度 : 4
解題想法 : 由於輸入量大,所以需要做輸入優化(C++ 優化見 C++ Code 中 #define fastio 的內容, Python 則需引入 sys 套件,並使用stdin.readline()代替input()使用,另外Python仍會有超時問題,故需要將整份程式包成函式來加速(加速的原因主要是因為避免掉全域變數)) 此題提供了兩種解法,第一種(C++ Code)較直覺,第二種(Python Code)較需要思考但速度較快。 第一種 - 依照題目實作,並維護最小值數列 根據題目的要求進行實作,區間和要使用前綴和來快速求得[圖像化筆記連結],而最小值的求法可以一開始先把num和index綁成pair並存入num_index[],然後將num_index[]進行大到小的排序,之後若需要找最小值,只要對num_index[]做pop_back(),就可以得到最小值和它的位置,若這個最小值的位置並不在L~R中,那就忽略它繼續pop下一個直到找到為止,如此一來便能找到最小值。 第二種 - 用dict儲存數值和索引值,然後直接把原陣列排序 每次在L~R中取最小值,可以換個方式來想,對於當前的最小值,判斷它是否位在L~R中,若依照這樣的想法實作,就可以一開始就對原陣列排序(前提是已經記好每個數值對應的索引值且算好前綴和)。 排序後開始for val in 陣列,只要val位在L~R中,就開始計算區間和(取區間和的方法一樣需搭配前綴和)並收斂L和R的範圍,其中val的index就會是題目敘述中的m。 收斂後只要L一等於R,就可以輸出答案,而答案的求法由於原陣列順序已被打亂,所以要使用前綴和的方法(前綴和需在原陣列排序前就算好),求得[L,L]的區間和(原陣列L的值),也就是前綴和(L) - 前綴和(L-1)。 ※前綴和的值最大可以到1e12,因此若使用C/C++請記得開long long。
//C++ language
#include<iostream>
#include<vector> //vector
#include<utility> //pair
#include<algorithm> //sort()
#define fastio ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define num first
#define index second
using namespace std;
int n;
vector<int>people;
vector<pair<int,int>>num_index;
int find_min(int L,int R){
while(true){
pair<int,int> i=num_index.back();
if(i.index>=L && i.index<=R){
int ret=i.index;
num_index.pop_back();
return ret;
}
else num_index.pop_back();
}
}
int main(){
fastio;
cin>>n;
int a;
people.push_back(0);
for(int i=0;i<n;i++){
cin>>a;
people.push_back(a);
num_index.push_back({a,i+1});
}
sort(num_index.begin(),num_index.end(),greater<pair<int,int>>());
vector<long long>prefix_sum(n+1);
prefix_sum[0]=0;
for(int i=1;i<=n;i++){
prefix_sum[i]=prefix_sum[i-1]+people[i];
}
int L=1,R=n;
while(L<R){
int pivot=find_min(L,R);
if(prefix_sum[pivot-1]-prefix_sum[L-1]>prefix_sum[R]-prefix_sum[pivot]){
R=pivot-1;
}else L=pivot+1;
}
cout<<people[R]<<"\n";
return 0;
}

## Python language
def main():
from sys import stdin
def input(): return stdin.readline()
n=int(input())
data=[int(x) for x in input().split()]
index={x:y for y,x in enumerate(data,1)}
pre=[0]*(n+1)
for i in range(1,n+1): pre[i]=pre[i-1]+data[i-1]
data.sort()
L=1; R=n
for val in data:
idx=index[val]
if L<=idx<=R:
Lsum = pre[idx-1]-pre[L-1]
Rsum = pre[R]-pre[idx]
if Lsum<Rsum: L=idx+1
else: R=idx-1
if L==R:
print(pre[L]-pre[L-1])
break
main()
