#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
2025-07-10 Phil Weber
    Plotting functions in support of BiGaussianized calibration.
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
from scipy.stats import norm
from typing import Optional, Tuple, List

FLOAT_ROUND = 6  # Decimal points to round some float operations to obviate precision errors


# Utility functions
def _style_params_for_plots(
        **kwargs
):
    """Allow selected Matplotlib parameters to be passed through to plots."""
    linestyle = kwargs.pop("linestyle", '-')
    linewidth = kwargs.pop("linewidth", 1)
    marker = kwargs.pop("marker", None)
    markersize = kwargs.pop("markersize", 12)
    markeredgewidth = kwargs.pop("markeredgewidth", 1)
    fillstyle = kwargs.pop("fillstyle", "none")
    fontsize = kwargs.pop("fontsize", 12)

    return linestyle, linewidth, marker, markersize, markeredgewidth, fillstyle, fontsize

    
def num2str_fractions_and_commas(array):
    """Convert an array of numbers to pretty-format strings.

    Numbers greater than 1 are converted to (e.g.) "10,000".
    Decimal fractions are converted to (e.g.) "1/10,000".

    :param array: List or NumPy Array of numeric
    :return: NumPy Array of strings
    """
    array = np.array(array)
    idx_less_than_1 = array < 1
    array[idx_less_than_1] = np.round(1 / array[idx_less_than_1], FLOAT_ROUND)
    array = array.astype(int)

    str_array = [f"{element:,d}" for element in array]
    str_array = [f"1/{str_array[ii]}" if idx_less_than_1[ii] else f"{str_array[ii]}" for ii in range(len(str_array))]

    return str_array
    

