#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
2025-07-10 Phil Weber
    Logistic regression function.

    Auxiliary functions include functions from, or adpated from, Niko Brümmer's FoCal Toolkit, adapted from
    a Matlab implementation by Geoffrey Stewart Morrison.
"""

import numpy as np

def lin_fusion(weights, scores):
    """Return scores multiplied by the logistic regression weights.

    :param weights: array of weights trained by logistic regression
    :param scores:  speaker-comparison scores (e.g. from PLDA)
    """
    return weights.T.dot(np.vstack((scores, np.ones(scores.shape[0]))))


def train_logreg_fusion_regularized(tar, non, prior: float=.5, kappa=0.0, df=None, max_iter: int=50000):
    """Train a regularized logistic-regression model.

    Train linear fusion (see 'lin_fusion') with prior-weighted logistic regression objective.
    The fusion output is encouraged by this objective to be a well-calibrated log-likelihood-ratio,
    i.e., this is simultaneous fusion and calibration.

    Regularization is either to prevent numerical problems in cases of complete (or near complete)
    separation in the training data, or to deliberately shrink the ln(LR) output toward 0.
    An explanation is provided in:
        Morrison G.S., Poh N. (2018). Avoiding overstating the strength of forensic evidence:
        Shrunk likelihood ratios / Bayes factors. Science & Justice, 58, 200–218.
        https://doi.org/10.1016/j.scijus.2017.12.005

    A maximally uninformative uniform distribution is added like an uninformative Bayesian prior.
    The strength of the "prior" distributions is controlled by 'kappa' which is scaled in pseudodata points,
    e.g., setting 'kappa' to 5 would mean that the strength of the prior is equivalent to 5 data points.
    Heuristic: use kappa <= 0.1 to avoid numerical problems, kappa >= 1 to induce shrinkage.
    For no regularization, set 'kappa' to 0.

    'df' is pseudo degrees of freedom, for example, if 100 scores were created by comparing recordings from
    10 speakers 'df' could be set to 10.
    If no 'df' were specified, the default 'df' in this example would be 100.
    If 'df' is not set, the default value is the total number of (same-source and different-source) scores
    used for training.
    If 'kappa' is 0, the value of 'df' is irrelevant.

    Code is adapted from Geoffey Stewart Morrison's Malab code 'train_llr_fusion_regularized.m'
    at https://geoff-morrison.net/#shrunk_LRs, which was adapted from Niko Brümmer's 'train_llr_fusion.m'
    function in his FoCal toolkit https://sites.google.com/site/nikobrummer/focal, which was adapted from
    Tom Minka's 'train_cg.m' at https://tminka.github.io/papers/logreg/.

    :param tar:      pairs of same-source (target) scores
    :param non:      pairs of diffeent-source (non-target) scores
    :param prior:    prior probability of same-source
    :param kappa:    controls strength of the "prior" distributions (see explanation above)
    :param df:       pseudo degrees of freedom (see explanation above)
    :param max_iter: maximum iterations of training
    """
    nt = tar.shape[0]  # scores
    nn = non.shape[0]
    prop = nt / (nn+nt)  # of same-speaker scores

    if kappa == 0.0:
        weights = np.hstack(((prior/prop) * np.ones(nt), ((1-prior)/(1-prop))*np.ones(nn)))
        weights = weights[np.newaxis,]
        x = np.hstack(
            (np.vstack((tar, np.ones(nt),)),
            -np.vstack((non, np.ones(nn),))
            )
        )
        offset = logit(prior) * np.hstack((np.ones(nt), -np.ones(nn)))
        offset = offset[np.newaxis,]
    else:
        if df is None:
            df = nn + nt  # scores

        weights_temp = np.hstack(((prior/prop) * np.ones(nt), ((1-prior)/(1-prop))*np.ones(nn)))
        weight_flat_prior = kappa / (2*df)  # default: scores
        weights_temp_flat_priors = weight_flat_prior*np.ones((nt+nn)*2)
        weights = np.hstack((weights_temp, weights_temp_flat_priors))
        weights = weights[np.newaxis,]
        x_temp = np.hstack(
            (np.vstack((tar, np.ones(nt),)),
            -np.vstack((non, np.ones(nn),))
            )
        )
        x = np.hstack((x_temp, x_temp, -x_temp))
        offset_temp = logit(prior) * np.hstack((np.ones(nt), -np.ones(nn)))
        offset = np.hstack((offset_temp, offset_temp, -offset_temp))

    w = np.zeros(x.shape[0])
    old_g = np.zeros_like(w)
    for iter in np.arange(max_iter):
        old_w = w
        s1 = 1. / (1. + np.exp(w.T.dot(x) + offset))
        g = x.dot((s1 * weights).flatten().T)
        if iter == 0:
            u = g
        else:
            u = cg_dir(u, g, old_g)

        ug = u.T.dot(g)
        ux = u.T.dot(x)
        a = weights * s1 * (1 - s1)
        w = w + (ug / (ux**2).dot(a.T)) * u
        old_g = g
        if np.max(abs(w - old_w)) < 1e-5:
            break
    return w


def cg_dir(old_dir, grad, old_grad):
    """Conjugate gradient

    :param grad
    :param old_grad
    """
    g = grad
    grad = grad.flatten()
    old_grad = old_grad.flatten()
    delta = grad - old_grad
    den = old_dir.T.dot(delta)
    if den == 0:
        dir = g*0
    else:
        beta = (grad.T * delta) / den
        dir = g - beta * old_dir
    return dir


def logit(p):
    """Return logit of the input

    :param: p: probability \in [0, 1].
    """
    return np.log(p/(1-p))
