import numpy as np
import os
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.sampler import SequentialSampler
from tqdm import tqdm
from torchvision import transforms

mean_values = np.array([5.21073351e-03, 5.96052594e-03, 6.52974704e-03, 8.32975097e-03,
                        7.55425915e-03, 6.12404291e-03, 5.52672660e-03, 3.10601527e-03,
                        3.59511259e-03, 1.47226769e-02, 8.54850560e-03, 8.51667207e-03,
                        1.23925116e-02, 1.42803723e-02, 1.56251397e-02, 4.10747202e-03,
                        5.75026544e-03, 3.72449664e-04, 5.43951988e-04, 1.07747130e-03,
                        1.58554071e-03, 2.26363586e-03, 2.36696517e-03, 2.36265850e-03,
                        1.52725552e-03, 8.91395612e-04, 5.94564306e-04, 3.48962902e-04,
                        7.18741678e-03, 1.43822515e-02, 3.60196736e-03, 2.66160554e-04,
                        5.45355251e-05, 1.52354201e-04, 2.94964673e-04, 8.75117490e-04,
                        2.02896562e-03, 2.20067520e-03, 2.04490405e-03, 8.85738467e-04,
                        3.55792959e-04, 1.34371730e-04, 5.25345640e-05, 3.29168164e-04,
                        5.38312644e-03, 5.58903860e-03, 6.88277069e-04, 1.63398407e-04,
                        1.56542024e-04, 7.85473094e-04, 6.65057963e-03, 6.84232488e-02,
                        1.85290962e-01, 7.43060410e-02, 6.87413430e-03, 7.20020791e-04,
                        1.61498887e-04, 1.65625766e-04, 7.55524961e-04, 7.25265918e-03,
                        8.41938052e-03, 9.42632614e-04, 3.83383944e-04, 7.78001733e-04,
                        3.81070077e-02, 5.19615293e-01, 5.50220191e-01, 5.38794518e-01,
                        5.34448326e-01, 5.43235421e-01, 4.08528857e-02, 7.97818648e-04,
                        4.15629969e-04, 1.06752687e-03, 6.63061999e-03, 7.94411451e-03,
                        2.01964984e-03, 1.29357260e-03, 1.17227593e-02, 5.54740846e-01,
                        6.38364494e-01, 6.20892465e-01, 6.22644603e-01, 6.24383211e-01,
                        6.07711077e-01, 5.28471947e-01, 1.29190683e-02, 1.27536803e-03,
                        2.48343428e-03, 8.05027969e-03, 1.07689695e-02, 3.70697421e-03,
                        3.40944924e-03, 2.57935375e-01, 6.08753502e-01, 6.18356168e-01,
                        6.01383507e-01, 6.38412476e-01, 6.33368015e-01, 5.99831939e-01,
                        6.09605312e-01, 2.63052583e-01, 3.10447975e-03, 3.51518416e-03,
                        1.09766228e-02, 1.12014152e-02, 5.27000520e-03, 4.80492925e-03,
                        4.88818109e-01, 6.12844348e-01, 6.83741450e-01, 6.57797813e-01,
                        6.71097219e-01, 6.20189548e-01, 5.99870145e-01, 6.46305442e-01,
                        4.74788874e-01, 5.40654734e-03, 4.76881908e-03, 1.44362412e-02,
                        1.31467385e-02, 9.69267217e-04, 1.47422845e-03, 2.22093716e-01,
                        3.45182747e-01, 3.51607233e-01, 3.85757565e-01, 3.65479916e-01,
                        3.56725693e-01, 3.87904257e-01, 3.12016398e-01, 2.10684806e-01,
                        1.46087864e-03, 1.33699412e-03, 3.69213661e-03, 3.74368811e-03,
                        1.46066595e-03, 1.28145388e-03, 2.67840642e-02, 4.66476470e-01,
                        4.65729028e-01, 4.42802548e-01, 4.96096462e-01, 5.07753849e-01,
                        4.94994015e-01, 4.47156101e-01, 2.86377966e-02, 1.15813117e-03,
                        1.61565270e-03, 7.82474503e-03, 6.95808930e-03, 8.88462702e-04,
                        3.94368777e-04, 1.50357396e-03, 2.97121108e-01, 4.28565830e-01,
                        4.72367257e-01, 4.91698295e-01, 4.64758307e-01, 4.56683666e-01,
                        2.93488562e-01, 1.36640086e-03, 4.21684934e-04, 9.23170010e-04,
                        4.85366676e-03, 5.75961359e-03, 6.32391835e-04, 1.52664841e-04,
                        2.04042153e-04, 1.86329673e-03, 9.33151469e-02, 4.54616725e-01,
                        4.68144625e-01, 4.37754482e-01, 9.08695906e-02, 2.01791874e-03,
                        2.14816740e-04, 1.83709461e-04, 6.81631325e-04, 6.37149159e-03,
                        5.97607112e-03, 2.33276718e-04, 4.52165041e-05, 1.37156196e-04,
                        4.86189412e-04, 1.54880504e-03, 4.25410829e-03, 5.77845145e-03,
                        4.47372254e-03, 1.70160609e-03, 4.54137218e-04, 1.39682335e-04,
                        4.37503550e-05, 1.98516616e-04, 2.58366042e-03, 1.90112100e-03,
                        4.06355178e-03, 1.77261623e-04, 6.58498029e-04, 1.12581032e-03,
                        2.83944141e-03, 4.71789809e-03, 4.93934331e-03, 4.64486564e-03,
                        2.47649080e-03, 1.31045596e-03, 6.68830879e-04, 3.28144466e-04,
                        4.30168957e-03, 1.32345064e-02, 1.27680078e-02, 1.12667335e-02,
                        3.88616696e-03, 3.44652007e-03, 6.67260401e-03, 1.01783043e-02,
                        9.80525184e-03, 1.18775144e-02, 1.09655987e-02, 9.94069688e-03,
                        6.91671390e-03, 6.63809339e-03, 5.12432400e-03, 1.46340337e-02,
                        1.28765143e-02, 8.54816567e-03, 6.83800410e-03, 8.33844952e-03,
                        8.99346918e-03, 1.11768600e-02, 6.82592299e-03, 3.04342946e-03,
                        4.46045492e-03, 2.02413369e-03, 1.38024753e-02, 8.35280865e-03,
                        1.27791120e-02, 1.28278844e-02, 1.67672783e-02, 2.27736961e-02,
                        7.32761621e-03, 7.22986786e-03, 4.50134190e-04, 1.20252057e-03,
                        1.41426129e-03, 2.72427150e-03, 2.19693268e-03, 5.04542282e-03,
                        2.66610435e-03, 2.07267376e-03, 1.80255866e-03, 1.04959623e-03,
                        4.54697729e-04, 8.23426433e-03, 2.08205190e-02, 2.88085127e-03,
                        8.10872880e-04, 1.13475260e-04, 2.64488423e-04, 4.48195962e-04,
                        1.27643941e-03, 2.59804889e-03, 4.62584011e-03, 3.72147281e-03,
                        1.16310455e-03, 6.51090289e-04, 3.31958727e-04, 8.46849434e-05,
                        9.92278219e-04, 7.47290533e-03, 7.67276762e-03, 1.70946564e-03,
                        2.51111516e-04, 2.34064079e-04, 1.15025463e-03, 1.11541776e-02,
                        1.61570728e-01, 2.98416883e-01, 1.51965499e-01, 1.29125165e-02,
                        1.04488863e-03, 2.13135223e-04, 2.48505239e-04, 1.64008699e-03,
                        9.26411897e-03, 1.09298229e-02, 1.50834699e-03, 5.97650360e-04,
                        1.46997499e-03, 5.25715537e-02, 5.80468774e-01, 6.73690975e-01,
                        4.32842940e-01, 5.50267875e-01, 7.11299777e-01, 7.66917318e-02,
                        1.34765520e-03, 9.47053311e-04, 2.39004730e-03, 1.19933467e-02,
                        1.02696884e-02, 2.95389467e-03, 2.14847689e-03, 1.86664425e-02,
                        6.15558207e-01, 6.90045893e-01, 7.66083300e-01, 8.01129639e-01,
                        6.58120573e-01, 6.52828753e-01, 6.32849276e-01, 1.89496931e-02,
                        2.36307457e-03, 3.16214608e-03, 1.32529885e-02, 1.18729388e-02,
                        5.10423677e-03, 3.87612544e-03, 3.59947711e-01, 6.71840966e-01,
                        6.71275616e-01, 7.29919791e-01, 8.24683487e-01, 7.82505691e-01,
                        6.65502846e-01, 6.92349315e-01, 3.37825298e-01, 5.30025642e-03,
                        7.63042457e-03, 1.89402848e-02, 1.79172885e-02, 6.63117133e-03,
                        7.26527628e-03, 5.96868157e-01, 8.38199794e-01, 8.76678824e-01,
                        7.63823032e-01, 8.34678233e-01, 6.67513490e-01, 8.10814500e-01,
                        6.61260128e-01, 5.75417221e-01, 8.03123973e-03, 1.04896920e-02,
                        2.55016200e-02, 1.52412709e-02, 9.25937260e-04, 8.59897467e-04,
                        1.99403673e-01, 1.94671571e-01, 2.61967719e-01, 2.02071846e-01,
                        2.90098280e-01, 3.09797615e-01, 1.30816489e-01, 1.80421576e-01,
                        1.19511932e-01, 1.52520742e-03, 1.19531667e-03, 3.55644873e-03,
                        2.22127489e-03, 1.28676603e-03, 1.16098567e-03, 2.27860045e-02,
                        3.27795684e-01, 5.25207102e-01, 3.85442644e-01, 4.23855305e-01,
                        4.57308412e-01, 5.83417594e-01, 4.28043753e-01, 3.28853428e-02,
                        1.10662868e-03, 1.77618791e-03, 9.14906617e-03, 1.07000358e-02,
                        1.48068683e-03, 5.30234887e-04, 1.16718072e-03, 1.81294382e-01,
                        4.01161909e-01, 4.80768859e-01, 6.27825856e-01, 4.71505821e-01,
                        4.86086041e-01, 2.05618963e-01, 1.12740602e-03, 3.42363870e-04,
                        2.08853232e-03, 6.53719716e-03, 7.09779281e-03, 1.16651994e-03,
                        2.64847942e-04, 3.35728429e-04, 2.76589533e-03, 6.09722175e-02,
                        4.85109657e-01, 5.19991875e-01, 5.07046163e-01, 7.42680505e-02,
                        2.93106469e-03, 3.53883835e-04, 3.76977521e-04, 1.88940670e-03,
                        1.15323700e-02, 8.37552734e-03, 3.44594067e-04, 6.68186258e-05,
                        1.10201778e-04, 5.48808603e-04, 8.11738893e-04, 5.56171034e-03,
                        3.56817478e-03, 2.61327461e-03, 2.07904237e-03, 7.26235856e-04,
                        1.25079561e-04, 9.01089152e-05, 6.09919371e-04, 4.00945917e-03,
                        3.27911880e-03, 5.66318585e-03, 7.36747985e-04, 6.88301807e-04,
                        1.46568392e-03, 2.98794243e-03, 4.88448050e-03, 6.53933687e-03,
                        6.40165806e-03, 3.90086649e-03, 1.62590784e-03, 1.49702397e-03,
                        5.31952479e-04, 7.29273586e-03, 1.34410346e-02, 1.53368441e-02,
                        1.70969907e-02, 8.28063674e-03, 1.06186820e-02, 1.24047799e-02,
                        1.04878610e-02, 1.37758106e-02, 1.59156248e-02, 1.25700869e-02,
                        1.01127671e-02, 1.29586672e-02, 1.17351692e-02, 6.88386569e-03,
                        2.20175218e-02, 1.23268589e-02, 1.59345870e-03, 1.93019630e-03,
                        2.51414231e-03, 4.05482342e-03, 2.43645767e-03, 2.93860678e-03,
                        1.67947158e-03, 1.68556126e-03, 4.58753348e-04, 5.64524811e-03,
                        5.03191212e-03, 7.53232744e-03, 1.87556464e-02, 1.33160017e-02,
                        1.88028663e-02, 8.50014912e-04, 4.33128653e-03, 3.00359039e-04,
                        5.61996829e-04, 8.35156708e-04, 1.55571743e-03, 1.78498530e-03,
                        2.02922965e-03, 2.82443408e-03, 1.14573387e-03, 6.16178382e-04,
                        5.71472978e-04, 2.73021520e-04, 3.16460291e-03, 1.18379099e-02,
                        6.27483067e-04, 2.92281155e-04, 5.34110077e-05, 1.21749581e-04,
                        2.84319802e-04, 6.75221672e-04, 1.20604876e-03, 1.89756940e-03,
                        2.02244846e-03, 8.19853332e-04, 3.09789990e-04, 1.00870682e-04,
                        3.06079455e-05, 3.33528063e-04, 4.11921320e-03, 3.22430604e-03,
                        1.12686295e-03, 2.02045208e-04, 1.48915380e-04, 1.33449573e-03,
                        1.42971361e-02, 2.09922373e-01, 3.27757150e-01, 1.86940402e-01,
                        9.37532634e-03, 8.61637469e-04, 1.88830760e-04, 2.00314404e-04,
                        9.08196787e-04, 8.29314813e-03, 8.67725164e-03, 1.59604277e-03,
                        2.96744343e-04, 5.30609454e-04, 6.58579841e-02, 4.48129207e-01,
                        4.04519051e-01, 3.63419116e-01, 3.67473215e-01, 3.37553054e-01,
                        7.61844069e-02, 7.07922271e-04, 2.44694907e-04, 8.49747506e-04,
                        8.44838098e-03, 6.48280932e-03, 2.64404225e-03, 1.31498359e-03,
                        1.36348903e-02, 5.06942034e-01, 5.44347465e-01, 5.18801272e-01,
                        5.94374955e-01, 5.86825550e-01, 5.59473753e-01, 5.06055772e-01,
                        1.97970010e-02, 1.78285094e-03, 3.09713208e-03, 1.29154902e-02,
                        1.01072500e-02, 5.22186561e-03, 2.81176460e-03, 3.08652669e-01,
                        5.78351438e-01, 5.99249780e-01, 6.82491004e-01, 6.40014768e-01,
                        6.12095296e-01, 5.10813773e-01, 6.00676239e-01, 2.61414587e-01,
                        3.15945526e-03, 4.83301841e-03, 1.44198742e-02, 1.43367685e-02,
                        6.76493673e-03, 6.59597199e-03, 4.72649693e-01, 6.31339908e-01,
                        6.69047058e-01, 6.89712942e-01, 7.22025156e-01, 7.41040111e-01,
                        6.97285891e-01, 7.38171101e-01, 5.55930376e-01, 6.57503074e-03,
                        7.97260832e-03, 2.58687772e-02, 1.43119860e-02, 3.17366612e-05,
                        3.18469793e-05, 1.54468352e-02, 3.49357165e-02, 3.30448449e-02,
                        3.68764289e-02, 2.84136105e-02, 3.27512883e-02, 4.23680879e-02,
                        2.35325973e-02, 1.46437958e-02, 3.38052232e-05, 2.85949700e-05,
                        6.64191393e-05, 5.59874316e-05, 4.59934265e-04, 2.25458338e-04,
                        7.10259937e-03, 2.13034362e-01, 2.70803630e-01, 2.06331864e-01,
                        1.99850202e-01, 2.14906335e-01, 2.52297133e-01, 2.45496601e-01,
                        8.95974506e-03, 4.59220697e-04, 8.69376643e-04, 4.20223316e-03,
                        2.38550431e-03, 3.32562777e-04, 1.33323367e-04, 4.06890205e-04,
                        7.94254243e-02, 2.40473256e-01, 2.29808524e-01, 2.84039289e-01,
                        2.07059562e-01, 2.45421469e-01, 1.15181439e-01, 4.75331501e-04,
                        1.47548169e-04, 3.55359050e-04, 3.13949911e-03, 1.12738740e-03,
                        6.32496958e-04, 1.99309055e-04, 2.28555917e-04, 1.08263642e-03,
                        3.40383835e-02, 3.85946274e-01, 4.21295404e-01, 3.82916152e-01,
                        4.38181423e-02, 1.36619166e-03, 2.26158561e-04, 1.82773700e-04,
                        8.44233087e-04, 6.81386748e-03, 3.50091467e-03, 9.03273540e-05,
                        1.93806300e-05, 4.58813302e-05, 2.42419643e-04, 4.83188487e-04,
                        1.61352020e-03, 2.68079713e-03, 1.36998028e-03, 6.56081247e-04,
                        9.74724753e-05, 8.80579828e-05, 3.01282653e-05, 1.62141543e-04,
                        7.64094933e-04, 1.96419773e-03, 2.39577633e-03, 4.01879166e-04,
                        1.03631872e-03, 1.28338777e-03, 1.76755083e-03, 3.58600239e-03,
                        3.67859355e-03, 3.12745967e-03, 1.92313862e-03, 1.20908953e-03,
                        8.31874262e-04, 3.48516129e-04, 6.50074240e-03, 1.48388250e-02,
                        1.10633913e-02, 1.31731434e-02, 3.62910773e-03, 7.88423698e-03,
                        7.34688668e-03, 1.01467837e-02, 1.11299865e-02, 1.24588087e-02,
                        1.43662449e-02, 1.08202770e-02, 9.52879619e-03, 7.94390496e-03,
                        6.74496125e-03, 1.41655942e-02, 1.26175145e-02], dtype=np.float32)

