#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
2025-07-10 Phil Weber
    Bi-Gaussian calibration function.
    
    available from:
      https://forensic-data-science.net/calibration-and-validation/#biGauss

    see:
      Morrison G.S. (2024). Bi-Gaussianized calibration of likelihood ratios. Law, Probability & Risk, 23, mgae004.
      Preprint at https://geoff-morrison.net/#biGauss2024
      
    This function performs only the bi-Gaussianized calibration with logistic-regression method.
"""

import numpy as np
import matplotlib.pyplot as plt
from numbers import Number
from typing import Optional, Tuple, List
from scipy.interpolate import interp1d
from .logistic_regression import train_logreg_fusion_regularized, lin_fusion
from .utils import mix_norm_cdf, np_unique_last, cllr_to_sigma2, cllr
from .plots import bigaussian_calibration_cdfs, bigaussian_calibration_mapping_functions
    
def biGaussianized_calibration(
    scores_test:                 np.ndarray,
    scores_s_train:              np.ndarray,
    scores_d_train:              np.ndarray,
    target_Cllr:                 Optional[float]=None,
    cllr_to_sigma2_coefs:        Optional[Tuple]=(17.665396790464737, 0.009333834837656),
    logreg_regularization_coefs: Optional[Tuple]=(0.5, 0.01, None, 50000),
    plot_cdf:                    Optional[bool]=False,
    plot_mapping:                Optional[bool]=False,
    ax:                          Optional[plt.Axes]=None,
    fig_size:                    List[float]=[15, 6],
):
    """Bi-Gaussian calibration function.
    
    available from:
      https://forensic-data-science.net/calibration-and-validation/#biGauss

    see:
      Morrison G.S. (2024). Bi-Gaussianized calibration of likelihood ratios. Law, Probability & Risk, 23, mgae004.
      Preprint at https://geoff-morrison.net/#biGauss2024
      
    This function performs only the bi-Gaussianized calibration with logistic-regression method.

    :param scores_test:            scalar or NumPy array of scores to be calibrated
    :param scores_s_train:         NumPy array of same-source scores for training
    :param scores_d_train:         NumPy array of different-source scores for training
                                     any -inf scores will be converted to a point mass at the lowest non -inf
    :param target_Cllr:            float: if supplied, value is interpreted as Cllr value, not calculated.
    :param cllr_to_sigma2_coefs:   coefficient values [b c] for Cllr_to_sigma2 function
                                     optional, default vaules are specified within the Cllr_to_sigma2 function
    :param logreg_regularization_coefs: 
                                   parameters for regularized logistic regression.
                                   For details see train_logreg_fusion_regularized().
                                     (prior:    prior probability of same-speaker,
                                      kappa:    controls strength of the "prior" distributions
                                      df:       pseudo degrees of freedom
                                      max_iter: maximum iterations of training)
    :param plot_cdf:               if true, plot cumulative density functions.
    :param plot_mapping:           if true, plot mapping functions.
    :param ax:                     Existing axis onto which to plot. If None, a new axis is created.

    :return: lnLR_test:            Array of bi-Gausianized-calibrated natural-log likelihood ratios corresponding
                                   to scores_test
    :return: lnLR_logreg_test:     Array of logistic-regression-calibrated natural-log likelihood ratios 
    :return: sigma2_target:        Variance of perfectly-calibrated bi-Gaussian system
    :return: lnLR_s_train:         Array of bi-Gausianized-calibrated natural-log likelihood ratios corresponding 
                                   to scores_s_train
    :return: lnLR_d_train:         Array of bi-Gausianized-calibrated natural-log likelihood ratios corresponding 
                                   to scores_d_train
    :return: lnLR_s_logreg_train:  Array of logistic-regression-calibrated natural-log likelihood ratios     
    :return: lnLR_d_logreg_train:  Array of logistic-regression-calibratedd natural-log likelihood ratios
    """
    # Default return values
    lnLR_s_train = None
    lnLR_d_train = None
    lnLR_test = None
    lnLR_logreg_test = None
    lnLR_s_logreg_train = None
    lnLR_d_logreg_train = None
    
    # Ensure that scores_test, scores_s_train, and scores_d_train are row vectors
    if scores_test is None:
        scores_test = []
    scores_test = np.array(scores_test).flatten()
    scores_s_train = np.array(scores_s_train).flatten()
    scores_d_train = np.array(scores_d_train).flatten()

    if np.isscalar(scores_test):
        scores_test = np.array([scores_test])
    else:
        scores_test = np.array(scores_test)

    # If target Cllr value is supplied and \in [0, 1] then use it to identify the Sigma2 target
    sigma2_target = np.NaN
    if target_Cllr is not None:
        if 0 < target_Cllr < 1:
            # Identify sigma2 of perfectly-calibrated bi-Gaussian system with the Cllr corresponding to this value.
            sigma2_target = cllr_to_sigma2(
                cllr=target_Cllr,
                cllr_to_sigma2_coefs=cllr_to_sigma2_coefs
            )
        else:
            print(
"""
Error calling biGaussianized_calibration.
If the target_Cllr argument is supplied it must be a numeric value in the range [0, 1] to interpret as target Cllr.
"""
            )
            return lnLR_test, lnLR_logreg_test, sigma2_target, \
                lnLR_s_train, lnLR_d_train, lnLR_s_logreg_train, lnLR_d_logreg_train

    # If figure has been specified, separate the axes to pass to the plotting functions
    ax_cdf, ax_mapping = None, None
    if ax is None and (plot_cdf or plot_mapping):
        if plot_cdf and plot_mapping:
            fig, ax = plt.subplots(1, 2, figsize=fig_size)
        else:
            fig, ax = plt.subplots(1, 1, figsize=[fig_size[0]/2, fig_size[1]])
    else:
        fig = None

    if ax is not None:
        if plot_cdf and plot_mapping:
            if type(ax) == np.ndarray:
                ax_cdf, ax_mapping = ax
            else:
                print("ax must be an numpy.ndarray, or None, if more than one plot is requested")
                ax_cdf = ax_mapping = None
        elif plot_cdf:
            ax_cdf, ax_mapping = ax, None
        elif plot_mapping:
            ax_cdf, ax_mapping = None, ax
            
    # Prepare the training data
    num_s_train = len(scores_s_train)
    num_d_train = len(scores_d_train)
    num_train = num_s_train + num_d_train

    # Adjust any -inf score values in the different-source training scores
    II_negInf_d_train = np.where(scores_d_train == -np.inf)[0]
    num_negInf_d_train = sum(II_negInf_d_train)
    if num_negInf_d_train > 0:
        point_mass = True

        # Replace any -inf in training scores with lowest non -inf score value (this assumes there are no -inf in
        # scores_s_train).
        scores_d_train[II_negInf_d_train] = min(scores_d_train[~II_negInf_d_train])

        # Version of different-source training scores excluding any -inf
        num_negInf_d_train = sum(II_negInf_d_train)
        num_d_train_ex_negInf = num_d_train - num_negInf_d_train
        scores_d_train_ex_negInf = scores_d_train[~II_negInf_d_train]

        ID_sorted_ex_neginf = scores_d_train_ex_negInf.argsort()
        scores_train_ex_negInf_sorted =  scores_d_train_ex_negInf[ID_sorted_ex_neginf]
    else:
        point_mass = False

    # Sort training scores
    scores_train = np.concatenate([scores_d_train, scores_s_train])
    ID_sorted = scores_train.argsort()
    scores_train_sorted = scores_train[ID_sorted]

    # Prepare the test data
    num_test = len(scores_test)

    # Convert any -inf test scores to lowest value non -inf score
    II_negInf_test = np.where(scores_test == -np.inf)[0]
    num_negInf_test = sum(II_negInf_test)
    if num_negInf_test > 0:
         min_score_test = min(scores_test[~II_negInf_test])
         scores_test[II_negInf_test] = min([min_score_test, scores_train_sorted[0]])

    # If scores_test values are above max or below min values of scores_train_sorted, cdf_test values can be above 1 or
    # below 0. To resolve this, reset such scores_test values to max or min values of scores_train_sorted.
    scores_test_limited = scores_test
    scores_test_limited[scores_test > scores_train_sorted[-1]] = scores_train_sorted[-1]
    scores_test_limited[scores_test < scores_train_sorted[0]] = scores_train_sorted[0]

    # Empirical cdfs

    # Empirical cdf for training data giving equal weighting to same-source and different-source categories
    # add 1 in denominator of empirical cdf so will not have to extrapolate to 1 for GMM cdf
    props_s = (np.ones(num_s_train) / (num_s_train + 1)) / 2
    if point_mass:
        props_d = (np.ones(num_d_train_ex_negInf) / (num_d_train + 1)) / 2
        props = np.concatenate([props_d, props_s])
        props_sorted = props[ID_sorted_ex_neginf]

        # Proportion for point mass corresponding to -inf scores
        point_mass_prop = (num_negInf_d_train / (num_d_train + 1)) / 2

        # Add point mass at location of lowest value non -inf score
        props_sorted[0] = props_sorted[0] + point_mass_prop
    else:
        props_d = (np.ones(num_d_train) / (num_d_train + 1)) / 2
        props = np.concatenate([props_d, props_s])
        props_sorted = props[ID_sorted]

    cdf_empirical = np.cumsum(props_sorted)

    # Interpolate cdf values for test scores
    # in case there are repeated score values, convert scores to unique values,
    # and use cdf of highest value with each unique score value
    if point_mass:
        scores_train_ex_negInf_sorted_unique, ID_unique = np_unique_last(scores_train_ex_negInf_sorted)
        cdf_empirical_unique = cdf_empirical[ID_unique]
        interp1_function = interp1d(
            scores_train_ex_negInf_sorted_unique, cdf_empirical_unique, kind="linear", fill_value="extrapolate"
        )
        cdf_test = interp1_function(scores_test_limited)
    else:
        scores_train_sorted_unique, ID_unique = np_unique_last(scores_train_sorted)
        cdf_empirical_unique = cdf_empirical[ID_unique]
        interp1_function = interp1d(
            scores_train_sorted_unique, cdf_empirical_unique, kind="linear", fill_value="extrapolate"
        )
        cdf_test = interp1_function(scores_test_limited)

    # cdf values for train scores
    if point_mass:
        cdf_train = np.concatenate([np.ones(num_negInf_d_train) * cdf_empirical[0], cdf_empirical])
    else:
        cdf_train = cdf_empirical

    # biGauss calibration: If Cllr was not been supplied, calculate Cllr, then sigma2_target.
    # Also perform LogReg calibration
    if target_Cllr is None:
        # Logistic regression defaults are set in the function signature, except for df which is dependent on the
        # training data.
        prior, kappa, df, max_iter = logreg_regularization_coefs
        if df is None:
            df = num_s_train

        weights = train_logreg_fusion_regularized(
            scores_s_train, scores_d_train, prior=prior, max_iter=max_iter, kappa=kappa, df=df
        )
        lnLR_d_logreg_train = lin_fusion(weights, scores_d_train)
        lnLR_s_logreg_train = lin_fusion(weights, scores_s_train)
        
        # Cllr of training data after logistic-regression calibration
        Cllr_LogReg_target = cllr(
            lnLRs_s=lnLR_s_logreg_train,
            lnLRs_d=lnLR_d_logreg_train
        )

        # sigma2 of the perfectly-calibrated bi-Gaussian system with the same Cllr.
        sigma2_target = cllr_to_sigma2(
            cllr=Cllr_LogReg_target,
            cllr_to_sigma2_coefs=cllr_to_sigma2_coefs
        )

    # Calculate target cdf given sigma2 of perfectly calibrated biGauss system,
    # then map test scores to biGauss-calibrated lnLRs via cdfs
    half_sigma2_target = sigma2_target / 2
    sigma_target = np.sqrt(sigma2_target)

    lnLR_test = np.zeros(num_test) * np.NaN
    lnLR_train_sorted = np.zeros(num_train) * np.NaN

    # GMM mixture of perfectly calibrated biGauss system
    # assume range of lnLRs will be within mu_d-4*sigma and mu_s+4sigma

    lnLR_max = half_sigma2_target + 4 * sigma_target
    lnLR_step = lnLR_max / (4 * (num_train + 1))
    lnLR_target_grid = np.arange(-lnLR_max, lnLR_max, lnLR_step)  # XXX Needs to be column?

    cdf_target_grid = mix_norm_cdf(
        sample=lnLR_target_grid,
        weights=[0.5, 0.5],
        means=[-half_sigma2_target, half_sigma2_target],
        covars=[sigma2_target, sigma2_target]
    )

    # Convert target grids to unique values to avoid problems due to numerical constraints on calculation of
    # cdf target values
    cdf_target_grid_unique, ID_unique = np.unique(cdf_target_grid, return_index=True)
    lnLR_target_grid_unique = lnLR_target_grid[ID_unique]

    # Map test scores to biGauss-calibrated lnLRs via cdfs
    interp1_function = interp1d(cdf_target_grid_unique, lnLR_target_grid_unique, kind="linear",
                                fill_value="extrapolate")
    lnLR_test = interp1_function(cdf_test)
    lnLR_train_sorted = interp1_function(cdf_train)

    # Plot cdfs
    if plot_cdf:
        if point_mass:
            cdf_for_plot = cdf_empirical
            scores_train_for_plot = scores_train_ex_negInf_sorted
        else:
            cdf_for_plot = cdf_train
            scores_train_for_plot = scores_train_sorted

        fig, ax = bigaussian_calibration_cdfs(
            lnLR_target_grid=lnLR_target_grid,
            cdf_target_grid=cdf_target_grid,
            scores_train=scores_train_for_plot,
            cdf=cdf_for_plot,
            sigma_target=sigma_target,
            linewidth = 1,
            fontsize = 12,
            ax = ax_cdf
        )

    if target_Cllr is None:
        # Add LogReg to results
        lnLR_logreg_test = lin_fusion(weights, scores_test)

    # Return calibrated training scores
    ID_unsort = np.zeros(len(ID_sorted), dtype=int)
    ID_unsort[ID_sorted] = np.arange(len(ID_sorted))
    lnLR_train = lnLR_train_sorted[ID_unsort]
    lnLR_s_train = lnLR_train[num_d_train:]
    lnLR_d_train = lnLR_train[:num_d_train]
    
    # Plot mapping functions
    if plot_mapping:
        # Pass only the requested data for plotting
        lnLR_plot = lnLR_train_sorted

        fig, ax = bigaussian_calibration_mapping_functions(
            scores_train_sorted=scores_train_sorted,
            lnLR_train_sorted=lnLR_plot,
            methods="biGauss_LogReg",
            lnLR_d_logreg_train=lnLR_d_logreg_train,
            lnLR_s_logreg_train=lnLR_s_logreg_train,
            linewidth=1,
            fontsize=12,
            ax = ax_mapping
        )

    return lnLR_test, lnLR_logreg_test, sigma2_target, \
        lnLR_s_train, lnLR_d_train, lnLR_s_logreg_train, lnLR_d_logreg_train