# Plots
def demo_biGaussian(
        sigma:          float,
        LR:             float,
        x_tick_base:    Optional[float]=None,
        y_tick_step:    Optional[float]=None,
        y_tick_lab_int: bool=None,
        ax:             plt.Axes = None,
        fig_size:       List[float]=[8, 5],
        grid_lines:     bool=True,
        colors:         Tuple[chr]=('b', 'r', 'g'),
        x_label:        str='Likelihood ratio',
        y_label:        str='Relative likelihood',
        **kwargs
    ):
    """Display a perfectly-calibrated bi-Gaussian system and a likelihood-ratio value

    :param sigma:          standard deviation of perfectly-calibrated Gaussian distributions (scaled as lnLR)
    :param LR:             likelihood-ratio value to display
    :param x_tick_base:    Optional: base for logarithmic scale of x axis, e.g., 10 or 2
    :param y_tick_step:    Optional: size of steps for scale of y axis as a multiplier of the lowest of the two
                           likelihoods, e.g., enter "1" to show every integer multiple of the lowest likelihood.
    :param y_tick_lab_int: Set to "True" for y-tick labels to be written as integers. Default behaviour writes
                           integers if LR > 10, one demimal place if LR <= 10.
    :param ax:             Existing axis onto which to plot. If None, a new axis is created.
    :param fig_size:       Size of figure. Only relevant if a new axis is created.
    :param grid_lines:     If True (default), draw grid lines.
    :param colors:         Tuple of Matplotlib color character codes for lines and symbols, e.g., default ('b', 'r', 'g')
                           results in a blue same-source curve, red different-source curve, and green lines and symbols
                           for the likelihoods. RGB codes can also be specified, e.g., ((1, 0, 0), (0, 0, 1), (0, 1, 0))
    :param x_label:        String for x-axis label.
    :param y_label:        String for y-axis label.
    :param **kwargs:       Allows Matplotlib parameters to be passed to control linewidth, markersize, etc.
        limited to the following: 
          linewidth:       Width in points (default 1).
          markersize:      Size in points (default 12).
          markeredgewidth: Width in points (default 1).
          fillstyle:       e.g. "full", "none" (default).
          fontsize:        Size in points (default 12) for axis labels.
    """

    # Handle additional parameters and defaults for the plots.
    linestyle, linewidth, marker, markersize, markeredgewidth, fillstyle, fontsize = \
        _style_params_for_plots(**kwargs)
    color_s, color_d, color_LR = colors

    # Natural logarithm of likelihood ratio
    lnLR = np.log(LR)

    # Parameters for perfectly-calibrated bi-Gaussian system (scaled as lnLR)
    sigma2 = sigma ** 2
    mu_s = sigma2 / 2
    mu_d = -mu_s

    # x (lnLR) plotting range
    x_plot_max = mu_s + 3 * sigma
    if np.abs(lnLR) > x_plot_max:
        x_plot_max = abs(lnLR)*1.1

    x = np.arange(-x_plot_max, x_plot_max, 2*x_plot_max/399)

    # y (pdf) for different-source and same-source distributions
    y_s = norm(loc=mu_s, scale=sigma).pdf(x)
    y_d = norm(loc=mu_d, scale=sigma).pdf(x)

    # create figure and adjust aspect ratio
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=fig_size)
    else:
        fig = None

    # plot different-source and same-source distributions
    ax.plot(x, y_d, '-', color=color_d, linewidth=linewidth)
    ax.plot(x, y_s, '-', color=color_s, linewidth=linewidth)
    ax.set_xlim([x[0], x[-1]])
    ax.set_xlabel(x_label, fontsize=fontsize)
    ax.set_ylabel(y_label, fontsize=fontsize)

    lim_y = ax.get_ylim()
    ax.plot([0, 0], [0, lim_y[1]], '-k', linewidth=linewidth/2)
    ax.set_ylim([0, lim_y[1]])
    if grid_lines:
        ax.grid(visible=True, alpha=0.5)

    # Write scale on x axis
    if x_tick_base is None:
        if x_plot_max > np.log(1024):
            base = 10
        else:
            base = 2
    else:
        base = x_tick_base

    lim_x_tick = np.floor(x[-1]/np.log(base))
    x_scale_log_base = np.arange(-lim_x_tick, lim_x_tick+1)

    # y (pdf) for likelihood ratio
    likelihood_s = norm(mu_s, sigma).pdf(lnLR)
    likelihood_d = norm(mu_d, sigma).pdf(lnLR)

    # plot likelihood ratio
    ax.plot([lnLR, lnLR], [0, np.max([likelihood_s, likelihood_d])], '-', color=color_LR, linewidth=linewidth)
    ax.plot([lnLR, lnLR], [likelihood_s, likelihood_d], 'o', fillstyle="none", color=color_LR, linewidth=linewidth, markersize=markersize, markeredgewidth=markeredgewidth)
    ax.plot([x[0], lnLR], [likelihood_s, likelihood_s], '--', color=color_LR, linewidth=linewidth)
    ax.plot([x[0], lnLR], [likelihood_d, likelihood_d], '--', color=color_LR, linewidth=linewidth)

    # write scale on y axis
    likelihood_min = np.min([likelihood_s, likelihood_d])
    lim_y_rel = lim_y[1] / likelihood_min

    # Limit the number of ticks to prevent undue run time
    if len(x_scale_log_base) > 50:
        print("Error determining scales to 'prettify' the x axis")
    else:
        x_scale_ln = x_scale_log_base * np.log(base)
        x_scale_linear = base ** x_scale_log_base
        x_scale_str = num2str_fractions_and_commas(x_scale_linear)

        ax.set_xticks(x_scale_ln)
        ax.set_xticklabels(x_scale_str, rotation=90)

        if y_tick_step is None:
            y_ticks_default = ax.get_yticks()
            num_y_ticks_default = len(y_ticks_default)
            y_step_rel = np.floor(lim_y_rel / num_y_ticks_default)
        else:
            y_step_rel = y_tick_step

        y_scale_rel = np.arange(0, lim_y_rel, y_step_rel)
        y_scale_density = y_scale_rel * likelihood_min
        ax.set_yticks(y_scale_density)
        
        if y_tick_lab_int is None:
            if LR >= 10:
                y_scale_rel_int = [int(x) for x in y_scale_rel]
                ax.set_yticklabels(y_scale_rel_int)
            else:
                ax.set_yticklabels(y_scale_rel)
        elif y_tick_lab_int is True:
            y_scale_rel_int = [int(x) for x in y_scale_rel]
            ax.set_yticklabels(y_scale_rel_int)
        else:
            ax.set_yticklabels(y_scale_rel)

    return fig, ax
    

