"""Implements the `EM algorithm`_ for the mixture model.
Using the class :py:class:`.models.LymphMixture` and its methods, this module provides
functions to compute the expectation and maximization steps of the `EM algorithm`_.
.. _EM algorithm: https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm
"""
import logging
import os
from collections.abc import Callable, Sequence
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import current_process, Pool
import copy
import emcee
import numpy as np
import pandas as pd
from scipy import optimize as opt
from lymixture import models, utils
logger = logging.getLogger(__name__)
RNG = np.random.default_rng(seed=42)
"""Random number generator for reproducibility."""
def _get_params(model: models.LymphMixture) -> np.ndarray:
"""Return the params of ``model``.
.. seealso::
This function is very similar to the :py:meth:`.models.LymphMixture.get_params`
method. Except that it does not accept a dictionary of parameters, but only a 1D
array.
"""
params = []
if model.universal_p:
for comp in model.components:
params += list(comp.get_spread_params(as_dict=False))
params += list(model.get_distribution_params(as_dict=False))
else:
for comp in model.components:
params += list(comp.get_params(as_dict=False))
return params
def _set_params(model: models.LymphMixture, params: np.ndarray) -> None:
"""Set the params of ``model`` from ``params``.
.. seealso::
This function is very similar to the :py:meth:`.models.LymphMixture.set_params`
method. Except that it does not accept a dictionary of parameters, but only a 1D
array.
"""
if model.universal_p:
for comp in model.components:
params = comp.set_spread_params(*params)
params = np.array(model.set_distribution_params(*params))
else:
for comp in model.components:
params = comp.set_params(*params)
params = np.array(params)
def _is_in_parallel_context() -> bool:
"""Check if we're already running in a parallel worker process.
Returns True if:
- Running in a multiprocessing worker (process name is not MainProcess)
- SLURM_JOB_ID environment variable is set (running in a SLURM job)
- Any other common parallelization indicators
"""
# Check if we're in a multiprocessing worker
if current_process().name != 'MainProcess':
return True
# Check for SLURM environment
if 'SLURM_JOB_ID' in os.environ:
return True
# Check for common parallel environment variables
parallel_env_vars = ['PBS_JOBID', 'LSB_JOBID', 'JOB_ID']
if any(var in os.environ for var in parallel_env_vars):
return True
return False
def _optimize_single_component(args: tuple) -> tuple[int, np.ndarray]:
"""Optimize a single component. Used for parallel execution.
Args:
args: Tuple of (component_index, component_params, model, num_components)
Returns:
Tuple of (component_index, optimized_parameters)
"""
i, current_params, model, _, method = args
lb = np.zeros(shape=len(current_params))
ub = np.ones(shape=len(current_params))
result = opt.minimize(
fun=_neg_complete_component_llh,
args=(model, i),
x0=current_params,
bounds=opt.Bounds(lb=lb, ub=ub),
method=method,
callback=init_callback(),
)
if result.success:
return i, result.x
else:
msg = f"Optimization failed for component {i}: {result}"
raise RuntimeError(msg)
[docs]
def expectation(
model: models.LymphMixture,
params: dict[str, float],
*,
log: bool = False,
) -> np.ndarray:
"""Compute expected value of latent ``model`` variables given the ``params``.
This marks the E-step of the famous `EM algorithm`_. The returned expected values
are also often called responsibilities.
If ``log`` is set to ``True``, the function returns the logarithm of the
responsibilities.
.. _EM algorithm: https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm
"""
model.set_params(**params)
llhs = model.patient_mixture_likelihoods(log=log, marginalize=False)
if log:
return utils.log_normalize(llhs.T, axis=0).T
return utils.normalize(llhs.T, axis=0).T
[docs]
def init_callback() -> Callable:
"""Return a function that logs the optimization progress."""
iteration = 0
def log_optimization(xk) -> None: # noqa: ANN001
nonlocal iteration
logger.debug(f"Iteration {iteration} with params: {xk}")
iteration += 1
return log_optimization
def _neg_complete_component_llh(
params: np.ndarray,
model: models.LymphMixture,
component: int,
) -> float:
"""Return the negative complete log likelihood of ``component`` in ``model``.
This function is used in the M-step of the EM algorithm.
"""
# print(params)
try:
if model.split_midext:
filtered = {k: v for k, v in model.components[component].get_params().items() if k != 'midext_prob'}
param_names = list(filtered.keys())
params_dict_new = dict(zip(param_names, params))
model.components[component].set_params(**params_dict_new)
else:
model.components[component].set_params(*params)
except ValueError:
return np.inf
result = -model.complete_data_likelihood(component=component)
logger.debug(f"Component {component} with params {params} has llh {result}")
return result
def _neg_complete_component_llh_shared(
params: dict[str, float],
model: models.LymphMixture,
) -> float:
"""Return the negative complete log likelihood of ``component`` in ``model``.
This function is used in the M-step of the EM algorithm.
"""
try:
params_dict = model.get_params(model_params_only=True)
params_dict_new = dict(zip(params_dict.keys(), params))
model.set_params(**params_dict_new)
except ValueError:
return np.inf
result = -model.complete_data_likelihood()
logger.debug(f"Mixture model with params {params} has llh {result}")
return result
[docs]
def maximization(
model: models.LymphMixture,
log_resps: np.ndarray,
parallelize: bool = True,
method: str = "Powell",
) -> dict[str, float]:
"""Maximize ``model`` params given expectation of ``latent`` variables.
This is the corresponding M-step to the :py:func:`.expectation` of the
`EM algorithm`_. It first maximizes the mixture coefficients analytically and then
optimizes the model parameters of all components sequentially.
.. _EM algorithm: https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm
"""
log_maxed_mix_coefs = model.infer_mixture_coefs(new_resps=log_resps, log=True)
log_maxed_mix_coefs = utils.log_normalize(log_maxed_mix_coefs, axis=0)
model.set_mixture_coefs(np.exp(log_maxed_mix_coefs))
if model.shared_transmission or model.universal_p:
current_params = list(model.get_params(as_dict=False, model_params_only=True))
lb = np.zeros(shape=len(current_params))
ub = np.ones(shape=len(current_params))
result = opt.minimize(
fun=_neg_complete_component_llh_shared,
args=(model),
x0=current_params,
bounds=opt.Bounds(lb=lb, ub=ub),
method=method,
callback=init_callback(),
)
if result.success:
params_dict = model.get_params(model_params_only=True)
params_dict_new = dict(zip(params_dict.keys(), result.x))
model.set_params(**params_dict_new)
else:
msg = f"Optimization failed: {result}"
raise RuntimeError(msg)
else:
# Check if we should parallelize component optimization
# Only parallelize if: (1) we have multiple components AND (2) not already in parallel context
num_components = len(model.components)
use_parallel = num_components > 1 and not _is_in_parallel_context() and parallelize
if use_parallel:
# Parallel optimization of components
logger.debug(f"Parallelizing {num_components} component optimizations")
optimization_args = []
for i, component in enumerate(model.components):
if model.split_midext:
filtered = {k: v for k, v in component.get_params().items() if k != 'midext_prob'}
current_params = list(filtered.values())
else:
current_params = list(component.get_params(as_dict=False))
optimization_args.append((i, current_params, copy.deepcopy(model), num_components, method))
# Use ProcessPoolExecutor to parallelize
with ProcessPoolExecutor(max_workers=num_components) as executor:
results = list(executor.map(_optimize_single_component, optimization_args))
# Set optimized parameters back to components
for i, optimized_params in results:
if model.split_midext:
filtered = {k: v for k, v in model.components[i].get_params().items() if k != 'midext_prob'}
param_names = list(filtered.keys())
params_dict_new = dict(zip(param_names, optimized_params))
model.components[i].set_params(**params_dict_new)
else:
model.components[i].set_params(*optimized_params)
else:
if _is_in_parallel_context() and parallelize:
logger.debug("Already in parallel context, using sequential component optimization")
for i, component in enumerate(model.components):
if model.split_midext:
filtered = {k: v for k, v in component.get_params().items() if k != 'midext_prob'}
current_params = list(filtered.values())
else:
current_params = list(component.get_params(as_dict=False))
lb = np.zeros(shape=len(current_params))
ub = np.ones(shape=len(current_params))
result = opt.minimize(
fun=_neg_complete_component_llh,
args=(model, i),
x0=current_params,
bounds=opt.Bounds(lb=lb, ub=ub),
method=method,
callback=init_callback(),
)
if result.success:
if model.split_midext:
filtered = {k: v for k, v in component.get_params().items() if k != 'midext_prob'}
param_names = list(filtered.keys())
params_dict_new = dict(zip(param_names, result.x))
component.set_params(**params_dict_new)
else:
component.set_params(*result.x)
else:
msg = f"Optimization failed: {result}"
raise RuntimeError(msg)
return model.get_params(as_dict=True)
[docs]
def log_prob_fn_fixed_mixture(
theta: Sequence[float],
model: models.LymphMixture,
) -> float:
"""Compute the model's log-prob, given its params, excluding mixture coefficients.
This function calculates the log-probability of a mixture ``model`` based on the
provided parameters (``theta``), assuming that mixture coefficients remain fixed.
It ensures that the parameter values are within the valid range [0, 1], and returns
negative infinity (``-inf``) if any parameter is out of bounds.
Returns:
float: The log-probability of the model if parameters are valid, or ``-inf`` if
parameters are out of bounds.
.. note::
- This function does not modify or include mixture coefficients in ``theta``;
these are assumed to remain unchanged.
- The `_set_params` function is used to update the model parameters before
computing the likelihood.
"""
lower_bounds = np.zeros(len(theta))
upper_bounds = np.ones(len(theta))
# Check if the parameters are within bounds
if np.any(theta < lower_bounds) or np.any(theta > upper_bounds):
return -np.inf # Return -infinity if out of bounds
_set_params(model, theta)
return model.likelihood(log=True, use_complete=False)
[docs]
def log_prob_fn(theta: Sequence[float], model: models.LymphMixture) -> float:
"""Compute the log-probability of the model given its parameters.
This function returns the log-probability of the provided mixture ``model`` based
on the given parameter values (``theta``). It ensures that parameters stay within
predefined bounds (0 to 1). If any parameter is out of bounds, the function
returns negative infinity (``-inf``).
.. note::
The `theta` array includes mixture parameters, which are not sampled from a
simplex. This behavior could be extended to enforce simplex constraints if
required.
"""
lower_bounds = np.zeros(len(theta))
upper_bounds = np.ones(len(theta))
# Check if the parameters are within bounds
if np.any(theta < lower_bounds) or np.any(theta > upper_bounds):
return -np.inf # Return -infinity if out of bounds
model.set_params(*theta)
return model.likelihood(log=True, use_complete=True)
[docs]
def sample_fixed_mixture(
model: models.LymphMixture,
steps: int = 100,
latent: pd.DataFrame | None = None,
filename: str = "chain_fixed_mix.hdf5",
*,
continue_sampling: bool = False,
) -> tuple[emcee.backends.HDFBackend, np.ndarray]:
"""Sample the parameters of a mixture model, excluding mixture coefficients.
This function performs MCMC sampling for the parameters of a mixture ``model`` while
keeping the mixture coefficients fixed. It allows the specification of ``latent``
parameters and offers options to either start a new sampling session or
``continue_sampling`` from an existing HDF5 backend file (named ``filename``).
.. note::
- The model's responsibilities (``resps``) and mixture coefficients are updated
based on the provided or computed latent parameters.
- Mixture coefficients are fixed during the sampling process.
- The function initializes an :py:class:`emcee.EnsembleSampler` with a fixed
mixture coefficient log-probability function (``log_prob_fn_fixed_mixture``)
and uses multiprocessing to parallelize sampling.
"""
if latent is None:
latent = model.get_resps()
model.set_resps(latent)
maximized_mixture_coefs = model.infer_mixture_coefs(new_resps=latent)
model.set_mixture_coefs(maximized_mixture_coefs)
current_params = _get_params(model)
ndim = len(current_params)
nwalkers = 5 * ndim
perturbation = 1e-6 * RNG.randn(nwalkers, ndim)
backend = emcee.backends.HDFBackend(filename)
if continue_sampling is False:
starting_points = np.ones((nwalkers, ndim)) * current_params + perturbation
backend.reset(nwalkers, ndim)
else:
starting_points = None
# Pass model as an additional argument to log_prob_fn
with Pool() as pool:
logger.info(f"Number of cores used by the sampler: {pool._processes}") # noqa: SLF001
original_sampler = emcee.EnsembleSampler(
nwalkers,
ndim,
log_prob_fn_fixed_mixture,
args=(model,), # Pass model here
pool=pool,
backend=backend,
)
original_sampler.run_mcmc(
initial_state=starting_points,
nsteps=steps,
progress=True,
)
return backend, original_sampler.get_chain(discard=0, thin=1, flat=True)
[docs]
def sample_model_params(
model: models.LymphMixture,
steps: int = 100,
latent: pd.DataFrame | None = None,
filename: str = "chain_fixed_latent.hdf5",
*,
continue_sampling: bool = False,
) -> tuple[emcee.backends.HDFBackend, np.ndarray]:
"""Sample the parameters of a mixture model given expectations of latent variables.
This function performs Markov Chain Monte Carlo (MCMC) sampling of the parameters
of a provided mixture ``model``. It allows setting ``latent`` parameters and
provides options to either start sampling from scratch or ``continue_sampling``
from a previous state stored in an HDF5 file named ``filename``.
.. note::
- The model's responsibilities (``resps``) and mixture coefficients are
updated based on the provided or computed latent parameters.
- The function initializes an `emcee.EnsembleSampler` for MCMC sampling and
uses a multiprocessing pool to parallelize the computations.
"""
latent = latent or model.get_resps()
model.set_resps(latent)
model.set_mixture_coefs(model.infer_mixture_coefs())
current_params = list(model.get_params(as_dict=False))
ndim = len(current_params)
nwalkers = 5 * ndim
perturbation = 1e-6 * abs(RNG.randn(nwalkers, ndim))
backend = emcee.backends.HDFBackend(filename)
if continue_sampling is False:
starting_points = np.ones((nwalkers, ndim)) * current_params + perturbation
starting_points[starting_points > 1] = 1 - perturbation[starting_points > 1]
backend.reset(nwalkers, ndim)
else:
starting_points = None
with Pool() as pool:
original_sampler = emcee.EnsembleSampler(
nwalkers,
ndim,
log_prob_fn,
args=(model,),
pool=pool,
backend=backend,
)
original_sampler.run_mcmc(
initial_state=starting_points,
nsteps=steps,
progress=True,
)
return backend, original_sampler.get_chain(discard=0, thin=1, flat=True)
[docs]
def complete_latent_likelihood(
theta: Sequence[float],
model: models.LymphMixture,
) -> float:
"""Compute the complete data log-llh for mixture ``model``, given latent variables.
This function evaluates the log-likelihood of the mixture ``model`` using a
provided set of latent variable assignments (``theta``). The assignments are set
as the responsibilities (``resps``) of the model before computing the likelihood.
"""
resps = model.get_resps()
sampled_df = pd.DataFrame(theta, index=resps.index, columns=resps.columns)
model.set_resps(sampled_df)
return model.likelihood(log=True, use_complete=True)
[docs]
def mh_latent_sampler_per_patient_2_component(
model: models.LymphMixture,
temp: float | None = None,
) -> tuple[pd.DataFrame, float]:
"""Perform Metropolis-Hastings for latent variables per-patient for 2 components.
This function implements a basic Metropolis-Hastings (MH) sampler to update the
latent variables (responsibilities) of a mixture ``model`` for individual patients.
It swaps the latent variable assignments for two components, evaluates the
log-acceptance ratio, and accepts or rejects the proposed changes based on the
Metropolis criterion.
It returns the latent variable responsibilities before the sampling step and the
log-probability of the model before the sampling step.
.. note::
- The sampler works by proposing a swap of responsibilities between two
components for each patient and calculating the acceptance ratio using
the patient-specific mixture likelihoods.
- Accepted swaps are updated in the latent variable matrix under
the header ``accepted_position``.
- The current and new log-probabilities are computed using
the provided ``log_prob_fn``.
- This function is designed for a full AIP algorithm but is not used
due to long computation times.
"""
temp = temp or 0.5
current_position = model.get_resps()
new_position = current_position.copy()
accepted_position = current_position.copy()
current_log_prob = complete_latent_likelihood(current_position, model)
new_position.iloc[:, [0, 1]] = current_position.iloc[:, [1, 0]].to_numpy()
current_assignments = np.argmax(np.array(current_position), axis=1)
new_assignments = np.argmax(np.array(new_position), axis=1)
log_acceptance_ratio = (
model.patient_mixture_likelihoods(log=True)[
np.arange(len(new_assignments)),
new_assignments,
]
- model.patient_mixture_likelihoods(log=True)[
np.arange(len(current_assignments)),
current_assignments,
]
) / temp
accept_ratio = np.exp(log_acceptance_ratio)
accept_thresholds = RNG.rand(len(accept_ratio))
accepted_indices = np.where(accept_thresholds < accept_ratio)[0]
accepted_position.iloc[accepted_indices, [0, 1]] = current_position.iloc[
accepted_indices,
[1, 0],
].to_numpy()
model.set_resps(accepted_position)
logger.info(f"{len(accepted_indices)} swaps accepted")
return current_position, current_log_prob
[docs]
def aip_sampling_algorithm(
model: models.LymphMixture,
ip_rounds: int = 4000,
n_steps_params: int = 1,
temperature_schedule: Callable[[int], float] | None = None,
params_filename: str = "../../params_samples.hdf5",
) -> dict[str, list]:
"""Perform Alternating Iterative Posterior (AIP) sampling for a mixture model.
This function alternates between sampling latent variables and ``model`` parameters
to approximate the posterior distribution of a mixture model. The AIP algorithm
integrates Metropolis-Hastings (MH) sampling for latent variables and a parameter
sampler initialized with ``emcee``. This is computationally intensive and may take
a long time to converge and is therefore only used for toy problems.
Returns:
A dictionary containing:
- "params_samples" (list): Samples of model parameters.
- "latent_samples" (list): Samples of latent variables.
- "complete_likelihoods" (list): Complete data log-llhs across iterations.
- "incomplete_likelihoods" (list): Incomplete data log-llhs across iterations.
- "number_of_swaps" (list): Number of swaps in latent variables btw. iterations.
"""
# Initialization
n_dim_params = len(model.get_params())
# Lists to store results
params_samples = []
latent_samples = []
complete_likelihoods = []
incomplete_likelihoods = []
number_of_swaps = []
# Initialize latent variables
starting_latent = model.get_resps()
starting_latent.iloc[:, 0] = RNG.choice([0, 1], len(starting_latent))
starting_latent.iloc[:, 1] = 1 - starting_latent.iloc[:, 0]
model.set_resps(starting_latent)
# Initialize parameter sampler
backend_params, params_samples = sample_model_params(
model,
steps=1,
filename=params_filename,
continue_sampling=False,
)
# Initial samples
latent_samples.append(model.get_resps())
params_samples.append(model.get_params(as_dict=False))
for ip_round in range(ip_rounds):
# Determine temperature
if temperature_schedule is None:
temperature = 1 - ip_round / ip_rounds + 0.05
else:
temperature = temperature_schedule(ip_round)
# Latent sampling
new_latent, current_prob = mh_latent_sampler_per_patient_2_component(
model,
temperature,
)
latent_samples.append(new_latent)
# Parameter sampling
backend_params, params_samples = sample_model_params(
model,
steps=n_steps_params,
filename=params_filename,
continue_sampling=True,
)
new_params_samples = backend_params.get_chain(discard=0, thin=1, flat=False)
# Extract the last parameter sample and update
samples_flat = new_params_samples.reshape(-1, n_dim_params) # Flatten correctly
params_samples.append(samples_flat[-1])
model.set_params(*samples_flat[-1])
# Compute likelihoods for diagnostics
complete_likelihoods.append(model.likelihood(use_complete=True))
incomplete_likelihoods.append(model.likelihood(use_complete=False))
if ip_round != 0:
number_of_swaps.append(
abs(latent_samples[-1] - latent_samples[-2]).sum().sum() / 2,
)
logger.debug(
f"Complete likelihood: {complete_likelihoods[-1]}, "
f"Incomplete likelihood: {incomplete_likelihoods[-1]}",
)
return {
"params_samples": params_samples,
"latent_samples": latent_samples,
"complete_likelihoods": complete_likelihoods,
"incomplete_likelihoods": incomplete_likelihoods,
"number_of_swaps": number_of_swaps,
}