import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import zscore
class transform_normalize_log10():
    def __init__(self):
        pass
    def __call__(self, x):
        x = torch.log10(x+1e-6)
        return x

class transform_normalize_zscore():
    def __init__(self,axis=-1):
        self.axis=axis
    def __call__(self, x):
        mu = x.mean(dim=-1,keepdim=True)
        std = x.std(dim=-1,keepdim=True)
        x = (x - mu)/std
        x = torch.nan_to_num(x)
        return x

class transform_clip_Faxis():
    def __init__(self,min,max):
        self.min = min
        self.max = max

    def __call__(self, x):
        return x[self.min:self.max,:]


class transform_CS():
    def __init__(self):
        pass

    def __call__(self, x):
        x = x.data.cpu().numpy()
        A = x[:,:-1]
        Amu = A.mean(axis=0,keepdims=True)
        B = x[:,1:]
        Bmu = B.mean(axis=0,keepdims=True)
        Amm = A-Amu
        Bmm = B-Bmu

        top = np.sum(Amm*Bmm,axis=0,keepdims=True)
        Abot = np.sqrt(np.sum(Amm**2,axis=0,keepdims=True))
        Bbot = np.sqrt(np.sum(Bmm**2,axis=0,keepdims=True))
        y = top/(Abot*Bbot)
        CS = y * np.log10(np.sum(A,axis=0,keepdims=True)+1e-6)
        # plt.plot(CS[0,:])
        # plt.show()
        CS = torch.from_numpy(CS)
        return CS

