Inference Using the Expectation-Maximization Algorithm#
In this notebook we demonstrate how to train the mixture lymphatic progression models. We do this for a simple set of synthetic data and see if and how well we can recover the original parameters that we set.
Imports#
from collections import namedtuple
from typing import Literal, Any, Callable
import numpy as np
import pandas as pd
from lymixture import LymphMixture
from lymixture import utils
from lymph.models import Unilateral
rng = np.random.default_rng(42)
Synthetic Data#
Define parameters and configuration to draw a number of synthetic data.
Modality = namedtuple("Modality", ["spec", "sens"])
# definition of the directed acyclic graph
GRAPH_DICT = {
("tumor", "T"): ["II", "III"],
("lnl", "II"): ["III"],
("lnl", "III"): [],
}
# definition of the diagnostic modality
MODALITIES = {
"path": Modality(spec=0.9, sens=0.9),
}
# assumed distributions over the time to diagnosis
DISTRIBUTIONS = {
"early": utils.binom_pmf(k=np.arange(11), n=10, p=0.3),
"late": utils.late_binomial,
}
# params of component 1
PARAMS_C1 = {
"TtoII_spread": 0.05,
"TtoIII_spread": 0.25,
"IItoIII_spread": 0.5,
"late_p": 0.5,
}
# params of component 2
PARAMS_C2 = {
"TtoII_spread": 0.25,
"TtoIII_spread": 0.05,
"IItoIII_spread": 0.1,
"late_p": 0.5,
}
SUBSITE_COL = ("tumor", "1", "subsite")
ModalityDict = dict[str, dict[str, float | Literal["clinical", "pathological"]]]
def create_model(
model_kwargs: dict[str, Any] | None = None,
modalities: ModalityDict | None = None,
distributions: dict[str, list[float] | Callable] | None = None,
) -> Unilateral:
"""Create a model to draw patients from."""
model = Unilateral(**(model_kwargs or {"graph_dict": GRAPH_DICT}))
for name, modality in (modalities or MODALITIES).items():
model.set_modality(name, modality.spec, modality.sens)
for t_stage, dist in (distributions or DISTRIBUTIONS).items():
model.set_distribution(t_stage, dist)
return model
def draw_datasets(
model: Unilateral,
num_c1: int,
num_c2: int,
num_c3: int,
tstage_ratio: float,
mix: float,
rng: np.random.Generator,
) -> pd.DataFrame:
"""Draw patients for the three datasets."""
model.set_params(**PARAMS_C1)
c1_data = model.draw_patients(
num=num_c1 + int(num_c3 * mix),
stage_dist=[tstage_ratio, 1 - tstage_ratio],
rng=rng,
)
model.set_params(**PARAMS_C2)
c2_data = model.draw_patients(
num=num_c2 + int(num_c3 * (1 - mix)),
stage_dist=[tstage_ratio, 1 - tstage_ratio],
rng=rng,
)
c3_data = pd.concat(
[
c1_data.iloc[num_c1:],
c2_data.iloc[num_c2:],
],
ignore_index=True,
axis=0,
)
c1_data = c1_data.iloc[:num_c1]
c2_data = c2_data.iloc[:num_c2]
c1_data[SUBSITE_COL] = "c1"
c2_data[SUBSITE_COL] = "c2"
c3_data[SUBSITE_COL] = "c3"
return pd.concat([c1_data, c2_data, c3_data], ignore_index=True, axis=0)
model = create_model()
synthetic_data = draw_datasets(
model=model,
num_c1=1000,
num_c2=1000,
num_c3=1000,
tstage_ratio=0.4,
mix=0.5,
rng=rng,
)
random_idx = rng.choice(synthetic_data.index, size=6, replace=False)
synthetic_data.iloc[random_idx]
| path | tumor | |||
|---|---|---|---|---|
| ipsi | core | 1 | ||
| II | III | t_stage | subsite | |
| 737 | False | True | early | c1 |
| 2617 | True | True | late | c3 |
| 111 | False | False | early | c1 |
| 1466 | False | True | early | c2 |
| 2177 | True | True | early | c3 |
| 429 | True | True | late | c1 |
Model Initialization#
Now, we define the mixture model and load the just drawn data. Note that we use only two components, hoping that the "c3" subgroup can be described as a mixture of these two components.
graph = {
("tumor", "T"): ["II", "III"],
("lnl", "II"): ["III"],
("lnl", "III"): [],
}
num_components = 2
mixture = LymphMixture(
model_cls=Unilateral,
model_kwargs={"graph_dict": graph},
num_components=num_components,
universal_p=False,
)
mixture.load_patient_data(
synthetic_data,
split_by=("tumor", "1", "subsite"),
mapping=lambda x: x,
)
Set the diagnostic modality to be the same as in the generated dataset.
for name, modality in MODALITIES.items():
mixture.set_modality(name=name, spec=modality.spec, sens=modality.sens)
Fix the distribution over diagnosis times. Again, we set this to be the same as during the synthetic data generation.
for t_stage, dist in DISTRIBUTIONS.items():
mixture.set_distribution(t_stage, dist)
Inference#
from lymixture.em import expectation, maximization
The iterative steps of computing the expectation over the latent variables (E-step) and maximizing the model parameters (M-step) can be initialized with an arbitrary set of starting parameters.
params = {k: rng.uniform() for k in mixture.get_params()}
mixture.set_params(**params)
mixture.normalize_mixture_coefs()
latent = utils.normalize(rng.uniform(size=mixture.get_resps().shape).T, axis=0).T
Then we define a function to check the convergence of the algorithm.
def is_converged(
history: list[dict[str, float]],
rtol: float = 1e-4,
) -> bool:
"""Check if the EM algorithm has converged."""
if len(history) < 2:
return False
old, new = history[-2]["llh"], history[-1]["llh"]
return np.isclose(old, new, rtol=rtol)
Finally, we can iterate the computation of the expectation value of the latent variables (E-step) and the maximization of the (complete) data log-likelihood w.r.t. the model parameters (M-step).
While the algorithm converges, we check the incomplete data likelihood after each round.
count = 0
snapshot = {
"llh": mixture.incomplete_data_likelihood(),
**mixture.get_params(as_dict=True, as_flat=True),
}
history = [snapshot]
while not is_converged(history, rtol=1e-4):
print(f"iteration {count:>3d}: {history[-1]['llh']:.3f}")
count += 1
latent = expectation(mixture, params)
assert np.allclose(latent.sum(axis=1), 1.)
params = maximization(mixture, latent)
snapshot = {
"llh": mixture.incomplete_data_likelihood(),
**mixture.get_params(as_dict=True, as_flat=True),
}
history.append(snapshot)
iteration 0: -4912.985
---------------------------------------------------------------------------
_RemoteTraceback Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/home/docs/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
return [fn(*args) for args in chunk]
File "/home/docs/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
return [fn(*args) for args in chunk]
File "/home/docs/checkouts/readthedocs.org/user_builds/lymixture/envs/latest/lib/python3.10/site-packages/lymixture/em.py", line 119, in _optimize_single_component
raise RuntimeError(msg)
RuntimeError: Optimization failed for component 0: message: NaN result encountered.
success: False
status: 3
fun: nan
x: [ 3.820e-01 3.820e-01 3.820e-01 3.820e-01]
nit: 1
direc: [[ 1.000e+00 0.000e+00 0.000e+00 0.000e+00]
[ 0.000e+00 1.000e+00 0.000e+00 0.000e+00]
[ 0.000e+00 0.000e+00 1.000e+00 0.000e+00]
[ 0.000e+00 0.000e+00 0.000e+00 1.000e+00]]
nfev: 81
"""
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[12], line 14
12 latent = expectation(mixture, params)
13 assert np.allclose(latent.sum(axis=1), 1.)
---> 14 params = maximization(mixture, latent)
16 snapshot = {
17 "llh": mixture.incomplete_data_likelihood(),
18 **mixture.get_params(as_dict=True, as_flat=True),
19 }
20 history.append(snapshot)
File ~/checkouts/readthedocs.org/user_builds/lymixture/envs/latest/lib/python3.10/site-packages/lymixture/em.py:262, in maximization(model, log_resps, parallelize, method)
260 # Use ProcessPoolExecutor to parallelize
261 with ProcessPoolExecutor(max_workers=num_components) as executor:
--> 262 results = list(executor.map(_optimize_single_component, optimization_args))
264 # Set optimized parameters back to components
265 for i, optimized_params in results:
File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
569 def _chain_from_iterable_of_lists(iterable):
570 """
571 Specialized implementation of itertools.chain.from_iterable.
572 Each item in *iterable* should be a list. This function is
573 careful not to keep references to yielded objects.
574 """
--> 575 for element in iterable:
576 element.reverse()
577 while element:
File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
618 while fs:
619 # Careful not to keep a reference to the popped future
620 if timeout is None:
--> 621 yield _result_or_cancel(fs.pop())
622 else:
623 yield _result_or_cancel(fs.pop(), end_time - time.monotonic())
File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
317 try:
318 try:
--> 319 return fut.result(timeout)
320 finally:
321 fut.cancel()
File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
456 raise CancelledError()
457 elif self._state == FINISHED:
--> 458 return self.__get_result()
459 else:
460 raise TimeoutError()
File ~/.asdf/installs/python/3.10.17/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
401 if self._exception:
402 try:
--> 403 raise self._exception
404 finally:
405 # Break a reference cycle with the exception in self._exception
406 self = None
RuntimeError: Optimization failed for component 0: message: NaN result encountered.
success: False
status: 3
fun: nan
x: [ 3.820e-01 3.820e-01 3.820e-01 3.820e-01]
nit: 1
direc: [[ 1.000e+00 0.000e+00 0.000e+00 0.000e+00]
[ 0.000e+00 1.000e+00 0.000e+00 0.000e+00]
[ 0.000e+00 0.000e+00 1.000e+00 0.000e+00]
[ 0.000e+00 0.000e+00 0.000e+00 1.000e+00]]
nfev: 81
Results#
After convergence, we can have a look at the likelihood and the parameters during the iterations. Ideally, the likelihood increases strictly monotonically.
history_df = pd.DataFrame(history)
history_df.plot(
y=["llh", "0_TtoII_spread", "1_TtoII_spread"],
subplots=[("llh",), ("0_TtoII_spread", "1_TtoII_spread")],
sharex=True,
xlim=(0, None),
);
And, more importantly, let’s also see if the learned parameters reproduce what we put into the model.
fixed_params = {}
fixed_params.update({f"0_{name}": value for name, value in PARAMS_C1.items()})
fixed_params.update({f"1_{name}": value for name, value in PARAMS_C2.items()})
learned_params = mixture.get_params(as_dict=True, as_flat=True)
for name, fixed in fixed_params.items():
learned = learned_params[name]
print(f"{name:>16s}: {fixed = :.3f}, {learned = :.3f}")
Sample Parameter Distribution#
To get a fully probabilistic distribution over the parameters and mixture coefficients, we can draw some parameter samples and then infer the optimal mixture coefficients for each sample.
from lymixture.em import complete_samples, sample_model_params
samples = sample_model_params(mixture, steps=20)
indices = np.random.choice(len(samples), 50, replace=False)
reduced_set = samples[indices]
completed_samples = complete_samples(mixture, reduced_set)