std_values = np.array([0.01100672, 0.01272034, 0.0141785, 0.0156693, 0.01513652,
                       0.01319228, 0.01126058, 0.00818644, 0.00773956, 0.01160436,
                       0.00924025, 0.00907213, 0.01137917, 0.01216084, 0.01323059,
                       0.0087936, 0.00901011, 0.00362655, 0.004775, 0.00716196,
                       0.00943687, 0.01184711, 0.01203932, 0.0121731, 0.00943377,
                       0.00675638, 0.00500216, 0.00359341, 0.00913122, 0.01239697,
                       0.00755421, 0.00304553, 0.00145503, 0.00273271, 0.00410944,
                       0.00722235, 0.01157732, 0.01213369, 0.01170763, 0.00725041,
                       0.00441923, 0.00257548, 0.0014372, 0.00342638, 0.00839598,
                       0.00831958, 0.00531285, 0.00292249, 0.00293996, 0.00673439,
                       0.02016459, 0.06607992, 0.16065882, 0.07185291, 0.02002606,
                       0.00621371, 0.00278016, 0.00287805, 0.00538756, 0.01117312,
                       0.01144882, 0.00693521, 0.00471784, 0.00665318, 0.0495008,
                       0.16057129, 0.1366758, 0.13166575, 0.13283126, 0.16915733,
                       0.05143046, 0.00646907, 0.00462567, 0.00704163, 0.01181607,
                       0.01280712, 0.01056505, 0.00893032, 0.02613454, 0.16387396,
                       0.15177079, 0.15585545, 0.15709981, 0.15934902, 0.14818858,
                       0.16218421, 0.02723811, 0.00853977, 0.01123476, 0.01416031,
                       0.01653492, 0.01397828, 0.01460873, 0.22238702, 0.14834903,
                       0.15272394, 0.16390665, 0.17786762, 0.17304581, 0.1538684,
                       0.14915386, 0.23675698, 0.01431978, 0.01319908, 0.01759504,
                       0.01748068, 0.01626338, 0.01823601, 0.23607433, 0.14783563,
                       0.16855304, 0.17956127, 0.17753215, 0.17625119, 0.15870522,
                       0.15753704, 0.24007423, 0.01964404, 0.01453823, 0.02039932,
                       0.02019605, 0.00730308, 0.00908481, 0.13461885, 0.08501502,
                       0.09284301, 0.11207125, 0.11437231, 0.10608353, 0.1027153,
                       0.07785753, 0.13750836, 0.00889074, 0.00822287, 0.01087694,
                       0.01123324, 0.00883781, 0.00847987, 0.04166801, 0.11883701,
                       0.11306212, 0.11439908, 0.13671103, 0.13094759, 0.12146649,
                       0.11712251, 0.04321892, 0.00799101, 0.00873083, 0.01357353,
                       0.0134072, 0.0066868, 0.00441477, 0.00895153, 0.17270455,
                       0.1061577, 0.11295654, 0.11862757, 0.11079995, 0.11383872,
                       0.18609522, 0.00830179, 0.00450323, 0.00643008, 0.0107184,
                       0.01194367, 0.00505658, 0.00260158, 0.00305959, 0.01015972,
                       0.12212912, 0.16417994, 0.14336362, 0.16046605, 0.11539042,
                       0.01051001, 0.0031314, 0.00288068, 0.00529733, 0.0109459,
                       0.00990098, 0.00307746, 0.00133712, 0.00244568, 0.00501612,
                       0.00969026, 0.01721338, 0.0198645, 0.01786308, 0.01029711,
                       0.00483738, 0.00254873, 0.00129623, 0.00282337, 0.00727823,
                       0.00660993, 0.00868327, 0.00278485, 0.0056017, 0.0077403,
                       0.01189518, 0.01565462, 0.01616747, 0.01555215, 0.01155002,
                       0.00823737, 0.00553332, 0.00357853, 0.00869695, 0.01231851,
                       0.01231341, 0.01161248, 0.00864775, 0.0095705, 0.01353118,
                       0.0175193, 0.01870177, 0.02053992, 0.01943354, 0.01785081,
                       0.01372892, 0.01128078, 0.00871301, 0.01290778, 0.01004221,
                       0.01620477, 0.01614349, 0.01860127, 0.01952227, 0.02163078,
                       0.01656327, 0.01106036, 0.01157754, 0.00774336, 0.01392861,
                       0.01134515, 0.01388273, 0.01407501, 0.01669568, 0.02054376,
                       0.01352774, 0.01208776, 0.00426614, 0.00762958, 0.00903565,
                       0.01375056, 0.01295341, 0.01972357, 0.01435544, 0.0121974,
                       0.01050433, 0.00737417, 0.00454238, 0.01247105, 0.01927394,
                       0.00843132, 0.00557184, 0.00229666, 0.00402, 0.00566569,
                       0.01017778, 0.01520018, 0.02124906, 0.01872783, 0.00963047,
                       0.00687628, 0.00458187, 0.00202604, 0.00628206, 0.01241151,
                       0.01190158, 0.00895464, 0.00394743, 0.00396118, 0.00933409,
                       0.03008733, 0.17031208, 0.21781605, 0.15934049, 0.03202182,
                       0.00870786, 0.00361496, 0.00390661, 0.00872403, 0.01528356,
                       0.01600035, 0.00937883, 0.00658848, 0.0106707, 0.07009508,
                       0.16933252, 0.16084836, 0.10305399, 0.1338289, 0.19997129,
                       0.09792922, 0.01007129, 0.00826455, 0.01157966, 0.01922891,
                       0.01755729, 0.01391795, 0.01331661, 0.03945404, 0.17381018,
                       0.15835251, 0.172239, 0.1757778, 0.15821329, 0.15159924,
                       0.18058139, 0.03919337, 0.01406448, 0.01435093, 0.02196202,
                       0.02046996, 0.01881248, 0.01812077, 0.27865988, 0.15921278,
                       0.15771288, 0.17790446, 0.18736193, 0.1854178, 0.15853935,
                       0.16246974, 0.26947203, 0.02240364, 0.0223528, 0.0274928,
                       0.02696475, 0.02145351, 0.0263722, 0.2622944, 0.17690226,
                       0.17106454, 0.18655586, 0.1834839, 0.17976308, 0.17824483,
                       0.15465823, 0.27617958, 0.0281765, 0.02541274, 0.03217886,
                       0.02545806, 0.00798179, 0.00732402, 0.12615772, 0.04705365,
                       0.06490838, 0.05734298, 0.09018057, 0.08863317, 0.03337502,
                       0.04491676, 0.08440526, 0.01032608, 0.00899803, 0.01229195,
                       0.01052804, 0.00937482, 0.00919476, 0.03951871, 0.08490941,
                       0.12328635, 0.09259018, 0.10955741, 0.10990736, 0.13659173,
                       0.11411992, 0.05040215, 0.00890909, 0.01054776, 0.01732665,
                       0.02030392, 0.00957112, 0.00590135, 0.00870832, 0.13120908,
                       0.10228226, 0.11183137, 0.14523469, 0.11087333, 0.12483308,
                       0.15368763, 0.0083931, 0.0045577, 0.01071767, 0.01454975,
                       0.01593151, 0.00764722, 0.00391865, 0.00459683, 0.01508291,
                       0.07737335, 0.2050291, 0.17832981, 0.2145463, 0.09193982,
                       0.01552067, 0.00472566, 0.0048149, 0.00954335, 0.01757633,
                       0.01507111, 0.00392437, 0.00175339, 0.00235714, 0.00607827,
                       0.00740777, 0.02323925, 0.01691486, 0.01445788, 0.0131759,
                       0.00711928, 0.0025959, 0.00215942, 0.00524189, 0.0104421,
                       0.00989037, 0.01245589, 0.0058956, 0.00621867, 0.00984124,
                       0.01452957, 0.01939184, 0.02255636, 0.02196463, 0.01684133,
                       0.01020457, 0.00890036, 0.00496369, 0.01351925, 0.01554072,
                       0.01731609, 0.01904545, 0.01476294, 0.01863399, 0.02213822,
                       0.02115382, 0.02602896, 0.02819741, 0.02466447, 0.02132256,
                       0.02230133, 0.01839607, 0.01295659, 0.02120053, 0.01321658,
                       0.00691323, 0.00813108, 0.00962452, 0.01192506, 0.00965906,
                       0.01000771, 0.00732658, 0.00671347, 0.00372588, 0.00780296,
                       0.00763505, 0.00869346, 0.01461541, 0.0123567, 0.0157277,
                       0.00502624, 0.0082584, 0.00320149, 0.00465961, 0.00619018,
                       0.00916294, 0.01031037, 0.01114416, 0.01300317, 0.00802438,
                       0.00552652, 0.004785, 0.00308415, 0.00746125, 0.01160666,
                       0.00414271, 0.0030778, 0.00144788, 0.00243753, 0.00402974,
                       0.00645435, 0.00878584, 0.01140144, 0.01199831, 0.00717941,
                       0.00413396, 0.00215354, 0.00105886, 0.00327036, 0.00832783,
                       0.007318, 0.00640655, 0.00329453, 0.00283843, 0.00936778,
                       0.03275277, 0.17466259, 0.18210661, 0.16391434, 0.02328939,
                       0.00697368, 0.00302575, 0.00314074, 0.00569507, 0.01218818,
                       0.0120525, 0.0086128, 0.00416531, 0.00526223, 0.08771882,
                       0.12356329, 0.10143592, 0.09137215, 0.09269898, 0.0925658,
                       0.09741083, 0.00610342, 0.00346749, 0.00628214, 0.01375641,
                       0.01222175, 0.01183096, 0.00945413, 0.02997008, 0.13834533,
                       0.13407713, 0.14025484, 0.1600995, 0.15827356, 0.1416306,
                       0.13695908, 0.03839171, 0.01092416, 0.01247739, 0.01856911,
                       0.01642532, 0.01702644, 0.01414914, 0.23575127, 0.14205873,
                       0.15585305, 0.1866514, 0.19072215, 0.18144861, 0.13955651,
                       0.14653444, 0.2070298, 0.01525314, 0.01613806, 0.02053961,
                       0.02076213, 0.01959993, 0.02356215, 0.2263291, 0.15301941,
                       0.1739403, 0.19688965, 0.19262746, 0.20011221, 0.18029848,
                       0.17144282, 0.26870966, 0.02328263, 0.01998539, 0.02841141,
                       0.02184964, 0.00119995, 0.00090064, 0.01437086, 0.01006876,
                       0.01073301, 0.01357481, 0.01243331, 0.01252809, 0.01302717,
                       0.00751022, 0.01421031, 0.00091621, 0.00111301, 0.00211225,
                       0.00201998, 0.00509077, 0.00324239, 0.01689583, 0.05452484,
                       0.06655625, 0.05485385, 0.05869477, 0.05617215, 0.06200515,
                       0.06451524, 0.01894244, 0.00492615, 0.00665049, 0.01068482,
                       0.00934835, 0.00406722, 0.00250619, 0.00430068, 0.07110547,
                       0.06075451, 0.0551732, 0.07004147, 0.05074057, 0.06195797,
                       0.1009532, 0.00460881, 0.00257612, 0.00400995, 0.00926566,
                       0.00634614, 0.00505242, 0.00314721, 0.00349651, 0.00793455,
                       0.04782098, 0.19926804, 0.1576645, 0.2011195, 0.05699866,
                       0.00898759, 0.00337693, 0.00295474, 0.00575766, 0.01173661,
                       0.00936173, 0.00183763, 0.00078052, 0.00134192, 0.00359417,
                       0.00518398, 0.01026634, 0.01362705, 0.0092591, 0.00620606,
                       0.00206264, 0.00200078, 0.00104786, 0.0024861, 0.00468479,
                       0.00702678, 0.0078221, 0.00395609, 0.00682308, 0.00827341,
                       0.01032707, 0.01545988, 0.01585194, 0.01442751, 0.01090202,
                       0.00801559, 0.00605984, 0.0036064, 0.01119194, 0.01458186,
                       0.0126898, 0.01431467, 0.0091317, 0.01408327, 0.01514108,
                       0.01857145, 0.02108827, 0.02237517, 0.02380445, 0.01961635,
                       0.01678348, 0.01326806, 0.01123615, 0.0143675, 0.01208347],
                      dtype=np.float32)


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_x, data_y, random_flip=False, cache=True, adjust=None):
        self.data_x = data_x
        self.data_y = data_y
        self.random_flip = random_flip
        self.cache = cache
        self.data_cache = {}
        if data_y.dtype is torch.float16:
            self.convert = True
        else:
            self.convert = False
        if adjust is not None:
            self.adjust_leds = np.load(adjust)
        else:
            self.adjust_leds = None

    def __getitem__(self, index):
        if self.cache:
            if index in self.data_cache:
                image = self.data_cache[index]
            else:
                print("missed")
                image = self.data_x[index].clone() - 0.5
                if self.adjust_leds is not None:
                    image = image / self.adjust_leds
                image = image
                self.data_cache[index] = image
        else:
            image = self.data_x[index]
        label = self.data_y[index]
        image, label = image.cuda(), label.float().cuda()
        if self.random_flip:
            if np.random.rand() < 0.5:
                image = transforms.functional.hflip(image)
                label = transforms.functional.hflip(label)
            if np.random.rand() < 0.5:
                image = transforms.functional.vflip(image)
                label = transforms.functional.vflip(label)
        if self.convert:
            return image.float(), label.float()
        else:
            return image.float(), label