def pdfs(
        comparison_pairs: Optional["ComparisonPairSet"] = None,
        scores_s_train:   Optional = None,
        scores_d_train:   Optional = None,
        ax:               plt.Axes=None,
        colors:           Tuple[chr]=('b', 'r'),
        grid_lines:       bool=True,
        fig_size:         List[float]=[12, 8],
        **kwargs
    ):
    """Plot PDFs for same-source and different-source scores

    May be called as e3fs3.calibration.BiGaussianCalibration.plot_pdfs or directly.

    :param scores_s_train: same-source scores (lnLR)
    :param scores_d_train: different-source scores (lnLR)
    :params Matplotlib arguments as **kwargs

    :return fig, ax
    """
    # Handle additional parameters and defaults for the plots.
    linestyle, linewidth, marker, markersize, markeredgewidth, fillstyle, fontsize = \
        _style_params_for_plots(**kwargs)
    color_ss, color_ds = colors

    if comparison_pairs is not None and (scores_s_train is not None or scores_d_train is not None) \
            or comparison_pairs is None and (scores_s_train is None or scores_d_train is None):
        print("Plot must be called with either a ComparisonPair set or both scores_s_train and scores_d_train")
        raise ValueError
    if comparison_pairs is not None:
        scores_s_train, scores_d_train = comparison_pairs.get_same_and_different_source_scores()
    else:
        scores_s_train = np.array(scores_s_train)
        scores_d_train = np.array(scores_d_train)

    scores_d_train = np.array(scores_d_train).reshape(-1, 1)
    scores_s_train = np.array(scores_s_train).reshape(-1, 1)
    scores_train = np.concatenate((scores_s_train, scores_d_train))
    min_s, max_s = scores_train.min(), scores_train.max()
    step_s = (max_s - min_s) / 399
    score_range = np.arange(min_s, max_s, step_s).reshape(-1, 1)

    # KDE parameters. For "silverman" or "scott" bandwidth estimation, sklearn >= 1.2 is required, but pdf is not smooth
    kernel = "gaussian"
    bandwidth = (max_s - min_s) / 20

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=fig_size)
    else:
        fig = None

    if len(scores_d_train) > 0:
        kde_ds = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(scores_d_train)
        pdf_ds = np.exp(kde_ds.score_samples(score_range))
        h1, = ax.plot(score_range, pdf_ds, linestyle=linestyle, color=color_ds, linewidth=linewidth)
    else:
        print("Warning: 'scores_d_train' is empty; no pdf will be plotted")

    if len(scores_s_train) > 0:
        kde_ss = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(scores_s_train)
        pdf_ss = np.exp(kde_ss.score_samples(score_range))
        h2, = ax.plot(score_range, pdf_ss, linestyle=linestyle, color=color_ss, linewidth=linewidth)
    else:
        print("Warning: 'scores_s_train' is empty; no pdf will be plotted")

    plt.tight_layout()

    if grid_lines:
        ax.grid(visible=True)

    ax.set_xlim([min_s, max_s])

    ax.set_xlabel("ln$(\Lambda)$", fontsize=fontsize)
    _ = ax.set_ylabel("probability density", fontsize=fontsize)

    return fig, ax


def bigaussian_calibration_cdfs(
        lnLR_target_grid:    np.ndarray,
        cdf_target_grid:     np.ndarray,
        scores_train:        np.ndarray,
        cdf:                 np.ndarray,
        sigma_target:        float,
        ax:                  plt.Axes=None,
        colors:              Tuple[chr]=('b', 'r'),
        grid_lines:          bool=True,
        fig_size:            List[float]=[12, 8],
        **kwargs
    ):
    """Plot CDFs for a trained bi-Gaussian calibration.

    May be called as e3fs3.calibration.BiGaussianCalibration.plot_cdfs or directly.

    :param lnLR_target_grid:
    :param cdf_target_grid:
    :param scores_train:
    :param cdf:
    :param sigma_target:
    :params Matplotlib arguments as **kwargs

    :return fig, ax
    """
    # Handle additional parameters and defaults for the plots.
    linestyle, linewidth, marker, markersize, markeredgewidth, fillstyle, fontsize = \
        _style_params_for_plots(**kwargs)
    color_target, color_empirical = colors

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=fig_size)
    else:
        fig = None

    try:
        minmax_x = [
            np.min(np.concatenate((lnLR_target_grid, scores_train))),
            np.max(np.concatenate((lnLR_target_grid, scores_train)))
        ]
    except Exception as err:
        print(f"There is a problem with the parameters:")
        print(type(err))
        print(err)
        raise
    else:
        ax.plot([0, 0], [0, 1], '-k', linewidth=linewidth/2)

        h1, = ax.plot(
            lnLR_target_grid,
            cdf_target_grid,
            linestyle='-',
            color=color_target,
            linewidth=linewidth
        )
        h2, = ax.plot(
            scores_train,
            cdf,
            linestyle=':',
            color=color_empirical,
            linewidth=linewidth
        )

        plt.tight_layout()

        if grid_lines:
            ax.grid(visible=True)

        ax.set_xlim(minmax_x)
        ax.set_ylim([0, 1])

        ax.set_xlabel(f"score or ln$(\Lambda)$", fontsize=fontsize)
        ax.set_ylabel(f"cumulative probability", fontsize=fontsize)

        ax.set_title(f"$\sigma$ target = {sigma_target:.2f}")
        ax.legend([h1, h2], ['target', 'empirical'], loc="upper left", fontsize=fontsize)

    return fig, ax


