import numpy as np
import argparse
print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
print('This program automatically returns rank-1 score, if you want other rank scores, please use python calc_rank.py --rank k .')
print('-------------------------------------------------------------------------------------------------------------------------')

parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rank', default=1, type=int)    
args = parser.parse_args()

data_list = ['market1501', 'mars']


for data in data_list:
    if args.rank > 50:
        print('Please input the rank value that is not greater than 50. Thanks')
        break
    raw_data = np.load(data + '.npz')
    query_name = np.array([int(s.split('_')[0]) for s in raw_data['query']])
    gallery_name = raw_data['gallery']
    distmat = raw_data['distmat']
    
    indices = np.argsort(distmat, axis=1)
    all_cmc = []
    num_valid_q = 0.
    matches = gallery_name[indices] == query_name[:, np.newaxis].astype(np.float32)
    for q_idx in range(len(indices)):
        if query_name[q_idx] not in gallery_name:
            continue
        orig_cmc = matches[q_idx]
        cmc = orig_cmc.cumsum()
        cmc[cmc > 1] = 1
        all_cmc.append(cmc[:50])
        num_valid_q += 1.
       
    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q
    print('Dataset: {0:15} Rank-{1} score: {2:.1f}'.format(data, args.rank, all_cmc[args.rank - 1] * 100))
print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
