Inference Using the Expectation-Maximization Algorithm

Inference Using the Expectation-Maximization Algorithm#

In this notebook we demonstrate how to train the mixture lymphatic progression models. We do this for a simple set of synthetic data and see if and how well we can recover the original parameters that we set.

Imports#

from collections import namedtuple
from typing import Literal, Any, Callable

import numpy as np
import pandas as pd

from lymixture import LymphMixture
from lymixture import utils
from lymph.models import Unilateral

rng = np.random.default_rng(42)

Synthetic Data#

Define parameters and configuration to draw a number of synthetic data.

Modality = namedtuple("Modality", ["spec", "sens"])

# definition of the directed acyclic graph
GRAPH_DICT = {
    ("tumor", "T"): ["II", "III"],
    ("lnl", "II"): ["III"],
    ("lnl", "III"): [],
}
# definition of the diagnostic modality
MODALITIES = {
    "path": Modality(spec=0.9, sens=0.9),
}
# assumed distributions over the time to diagnosis
DISTRIBUTIONS = {
    "early": utils.binom_pmf(k=np.arange(11), n=10, p=0.3),
    "late": utils.late_binomial,
}

# params of component 1
PARAMS_C1 = {
    "TtoII_spread": 0.05,
    "TtoIII_spread": 0.25,
    "IItoIII_spread": 0.5,
    "late_p": 0.5,
}
# params of component 2
PARAMS_C2 = {
    "TtoII_spread": 0.25,
    "TtoIII_spread": 0.05,
    "IItoIII_spread": 0.1,
    "late_p": 0.5,
}
SUBSITE_COL = ("tumor", "1", "subsite")
ModalityDict = dict[str, dict[str, float | Literal["clinical", "pathological"]]]

def create_model(
    model_kwargs: dict[str, Any] | None = None,
    modalities: ModalityDict | None = None,
    distributions: dict[str, list[float] | Callable] | None = None,
) -> Unilateral:
    """Create a model to draw patients from."""
    model = Unilateral(**(model_kwargs or {"graph_dict": GRAPH_DICT}))

    for name, modality in (modalities or MODALITIES).items():
        model.set_modality(name, modality.spec, modality.sens)

    for t_stage, dist in (distributions or DISTRIBUTIONS).items():
        model.set_distribution(t_stage, dist)

    return model
def draw_datasets(
    model: Unilateral,
    num_c1: int,
    num_c2: int,
    num_c3: int,
    tstage_ratio: float,
    mix: float,
    rng: np.random.Generator,
) -> pd.DataFrame:
    """Draw patients for the three datasets."""
    model.set_params(**PARAMS_C1)
    c1_data = model.draw_patients(
        num=num_c1 + int(num_c3 * mix),
        stage_dist=[tstage_ratio, 1 - tstage_ratio],
        rng=rng,
    )
    model.set_params(**PARAMS_C2)
    c2_data = model.draw_patients(
        num=num_c2 + int(num_c3 * (1 - mix)),
        stage_dist=[tstage_ratio, 1 - tstage_ratio],
        rng=rng,
    )
    c3_data = pd.concat(
        [
            c1_data.iloc[num_c1:],
            c2_data.iloc[num_c2:],
        ],
        ignore_index=True,
        axis=0,
    )
    c1_data = c1_data.iloc[:num_c1]
    c2_data = c2_data.iloc[:num_c2]

    c1_data[SUBSITE_COL] = "c1"
    c2_data[SUBSITE_COL] = "c2"
    c3_data[SUBSITE_COL] = "c3"

    return pd.concat([c1_data, c2_data, c3_data], ignore_index=True, axis=0)
model = create_model()
synthetic_data = draw_datasets(
    model=model,
    num_c1=1000,
    num_c2=1000,
    num_c3=1000,
    tstage_ratio=0.4,
    mix=0.5,
    rng=rng,
)

random_idx = rng.choice(synthetic_data.index, size=6, replace=False)
synthetic_data.iloc[random_idx]
path tumor
ipsi core 1
II III t_stage subsite
737 False True early c1
2617 True True late c3
111 False False early c1
1466 False True early c2
2177 True True early c3
429 True True late c1

Model Initialization#

Now, we define the mixture model and load the just drawn data. Note that we use only two components, hoping that the "c3" subgroup can be described as a mixture of these two components.

graph = {
    ("tumor", "T"): ["II", "III"],
    ("lnl", "II"): ["III"],
    ("lnl", "III"): [],
}
num_components = 2

mixture = LymphMixture(
    model_cls=Unilateral,
    model_kwargs={"graph_dict": graph},
    num_components=num_components,
    universal_p=False,
)
mixture.load_patient_data(
    synthetic_data,
    split_by=("tumor", "1", "subsite"),
    mapping=lambda x: x,
)

Set the diagnostic modality to be the same as in the generated dataset.

for name, modality in MODALITIES.items():
    mixture.set_modality(name=name, spec=modality.spec, sens=modality.sens)

Fix the distribution over diagnosis times. Again, we set this to be the same as during the synthetic data generation.

for t_stage, dist in DISTRIBUTIONS.items():
    mixture.set_distribution(t_stage, dist)

Inference#

from lymixture.em import expectation, maximization