def bigaussian_calibration_mapping_functions(
        scores_train_sorted:      np.ndarray,
        lnLR_train_sorted:        np.ndarray,
        methods:                  Optional[list]=["biGauss_LogReg"],
        lnLR_d_logreg_train:      Optional[np.ndarray]=None,
        lnLR_s_logreg_train:      Optional[np.ndarray]=None,
        ax:                       plt.Axes=None,
        colors:                   Tuple[chr]=('k', 'b', 'r', 'g', 'm'),
        grid_lines:               bool=True,
        fig_size:                 List[float]=[12, 8],
        **kwargs
    ):
    """Plot mapping functions for trained bi-Gaussian calibrations.

    lnLR_train_sorted is a 2d array: training points x methods

    May be called as e3fs3.calibration.BiGaussianCalibration.plot_mapping_functions or directly.

    :param scores_train_sorted:
    :param lnLR_train_sorted: 2d np.ndarray scores x methods
    :param lnLR_d_logreg_train: Set to None if not to be plotted
    :param lnLR_s_logreg_train:
    :param colors: list of color codes
    :param methods: names of the methods (for the legend)
    :params Matplotlib arguments as **kwargs
    """

    # Handle additional parameters and defaults for the plots.
    linestyle, linewidth, marker, markersize, markeredgewidth, fillstyle, fontsize = \
        _style_params_for_plots(**kwargs)
    colors = ['k', 'r', 'b', 'g', 'm']
    color_x, color_logreg = colors[:2]
    I_color = 2  # index for remaining colors

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=fig_size)
    else:
        fig = None

    minmax_x = [scores_train_sorted[0], scores_train_sorted[-1]]
    ax.plot(minmax_x, [0, 0], '-', color=color_x, linewidth=linewidth/2)

    legend_h = []
    legend_text = []

    # Logistic regression
    if lnLR_s_logreg_train is not None and lnLR_d_logreg_train is not None:
        minmax_lnLR_logreg = [
            np.min(np.concatenate((lnLR_d_logreg_train, lnLR_s_logreg_train))),
            np.max(np.concatenate((lnLR_d_logreg_train, lnLR_s_logreg_train)))
        ]
        h1, = ax.plot(
            minmax_x,
            minmax_lnLR_logreg,
            color=color_logreg,
            linewidth=linewidth,
            linestyle="--"
        )
        legend_h.append(h1)
        legend_text.append("LogReg")

    # biGaussianized
    methods = [methods] if isinstance(methods, str) else methods

    if len(lnLR_train_sorted.shape) == 1:
        lnLR_train_sorted = lnLR_train_sorted[:, np.newaxis]

    for ii in range(len(methods)):
        h2, = ax.plot(
            scores_train_sorted,
            lnLR_train_sorted[:, ii],
            linewidth=linewidth,
            color=colors[I_color]
        )
        I_color += 1
        legend_h.append(h2)
        legend_text.append(methods[ii])
    ax.set_xlabel("score", fontsize=fontsize)
    ax.set_ylabel("ln$(\Lambda)$", fontsize=fontsize)
    ax.set_xlim(minmax_x)

    plt.tight_layout()

    if grid_lines:
        ax.grid(visible=True)

    ax.legend(legend_h, legend_text, loc="upper left", fontsize=fontsize)

    return fig, ax

