# filename: distributed_example.py
# import some module
import argparse
from log import set_logger,get_logger
import torch
import random
import numpy as np
from re import S
from transformers import  SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
from transformers import (
    SpeechEncoderDecoderModel,
    Wav2Vec2FeatureExtractor,
    Speech2Text2Tokenizer,
    Wav2Vec2Processor,
    
)
from torch.utils.tensorboard import SummaryWriter   
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from transformers import Trainer, TrainingArguments,Seq2SeqTrainingArguments
import soundfile as sf
from datasets import load_metric
from torch.utils.data import DataLoader, Dataset
import datasets
from transformers import Seq2SeqTrainingArguments
import numpy as np
import logging
import sys
import transformers
logger = logging.getLogger(__name__)
from transformers import SpeechEncoderDecoderModel, Speech2Text2Processor
import random
from torch.utils.data.distributed import DistributedSampler
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
from datasets import load_from_disk
import datetime
import os
processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de")

def evaluate(trainer):
    logger.info("*** Evaluate ***")
    metrics = trainer.evaluate()

    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

# distributed_step 1
# set random seed
set_random_seeds(random_seed=0)
# distributed_step 2
# set target device
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl")


training_args = TrainingArguments(output_dir='random_trainer') # 指定输出文件夹，没有会自动创建

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)

# set the main code and the modules it uses to the same log-level according to the node
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)


##tokenized_datasets=data.map(prepare_dataset,remove_columns=data.column_names["train"])
tokenized_datasets=load_from_disk("/workspace/yutengfei6/users/yutengfei6/docker-remote/train_model/dataset_test")


# distributed_step 3
# initialize process group
# use nccl backend to speedup gpu communication
#torch.distributed.init_process_group(backend='nccl')

# distributed_step 4
# set distributed sampler
# the same, you can set distributed sampler for validation set

sampler = torch.utils.data.distributed.DistributedSampler(tokenized_datasets)

# step 5
# initialize model 

#processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
#model config
configuration=SpeechEncoderDecoderConfig.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
tokenizer = Speech2Text2Tokenizer.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
# init random speech2text mode
#model = SpeechEncoderDecoderModel(config=configuration)
model = SpeechEncoderDecoderModel.from_pretrained("/workspace/yutengfei6/users/yutengfei6/docker-remote/train_model/results2/checkpoint-52500")


model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
# pre-process inputs and labels
feature_extractor = Wav2Vec2FeatureExtractor()
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.cuda()

# step 6
# wrap model with distributeddataparallel
# map device with model process, and we bind process n with gpu n(n=0,1,2,3...) by default.

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank,find_unused_parameters=True)
#processor= torch.nn.parallel.DistributedDataParallel(processor, device_ids=[args.local_rank], output_device=args.local_rank)
#model.module.generate()

# normal-----------------------

metric = load_metric("sacrebleu")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels




def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result


@dataclass
class DataCollatorWithPadding:

    processor: Speech2Text2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features=[]
        label_features=[]
        for feature in features:
            #print(np.array(feature["labels"]).shape)
            input_features .append({"input_values": feature["inputs"][0]})
            label_features .append({"input_ids": feature["labels"]})

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.tokenizer.pad(
                label_features,
                padding=True,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["inputs"] =batch["input_values"]
        batch["attention_mask"]=batch["attention_mask"]
        batch["labels"] = labels
        del batch["input_values"]

        return batch



training_args = Seq2SeqTrainingArguments(
    output_dir="./results3",
    evaluation_strategy="steps",
    learning_rate=5e-5,
    label_smoothing_factor=0.1,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    save_total_limit=10,
    num_train_epochs=1,
    eval_steps=100,
    dataloader_pin_memory=False,
    dataloader_num_workers=8,
    remove_unused_columns = True,
    do_train=True,
    do_eval=True,
    load_best_model_at_end=True,
    resume_from_checkpoint=True,
    fp16=True,
)

class CustomTrainer(Trainer):
    
    def compute_loss(self, model, inputs,return_outputs=False):
        #print(inputs)
        input_values=inputs.get("inputs")
        input_ids=inputs.get("labels")
        attention_mask=inputs.get("attention_mask")
        '''for i in range(len(input_values)):'''
        outputs = model(inputs=input_values, labels=input_ids,attention_mask=attention_mask)
        loss=outputs.loss
        #print(loss)
        return (loss, outputs) if return_outputs else loss


data_collator = DataCollatorWithPadding(processor=processor, padding=True)

trainer = CustomTrainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()


