def load_progress(path, desc=''):
    try:
        mmap_array = np.load(path, mmap_mode='r')
        array = np.full_like(mmap_array, 0, dtype=mmap_array.dtype)
        block_size = 2
        n_blocks = int(np.ceil(mmap_array.shape[0] / block_size))
        for b in tqdm(range(n_blocks), desc=desc):
            array[b * block_size:(b + 1) * block_size] = mmap_array[b * block_size:(b + 1) * block_size]
    finally:
        del mmap_array
    return array


def shift_data(shift_code, data):
    if shift_code == '':
        return data
    # if we are going to shift  we can use N S E W along with a number (1-10)
    direction = shift_code[0].upper()
    amnt = int(shift_code[1:])
    assert direction in ['N', 'S', 'E', 'W']
    # we need to do a reshape first
    data = np.reshape(data, list(data.shape[0:1]) + [3, 15, 15, 256, 256])
    # now we need to cutoff the data on the correct axis
    if direction == 'S':
        data = data[:, :, amnt:]
    elif direction == 'N':
        data = data[:, :, :-amnt]
    elif direction == 'E':
        data = data[:, :, :, :-amnt]
    else:
        data = data[:, :, :, amnt:]
    # reshape the data back to the proper format
    data = np.reshape(data, [-1, 675 - amnt * 15 * 3, 256, 256])
    return data