def tippett(
        lnLR_s:                     List[float]=None,
        lnLR_d:                     List[float]=None,
        sort_values:                bool=True,
        add_one:                    bool=False,
        ax:                         plt.Axes=None,
        fig_size:                   List[float]=[12, 8],
        zero_line:                  bool=True,
        grid_lines:                 bool=True,
        colors:                     Tuple[chr]=('b', 'r'),
        title:                      str=None,
        x_label:                    str='log$_{{10}}(\Lambda)$',
        y_label:                    str='Cumulative Proportion',
        legend_text:                List[str]=None,
        xlim:                       List[float]=None,
        **kwargs
    ):
    """Output a Tippett plot.

    The Tippett plot displays two curves:

    Same-source curve:          Proportion of different-speaker comparisons with log10(LR) values equal to or less than
                                the value indicated on x-axis.
    Different-source curve:     Proportion of same-speaker comparisons with log10(LR) values equal to or greater than
                                the value indicated on x-axis.

    :param lnLR_s, lnLR_d:      Lists of same-source and different-source ln(LR) values.
    :param sort_values:         If True (default), sort the ln(LR) values.
    :param add_one:             If True, add 1 to the denominator for the calculation of cumulative proportion.
                                This will prevent the cumulative proportion from reaching 1 
                                (appropriate when plotting CDF of a perfectly calibrated bi-Gaussian system).
                                (Default is false.)
    :param ax:                  Existing axis onto which to plot. If None, a new axis is created.
    :param fig_size:            Size of figure. Only relevant if a new axis is created.
    :param zero_line:           If True (default), draw vertical line at log10LR=0.
    :param grid_lines:          If True (default), draw grid lines.
    :param colors:              Pair of Matplotlib color character codes for lines & points.
                                (same-source, different-source)
    :param title:               Title for the Tippett plot (None = no title).
    :param x_label:             String for x-axis label.
    :param y_label:             String for y-axis label.
    :param legend_text:         List of strings for the legend, e.g., ['same source', 'different source'].
    :param xlim:                List [min-x, max-x] to restrict x-axis. Defaults to show all points.
    :param **kwargs:            Allows Matplotlib parameters to be passed to control linestyle, markers, etc.
        limited to the following: 
          linestyle:  e.g. '-': solid (default), '--': dashed, ':': dotted, '-.': dashed-dotted, None: no line)
          linewidth:  Width in points (default 1).
          marker:     Character specification for points on the primary lines (e.g. 'o', default="none").
          markersize: Size in points (default 12).
          markeredgewidth: points  (default 1).
          fillstyle:  e.g. "full", "none" (default).
    """
    lnLR_s = np.array(lnLR_s)
    lnLR_d = np.array(lnLR_d)

    # Handle additional parameters and defaults for the plots.
    linestyle, linewidth, marker, markersize, markeredgewidth, fillstyle, fontsize = \
        _style_params_for_plots(**kwargs)

    color_same_source, color_diff_source = colors
    
    if xlim is not None:
        if type(xlim) != list or len(xlim) != 2 or not isinstance(xlim[0], (int, float)) or not isinstance(xlim[1], (int, float)):
            raise ValueError("xlim must be [float, float]")

    cumulative_proportion_s, cumulative_proportion_d = None, None

    # Sort the lnLRs if requested.
    if sort_values:
        lnLR_s = sorted(lnLR_s)
        lnLR_d = sorted(lnLR_d)

    # Convert natural LRs to base 10 for the Tippett plot.
    log10LR_s = lnLR_s / np.log(10)
    log10LR_d = lnLR_d / np.log(10)
    
    num_s = len(log10LR_s)
    num_d = len(log10LR_d)

    if add_one:
        cumulative_proportion_s =        (1 + np.arange(num_s)) / (num_s + 1)
        cumulative_proportion_d = np.flip(1 + np.arange(num_d)) / (num_d + 1)
    else:
        cumulative_proportion_s =        (1 + np.arange(num_s)) / num_s
        cumulative_proportion_d = np.flip(1 + np.arange(num_d)) / num_d

    # Make a new figure if axis has not been specified, and draw zero line if requested
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=fig_size)
    else:
        fig = None
    if zero_line:
        ax.plot([0, 0], [0, 1], "k", linewidth=0.5, label='_nolegend_')

    # Plot cumulative proportions against log10(LR) values
    ax.plot(
        log10LR_s, 
        cumulative_proportion_s, 
        color=color_same_source,
        linestyle=linestyle,
        linewidth=linewidth,
        marker=marker,
        markeredgewidth=markeredgewidth,
        markersize=markersize,
        fillstyle=fillstyle,
        clip_on=True
    )
    ax.plot(
        log10LR_d,
        cumulative_proportion_d, 
        color=color_diff_source,
        linestyle=linestyle,
        linewidth=linewidth,
        marker=marker,
        markeredgewidth=markeredgewidth,
        markersize=markersize,
        fillstyle=fillstyle,
        clip_on=True
    )

    # Set the x-axis limits if not overridden by parameter
    if xlim is None:
        # Test for enough data
        min_log10LR_d = 0 if len(log10LR_d) == 0 else min(log10LR_d)
        min_log10LR_s = 0 if len(log10LR_s) == 0 else min(log10LR_s)
        max_log10LR_d = 0 if len(log10LR_d) == 0 else max(log10LR_d)
        max_log10LR_s = 0 if len(log10LR_s) == 0 else max(log10LR_s)

        xlim = [min([min_log10LR_d, min_log10LR_s]), max([max_log10LR_d, max_log10LR_s])]
    ax.set_xlim(xlim)
    ax.set_ylim([0, 1])

    if title is not None:
        ax.set_title(title)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    if legend_text is not None:
        ax.legend(legend_text)
    
    if grid_lines:
        ax.grid(visible=True)
    
    return fig, ax
