require 'torch'
require 'nn'


function buildModel(option)
  print '==> Construct Neural Nets model'

  torch.manualSeed(option.net_init_seed)

  --model = nn.Sequential()
  model = nn.Sequential()
  if option.feature_ex_type == 'conv' then -- For convolutional networks
    for i = 1, table.getn(option.conv_architecture) do
      if i == 1 then
        -- 1st convolution layer
        model:add(nn.SpatialConvolutionMM(option.nInputFeature, option.conv_architecture[i], 1, option.conv_kernel))
        model:add(nn.SpatialBatchNormalization(option.conv_architecture[i]))
        model:add(nn.ReLU())
        model:add(nn.SpatialMaxPooling(1, option.conv_pool))
        -- Calculate # of outputs
        nConvOut = math.floor((option.inputSize - option.conv_kernel + 1)/option.conv_pool)
      else
        -- 2nd+ convolution layer
        model:add(nn.SpatialConvolutionMM(option.conv_architecture[i-1], option.conv_architecture[i], 1, option.conv_kernel))
        model:add(nn.SpatialBatchNormalization(option.conv_architecture[i]))
        model:add(nn.ReLU())
        if i==2  then
          model:add(nn.SpatialMaxPooling(1, option.conv_pool))
          nConvOut = math.floor((nConvOut - option.conv_kernel + 1)/option.conv_pool)
        else
        -- Calculate # of outputs
          nConvOut = math.floor((nConvOut - option.conv_kernel + 1))---2))--/option.conv_pool)
        end
        --nConvOut = math.floor((nConvOut - option.conv_kernel + 1))---2))--/option.conv_pool)

      end
    end
    n_feature_out = option.conv_architecture[table.getn(option.conv_architecture)]*nConvOut*1
  end


  if option.feature_ex_type == 'mlp' or option.feature_ex_type == 'gauss' then
    n_feature_out = option.inputSize

  end

  -- Standard MLP
  model:add(nn.Reshape(n_feature_out))
  n_mlp_layer = table.getn(option.mlp_architecture)

  for i = 1, n_mlp_layer do
    if i == 1 then
      model:add(nn.Linear(n_feature_out, option.mlp_architecture[i]))
      model:add(nn.BatchNormalization(option.mlp_architecture[i]))
      model:add(nn.ReLU())
    else
      model:add(nn.Linear(option.mlp_architecture[i-1], option.mlp_architecture[i]))
      model:add(nn.BatchNormalization(option.mlp_architecture[i]))
      model:add(nn.ReLU())
    end
    if option.dropout_rate > 0 and option.dropout_rate < 1 then
      model:add(nn.Dropout(option.dropout_rate))
    end
  end


  model:add(nn.Linear(option.mlp_architecture[n_mlp_layer], option.nTarget))
  model:add(nn.BatchNormalization(option.nTarget))
  model:add(nn.LogSoftMax())

  model_para1, model_para2 = model:getParameters()
  --print(model_para1)
  --model_para2=model_para2:double()
  --for i = 1,25 do

  --  rmodel[i]=rmodel[i]:double()

  --end
  rmodel[1]=rmodel[1]:double()
  rmodel[2]=rmodel[2]:double()
  rmodel[3]=rmodel[3]:double()
  rmodel[4]=rmodel[4]:double()
  rmodel[5]=rmodel[5]:double()
  rmodel[6]=rmodel[6]:double()
  rmodel[7]=rmodel[7]:double()
  rmodel[8]=rmodel[8]:double()
  rmodel[9]=rmodel[9]:double()
  rmodel[10]=rmodel[10]:double()
  rmodel[11]=rmodel[11]:double()
  rmodel[12]=rmodel[12]:double()
  rmodel[13]=rmodel[13]:double()
  rmodel[14]=rmodel[14]:double()
  rmodel[15]=rmodel[15]:double()
  rmodel[16]=rmodel[16]:double()
  rmodel[17]=rmodel[17]:double()
  rmodel[18]=rmodel[18]:double()
  rmodel[19]=rmodel[19]:double()
  rmodel[20]=rmodel[20]:double()
  rmodel[21]=rmodel[21]:double()
  rmodel[22]=rmodel[22]:double()
  rmodel[23]=rmodel[23]:double()
  rmodel[24]=rmodel[24]:double()
  rmodel[25]=rmodel[25]:double()
  eps= 1e-5

  --print(model_para1:size())
  --print(rmodel[1]:size())
  --rmodel[8]=rmodel[8]:double()
  model_para1:copy(rmodel[1])
  model.modules[2].weight:copy(rmodel[2])
  model.modules[2].bias:copy(rmodel[3])
  model.modules[2].running_mean:copy(rmodel[4])
  model.modules[2].running_var:copy(rmodel[5]:pow(-2):add(-eps))
  model.modules[6].weight:copy(rmodel[6])
  model.modules[6].bias:copy(rmodel[7])
  model.modules[6].running_mean:copy(rmodel[8])
  model.modules[6].running_var:copy(rmodel[9]:pow(-2):add(-eps))
  model.modules[10].weight:copy(rmodel[10])
  model.modules[10].bias:copy(rmodel[11])
  model.modules[10].running_mean:copy(rmodel[12])
  model.modules[10].running_var:copy(rmodel[13]:pow(-2):add(-eps))
  model.modules[13].weight:copy(rmodel[14])
  model.modules[13].bias:copy(rmodel[15])
  model.modules[13].running_mean:copy(rmodel[16])
  model.modules[13].running_var:copy(rmodel[17]:pow(-2):add(-eps))
  model.modules[17].weight:copy(rmodel[18])
  model.modules[17].bias:copy(rmodel[19])
  model.modules[17].running_mean:copy(rmodel[20])
  model.modules[17].running_var:copy(rmodel[21]:pow(-2):add(-eps))
  model.modules[21].weight:copy(rmodel[22])
  model.modules[21].bias:copy(rmodel[23])
  model.modules[21].running_mean:copy(rmodel[24])
  model.modules[21].running_var:copy(rmodel[25]:pow(-2):add(-eps))
  --model.modules[2]=(rmodel[2]):clone('weight','bias')
  --model.modules[6]=(rmodel[3]):clone('weight','bias')
  --model.modules[10]=(rmodel[4]):clone('weight','bias')
  --model.modules[13]=(rmodel[5]):clone('weight','bias')
  --model.modules[17]=(rmodel[6]):clone('weight','bias')
  --model.modules[21]=(rmodel[7]):clone('weight','bias')
  --model.modules[24]=(rmodel[8]):clone('weight','bias')


  if option.cuda then
    model:cuda()
  end
 -- cudnn.convert(model,cudnn)

  print(model)
  return model
end