def get_train_val_loader(config, pin_memory, num_workers=0):
    data_dir = '/hddraid5/data/colin/'
    batch_size = config.batch_size
    random_flip = config.flip
    if str(config.task).lower() == 'hela':
        data_dir = '/scratch/data/colin/hela_vf'
        train_x_path = os.path.join(data_dir, 'train_x_norm.npy')
        train_y_path = os.path.join(data_dir, f'new_nuc_train_kb8.npy')
        val_x_path = os.path.join(data_dir, 'val_x_norm.npy')
        val_y_path = os.path.join(data_dir, f'new_nuc_val_kb8.npy')
        padded_data = True
        adjust = None
    else:
        train_x_path = os.path.join(data_dir, 'ctc', 'pan_train_x.npy')
        # train_x_path = os.path.join(data_dir, 'ctc/pan_vf_v3', 'train_x_pan.npy')
        # train_x_path = '/scratch/data/colin/pan_vf_align/train_x.npy'
        # train_x_path = '/hddraid5/data/colin/pan_vf_align/train_x.npy'
        train_y_path = os.path.join(data_dir, 'ctc', f'pan_train_6_y.npy')
        #train_y_path = os.path.join(data_dir, 'ctc/pan_vf_v3', 'train_y_pan.npy')
        #train_y_path = '/scratch/data/colin/pan_vf_new/train_y.npy'
        #train_y_path = '/hddraid5/data/colin/pan_vf_align/train_y.npy'
        val_x_path = os.path.join(data_dir, 'ctc', 'pan_val_x.npy')
        # val_x_path = os.path.join(data_dir, 'ctc/pan_vf_v3', 'val_x_pan.npy')
        #val_x_path = '/hddraid5/data/colin/pan_vf_align/val_x.npy'
        val_y_path = os.path.join(data_dir, 'ctc', f'pan_val_6_y.npy')
        # val_y_path = os.path.join(data_dir, 'ctc/pan_vf_v3', 'val_y_pan.npy')
        #val_y_path = '/hddraid5/data/colin/pan_vf_align/val_y.npy'
        padded_data = True
        adjust = None

    # pytorch says channels fist
    print("Loading train_x data")
    # train_x_npy = load_progress(train_x_path)
    train_x_npy = np.load(train_x_path, mmap_mode='r')
    train_x = torch.from_numpy(train_x_npy)
    print("Loading train_y data")
    train_y = torch.from_numpy(np.load(train_y_path).astype(np.float32))
    print("Loading val_x data")
    # val_x_npy = load_progress(val_x_path)
    val_x_npy = np.load(val_x_path, mmap_mode='r')
    val_x = torch.from_numpy(val_x_npy)
    print("Loading val_y data")
    val_y = torch.from_numpy(np.load(val_y_path).astype(np.float32))

    num_leds = train_x_npy.shape[1]
    train_dataset = CustomDataset(train_x, train_y, random_flip, adjust=adjust)
    val_dataset = CustomDataset(val_x, val_y, adjust=adjust)

    train_idx, valid_idx = list(range(train_x.shape[0])), list(range(val_x.shape[0]))

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SequentialSampler(valid_idx)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=False,
    )

    valid_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=False
    )

    return train_loader, valid_loader, num_leds, padded_data