The iterative steps of computing the expectation over the latent variables (E-step) and maximizing the model parameters (M-step) can be initialized with an arbitrary set of starting parameters.

params = {k: rng.uniform() for k in mixture.get_params()}
mixture.set_params(**params)
mixture.normalize_mixture_coefs()
latent = utils.normalize(rng.uniform(size=mixture.get_resps().shape).T, axis=0).T

Then we define a function to check the convergence of the algorithm.

def is_converged(
    history: list[dict[str, float]],
    rtol: float = 1e-4,
) -> bool:
    """Check if the EM algorithm has converged."""
    if len(history) < 2:
        return False

    old, new = history[-2]["llh"], history[-1]["llh"]
    return np.isclose(old, new, rtol=rtol)

Finally, we can iterate the computation of the expectation value of the latent variables (E-step) and the maximization of the (complete) data log-likelihood w.r.t. the model parameters (M-step).

While the algorithm converges, we check the incomplete data likelihood after each round.

count = 0
snapshot = {
    "llh": mixture.incomplete_data_likelihood(),
    **mixture.get_params(as_dict=True, as_flat=True),
}
history = [snapshot]

while not is_converged(history, rtol=1e-4):
    print(f"iteration {count:>3d}: {history[-1]['llh']:.3f}")
    count += 1

    latent = expectation(mixture, params)
    assert np.allclose(latent.sum(axis=1), 1.)
    params = maximization(mixture, latent)

    snapshot = {
        "llh": mixture.incomplete_data_likelihood(),
        **mixture.get_params(as_dict=True, as_flat=True),
    }
    history.append(snapshot)
iteration   0: -4912.985
---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/home/docs/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/home/docs/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/docs/checkouts/readthedocs.org/user_builds/lymixture/envs/latest/lib/python3.10/site-packages/lymixture/em.py", line 119, in _optimize_single_component
    raise RuntimeError(msg)
RuntimeError: Optimization failed for component 0:  message: NaN result encountered.
 success: False
  status: 3
     fun: nan
       x: [ 3.820e-01  3.820e-01  3.820e-01  3.820e-01]
     nit: 1
   direc: [[ 1.000e+00  0.000e+00  0.000e+00  0.000e+00]
           [ 0.000e+00  1.000e+00  0.000e+00  0.000e+00]
           [ 0.000e+00  0.000e+00  1.000e+00  0.000e+00]
           [ 0.000e+00  0.000e+00  0.000e+00  1.000e+00]]
    nfev: 81
"""

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[12], line 14
     12 latent = expectation(mixture, params)
     13 assert np.allclose(latent.sum(axis=1), 1.)
---> 14 params = maximization(mixture, latent)
     16 snapshot = {
     17     "llh": mixture.incomplete_data_likelihood(),
     18     **mixture.get_params(as_dict=True, as_flat=True),
     19 }
     20 history.append(snapshot)

File ~/checkouts/readthedocs.org/user_builds/lymixture/envs/latest/lib/python3.10/site-packages/lymixture/em.py:262, in maximization(model, log_resps, parallelize, method)
    260 # Use ProcessPoolExecutor to parallelize
    261 with ProcessPoolExecutor(max_workers=num_components) as executor:
--> 262     results = list(executor.map(_optimize_single_component, optimization_args))
    264 # Set optimized parameters back to components
    265 for i, optimized_params in results:

File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

RuntimeError: Optimization failed for component 0:  message: NaN result encountered.
 success: False
  status: 3
     fun: nan
       x: [ 3.820e-01  3.820e-01  3.820e-01  3.820e-01]
     nit: 1
   direc: [[ 1.000e+00  0.000e+00  0.000e+00  0.000e+00]
           [ 0.000e+00  1.000e+00  0.000e+00  0.000e+00]
           [ 0.000e+00  0.000e+00  1.000e+00  0.000e+00]
           [ 0.000e+00  0.000e+00  0.000e+00  1.000e+00]]
    nfev: 81

Results#

After convergence, we can have a look at the likelihood and the parameters during the iterations. Ideally, the likelihood increases strictly monotonically.

history_df = pd.DataFrame(history)
history_df.plot(
    y=["llh", "0_TtoII_spread", "1_TtoII_spread"],
    subplots=[("llh",), ("0_TtoII_spread", "1_TtoII_spread")],
    sharex=True,
    xlim=(0, None),
);

And, more importantly, let’s also see if the learned parameters reproduce what we put into the model.

fixed_params = {}
fixed_params.update({f"0_{name}": value for name, value in PARAMS_C1.items()})
fixed_params.update({f"1_{name}": value for name, value in PARAMS_C2.items()})

learned_params = mixture.get_params(as_dict=True, as_flat=True)

for name, fixed in fixed_params.items():
    learned = learned_params[name]
    print(f"{name:>16s}: {fixed = :.3f}, {learned = :.3f}")

Sample Parameter Distribution#

To get a fully probabilistic distribution over the parameters and mixture coefficients, we can draw some parameter samples and then infer the optimal mixture coefficients for each sample.

from lymixture.em import complete_samples, sample_model_params

samples = sample_model_params(mixture, steps=20)
indices = np.random.choice(len(samples), 50, replace=False)
reduced_set = samples[indices]
completed_samples = complete_samples(mixture, reduced_set)