from imps import *


# %% myunet


def fct(data):
    
    training = True
    
    att_heads = [2, 4, 8, 12, 16, 12, 8, 4, 2]
    filters = [16, 32, 64, 128, 384, 128, 64, 32, 16]
    blocks = len(filters)
    stochastic_depth_rate = 0.0
    image_size = data.shape[1]
    input_shape = (data.shape[1], data.shape[2], data.shape[3])
    

    
    class StochasticDepth(layers.Layer):
        def __init__(self, drop_prop, **kwargs):
            super(StochasticDepth, self).__init__(**kwargs)
            self.drop_prob = drop_prop
    
        def call(self, x, training=training):
            if training:
                keep_prob = 1 - self.drop_prob
                shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
                random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
                random_tensor = tf.floor(random_tensor)
                return (x / keep_prob) * random_tensor
            return x
    

    def mlp(x, hidden_units, dropout_rate): 
        for units in hidden_units:
            x1 = layers.Conv2D(units, 3, padding='same', activation=tf.nn.gelu)(x)
            x1 = layers.Dropout(0.1)(x1)
            x2 = layers.Conv2D(units, 3, padding='same', dilation_rate=2, activation=tf.nn.gelu)(x)
            x2 = layers.Dropout(0.1)(x2)
            added = layers.Add()([x1,x2])
            x_out = layers.Conv2D(units, 3, padding='same', activation=tf.nn.gelu)(added)
            x_out = layers.Dropout(0.1)(x_out)
        return x_out


 
    from tensorflow.keras.layers import Layer, Dense, Conv2D, Dropout, MultiHeadAttention, BatchNormalization, \
    DepthwiseConv2D, UpSampling2D
    from tensorflow.keras.models import Sequential
    from tensorflow import Tensor, divide, concat, random, split, reshape, transpose, float32
    from typing import List, Union, Iterable
    
        
    class Attention(Layer):

        def __init__(self,
                      # dim_in,
                      dim_out,
                      num_heads,
                      proj_drop=0.0,
                      kernel_size=3,
                      stride_kv=1, 
                      stride_q=1,
                      padding_kv="same",
                      padding_q="same",
                      attention_bias=True):
            super().__init__()
            self.stride_kv = stride_kv
            self.stride_q = stride_q
            self.dim = dim_out
            self.num_heads = num_heads
            self.scale = dim_out ** -0.5
    
            self.conv_proj_q = self._build_projection(kernel_size, stride_q, padding_q)
            self.conv_proj_k = self._build_projection(kernel_size, stride_kv, padding_kv)
            self.conv_proj_v = self._build_projection(kernel_size, stride_kv, padding_kv)
    
            self.attention = MultiHeadAttention(self.num_heads, dim_out, use_bias=attention_bias)
            self.proj_drop = Dropout(proj_drop)
            
    
        @staticmethod
        def _build_projection(kernel_size, stride, padding):
            proj = Sequential([
                DepthwiseConv2D(kernel_size, padding=padding, strides=stride, use_bias=False),
                layers.LayerNormalization(),                
            ])
            return proj
    
        def call_conv(self, x, h, w):

            q = self.conv_proj_q(x)
            k = self.conv_proj_k(x)
            v = self.conv_proj_v(x)
            
            return q, k, v
    
        def call(self, inputs, mask=None, training=training, h=1, w=1):
            x = inputs
            q, k, v = self.call_conv(x, h, w)
            x = self.attention(q, v, key=k)
            if training:
                x = self.proj_drop(x)
                
            return x
 
    def att(x_in, 
            num_heads,
            dpr, 
            proj_drop=0.0,
            attention_bias=True,
            padding_q="same",
            padding_kv="same",
            stride_kv=2,
            stride_q=1):
    
        b, h, w, c = x_in.shape
    
        attention_output = Attention(dim_out=c,
                                      num_heads=num_heads,
                                      proj_drop=proj_drop,
                                      attention_bias=attention_bias, 
                                      padding_q=padding_q, 
                                      padding_kv=padding_kv, 
                                      stride_kv=stride_kv, 
                                      stride_q=stride_q,
                                      )(x_in, h=h, w=w, training=training, mask=None)
        
        
        attention_output = StochasticDepth(dpr)(attention_output)
        attention_output = Conv2D(x_in.shape[-1], 3, 1, padding="same", activation="relu")(attention_output)
        x2 = layers.Add()([attention_output, x_in]) 
        x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
        x3 = mlp(x3, hidden_units=[c,c], dropout_rate=0.0)
        x3 = StochasticDepth(dpr)(x3)
        encoded_patches = layers.Add()([x3, x2])
        
        return encoded_patches
    
    def create_model(
        image_size=image_size,
        input_shape=input_shape,
    ):
    
        inputs = layers.Input(input_shape)
    
        # Calculate Stochastic Depth probabilities.
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]
        

        initializer = 'he_normal'
    
    
        scale_img_2 = layers.AveragePooling2D(2,2)(inputs)
        scale_img_3 = layers.AveragePooling2D(2,2)(scale_img_2)
        scale_img_4 = layers.AveragePooling2D(2,2)(scale_img_3)
       
       
          # first block
        x1 = layers.LayerNormalization(epsilon=1e-5)(inputs[:,:,:,-1])
        x11 = tf.expand_dims(x1, -1)                                
        x11 = Conv2D(filters[0], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[0], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        x11 = MaxPooling2D((2,2))(x11)
        out = att(x11, att_heads[0], dpr[0])
        skip1=out
        print("\nBlock 1 -> input:", x1.shape, "output:", skip1.shape)
        
        # second block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = concatenate([Conv2D(filters[0], 3, padding="same", activation="relu")(scale_img_2), x11], axis=3)
        x11 = Conv2D(filters[1], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[1], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.4)(x11)
        x11 = MaxPooling2D((2,2))(x11)
        out = att(x11, att_heads[1], dpr[1])
        skip2=out
        print("Block 2 -> input:", x1.shape, "output:", skip2.shape)
        
        # third block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = concatenate([Conv2D(filters[1], 3, padding="same", activation="relu")(scale_img_3), x11], axis=3)
        x11 = Conv2D(filters[2], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[2], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        x11 = MaxPooling2D((2,2))(x11)
        out = att(x11, att_heads[2], dpr[2])
        skip3=out
        print("Block 3 -> input:", x1.shape, "output:", skip3.shape)
        
        # fourth block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = concatenate([Conv2D(filters[2], 3, padding="same", activation="relu")(scale_img_4), x11], axis=3)
        x11 = Conv2D(filters[3], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[3], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        x11 = MaxPooling2D((2,2))(x11)
        out = att(x11, att_heads[3], dpr[3])
        skip4 = out
        print("Block 4 -> input:", x1.shape, "output:", skip4.shape)
         
        # fifth block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = Conv2D(filters[4], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[4], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        x11 = MaxPooling2D((2,2))(x11)
        out = att(x11, att_heads[4], dpr[4])
        
        # sixth block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = Conv2D(filters[5], 2, padding="same", activation="relu", kernel_initializer=initializer)(UpSampling2D(size=(2,2))(x11))
        x11 = concatenate([skip4,x11], axis=3)
        x11 = Conv2D(filters[5], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[5], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        out = att(x11, att_heads[5], dpr[5])       
        
        # seventh block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = Conv2D(filters[6], 2, padding="same", activation="relu", kernel_initializer=initializer)(UpSampling2D(size=(2,2))(x11))
        x11 = concatenate([skip3,x11], axis=3)
        x11 = Conv2D(filters[6], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[6], 3, 1, padding="same", activation="relu", kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        out = att(x11, att_heads[6], dpr[6])
        skip7=out
        print("Block 7 -> input:", x1.shape, "output:", skip7.shape)  
        
        # eighth block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = Conv2D(filters[7], 2, padding="same", activation="relu", kernel_initializer=initializer)(UpSampling2D(size=(2,2))(x11))
        x11 = concatenate([skip2, x11], axis=3)
        x11 = Conv2D(filters[7], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[7], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        out = att(x11, att_heads[7], dpr[7])
        skip8=out
        print("Block 8 -> input:", x1.shape, "output:", skip8.shape) 
        
        # nineth block
        x1 = layers.LayerNormalization(epsilon=1e-5)(out)
        x11=x1
        x11 = Conv2D(filters[8], 2, padding="same", activation="relu", kernel_initializer=initializer)(UpSampling2D(size=(2,2))(x11))
        x11 = concatenate([skip1, x11], axis=3)
        x11 = Conv2D(filters[8], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Conv2D(filters[8], 3, 1, padding="same", activation='relu', kernel_initializer=initializer)(x11)
        x11 = Dropout(.3)(x11)
        out = att(x11, att_heads[8], dpr[8])
        skip9=out
        print("Block 9 -> input:", x1.shape, "output:", skip9.shape) 


       
        # Deep supervision
        skip7 = layers.LayerNormalization(epsilon=1e-5)(UpSampling2D(size=(2,2))(skip7))
        out7 = Conv2D(filters[6], 3, padding="same", activation="relu", kernel_initializer=initializer)(skip7)
        out7 = Conv2D(filters[6], 3, padding="same", activation="relu", kernel_initializer=initializer)(out7)
        #
        skip8 = layers.LayerNormalization(epsilon=1e-5)(UpSampling2D(size=(2,2))(skip8))
        out8 = Conv2D(filters[7], 3, padding="same", activation="relu", kernel_initializer=initializer)(skip8)
        out8 = Conv2D(filters[7], 3, padding="same", activation="relu", kernel_initializer=initializer)(out8)
        #
        skip9 = layers.LayerNormalization(epsilon=1e-5)(UpSampling2D(size=(2,2))(skip9))
        out9 = Conv2D(filters[8], 3, padding="same", activation="relu", kernel_initializer=initializer)(skip9)
        out9 = Conv2D(filters[8], 3, padding="same", activation="relu", kernel_initializer=initializer)(out9)
        #
        #
        out7 = Conv2D(4, (1,1), activation="sigmoid")(out7)
        out8 = Conv2D(4, (1,1), activation="sigmoid")(out8)
        out9 = Conv2D(4, (1,1), activation="sigmoid")(out9)
        
        print("\n")
        print("DS 1 -> input:", skip7.shape, "output:", out7.shape) 
        print("DS 2 -> input:", skip8.shape, "output:", out8.shape) 
        print("DS 3 -> input:", skip9.shape, "output:", out9.shape) 


        model = keras.Model(inputs=inputs, outputs=[out7, out8, out9])

        
        return model    
    return create_model()

