import os
import matplotlib.pyplot as plt
import numpy as np
[docs]def find_max_relative_error(preds, yval):
"""
Find maximum error and its relative value if possible.
Args:
preds (np.array): Prediction array.
yval (np.array): Validation array.
Returns:
pred_err (np.array): Flatten maximum error along axis=0
prelm (np.array): Flatten Relative maximum error along axis=0
"""
pred = np.reshape(preds, (preds.shape[0], -1))
flat_yval = np.reshape(yval, (yval.shape[0], -1))
maxerr_ind = np.expand_dims(np.argmax(np.abs(pred - flat_yval), axis=0), axis=0)
pred_err = np.abs(np.take_along_axis(pred, maxerr_ind, axis=0) -
np.take_along_axis(flat_yval, maxerr_ind, axis=0))
with np.errstate(divide='ignore', invalid='ignore'):
prelm = pred_err / np.abs(np.take_along_axis(flat_yval, maxerr_ind, axis=0))
pred_err = pred_err.flatten()
prelm = prelm.flatten()
return pred_err, prelm
[docs]def plot_error_vec_mean(
y_pred,
y_true,
label_curves="Vector",
unit_predicted="#",
filename='fit',
dir_save="",
save_plot_to_file=False,
filetypeout='.png',
x_label="Vector components",
plot_title="Component mean error"
):
# Forces mean
if not isinstance(y_pred, list):
y_pred = [y_pred]
if not isinstance(y_true, list):
y_true = [y_true]
if isinstance(label_curves, str):
label_curves = [label_curves]
fig = plt.figure()
for i in range(len(y_pred)):
preds = np.mean(np.abs(y_pred[i] - y_true[i]), axis=0).flatten()
if i < len(label_curves):
temp_label = label_curves[i]
else:
temp_label = "Vector"
plt.plot(np.arange(len(preds)), preds, label=temp_label)
plt.ylabel('Mean absolute error ' + "[" + unit_predicted + "]")
plt.legend(loc='upper right')
plt.xlabel(x_label)
plt.title(plot_title)
if save_plot_to_file:
outname = os.path.join(dir_save, filename + "_mean" + filetypeout)
plt.savefig(outname)
return fig
[docs]def plot_error_vec_max(y_pred,
y_true,
label_curves="Vector",
unit_predicted="#",
filename='fit',
dir_save="",
save_plot_to_file=False,
filetypeout='.png',
x_label="Vector components",
plot_title="Component max error"):
if not isinstance(y_pred, list):
y_pred = [y_pred]
if not isinstance(y_true, list):
y_true = [y_true]
if isinstance(label_curves, str):
label_curves = [label_curves]
err_max = []
err_rel = []
for i in range(len(y_pred)):
temp_err, temp_rel = find_max_relative_error(y_pred[i], y_true[i])
err_max.append(temp_err)
err_rel.append(temp_rel)
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
for i in range(len(err_max)):
if i < len(label_curves):
temp_label = label_curves[i]
else:
temp_label = "Vector"
ax1.plot(np.arange(len(err_max[i])), err_max[i], label="Max " + temp_label)
plt.ylabel('Max absolute error ' + "[" + unit_predicted + "]")
plt.legend(loc='upper left')
ax2 = fig1.add_subplot(111, sharex=ax1, frameon=False)
for i in range(len(err_rel)):
if i < len(label_curves):
temp_label = label_curves[i]
else:
temp_label = "Vector"
ax2.plot(np.arange(len(err_rel[i])), err_rel[i], label='Rel. ' + temp_label)
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position("right")
plt.ylabel("Relative max error")
plt.legend(loc='upper right')
plt.xlabel(x_label)
plt.title(plot_title)
if save_plot_to_file:
outname = os.path.join(dir_save, filename + "_max" + filetypeout)
plt.savefig(outname)
return fig1