# -*- coding:utf-8 -*-
# Time:2019/3/1815:15
# author: Tengfei Shen
# File:util.py

import json
import os
import pickle

import numpy as np


def load(dirname):
    preproc_f = os.path.join(dirname, 'preproc.bin')
    with open(preproc_f, 'rb') as fid:
        preproc = pickle.load(fid)
    return preproc


def save(preproc, dirname):
    preproc_f = os.path.join(dirname, 'preproc.bin')
    with open(preproc_f, 'wb') as fid:
        pickle.dump(preproc, fid)


def make_json(save_path, params):
    with open(save_path, 'w') as fid:
        json.dump(params, fid)
        fid.write('\n')


def load_json(data_json):
    return json.load(open(data_json, 'r'))


def plot_model_history(history, savedir):
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(figsize=(12, 6), ncols=2, nrows=1)
    # 绘制训练 & 验证的准确率值
    epochs = np.arange(len(history.history['acc']))
    axes[0].plot(epochs, history.history['acc'])
    axes[0].plot(epochs, history.history['val_acc'])
    axes[0].set_title('Model accuracy')
    axes[0].set_ylabel('Acc')
    axes[0].set_xlabel('Epoch')
    axes[0].legend(['Train', 'Test'], loc='upper left')

    # 绘制训练 & 验证的损失值
    axes[1].plot(epochs, history.history['loss'])
    axes[1].plot(epochs, history.history['val_loss'])
    axes[1].set_title('Model loss')
    axes[1].set_ylabel('Loss')
    axes[1].set_xlabel('Epoch')
    axes[1].legend(['Train', 'Test'], loc='upper left')

    model_file = os.path.join(savedir, 'train_history.pdf')
    plt.savefig(model_file)
    plt.close()
