import numpy as np
from typing import Any
from collections.abc import Callable, Sequence
from pathlib import Path
from ..inventory._inventory import Inventory
from ._tools import (
log_likelihood,
map_params,
read_param_file,
_default_params,
_evaluate_statistics_method,
_prepare_binning,
)
from ..inventory._plotting import display
from ._processes import (
recruitment,
mortality,
model_step,
growth,
_set_parameters,
_get_parameters,
)
import matplotlib.pyplot as plt
import pybobyqa
[docs]
class Model:
"""
This is the individual-based model class. This class is used to create, run, analyse and optimize models.
:no-index:
"""
[docs]
def __init__(
self,
parameters: Path | dict[str, Any] | None = None,
initial_inventory: np.ndarray | Path | Inventory | None = None,
reference_inventory: np.ndarray | Path | None = None,
threads: int = 1,
):
"""
Initialize the model
:param parameters: Parameters to configure the model. Either they are supplied via a file or dictionary.
Otherwise they will be set to default values.
:param initial_inventory: Initial inventory data or path to file containing it.
:param reference_inventory: Reference inventory to compare the simulated inventory to.
For optimization and analysis purposes.
:param threads: The number of CPU-threads to be used to run the model.
"""
self.threads = threads
# Determine parameterss from source (do not persist a duplicate copy on the Python side)
if isinstance(parameters, Path):
# Initialize from parameters file
read_params = read_param_file(parameters)
params = map_params(read_params)
elif isinstance(parameters, dict):
# Initialize from params dictionary
params = parameters
else:
params = _default_params()
# Apply parameters to the Cython/backend layer
_set_parameters(params)
self.radius = params["neighborhood_radius"]
# Set up inventory for simulation using local params
if initial_inventory is None:
x_dimension = params["quadrat_dim_x"] * params["cell_size"]
y_dimension = params["quadrat_dim_y"] * params["cell_size"]
self._inventory = Inventory.from_random(
radius=self.radius,
n_species=params["num_species"],
dim_x=x_dimension,
dim_y=y_dimension,
num_threads=self.threads,
)
elif isinstance(initial_inventory, Path):
self._inventory = Inventory.from_data(initial_inventory, radius=self.radius)
else:
self._inventory = Inventory.from_data(initial_inventory, radius=self.radius)
if reference_inventory is None:
self.reference_inventory = self._inventory
else:
if isinstance(reference_inventory, Path):
self.reference_inventory = Inventory.from_data(
data=reference_inventory, radius=self.radius
)
else:
self.reference_inventory = Inventory.from_data(
data=reference_inventory, radius=self.radius
)
# Baseline snapshot uses the initial simulation inventory, not the reference inventory
self._initial_inventory_data = self._inventory.data.copy()
@property
def parameters(self) -> dict:
"""Access a copy of the current model parameters (backend source of truth)."""
return dict(
_get_parameters()
) # return a copy to prevent accidental in-place edits
[docs]
def update_parameters(self, parameters: dict) -> None:
"""
Update the model parameters from a dictionary (merge with existing).
Only update keys that already exist in the model parameters.
:param parameters: Dictionary with the new parameter configurations.
"""
current = dict(_get_parameters())
for k, v in parameters.items():
if k in current:
current[k] = v
else:
raise KeyError(f"Parameter '{k}' does not exist in the model.")
_set_parameters(current)
@property
def inventory(self) -> Inventory:
"""Read-only access to the current inventory."""
return self._inventory
[docs]
def step(self, n_steps: int):
"""
A simulation step in the model consists mainly of three different processes:
mortality, recruitment and growth. One step represents 5 years in real time.
:param n_steps: The amount of model steps to be simulated.
"""
self._inventory = Inventory.from_data(
data=model_step(self._inventory.data, n_steps, self.threads)
)
[docs]
def plot(
self,
reference_plot: bool = False,
threshold: int = 50,
filter: float = 1.0,
scale: float = 7.5,
figsize=(22, 10),
):
"""
Plot the current state of the model.
:param title: The title of the plot.
:param reference_plot: If True, the reference inventory is plotted.
:param threshold: Minimum abundance of conspecifics ToDo.
:param filter: Filter out all trees below that size for inventory plot.
"""
inventory = self.inventory
fig, axs = plt.subplots(5, 5, figsize=figsize)
plot_inventory = self.reference_inventory if reference_plot else inventory
display(
plot_inventory,
threshold=threshold,
filter=filter,
scale=scale,
abundances=inventory.species_abundance(),
axs=axs,
ref=reference_plot,
)
plt.show()
[docs]
def run_with_plotting(
self,
n_steps: int = 1000,
plot_interval: int = 50,
threshold: int = 50,
figsize=(24, 12),
):
"""
Run the simulation with plotting enabled to see the development of the forest inventory in real time.
ToDo: Extract this function to plotting.
:param n_steps: Amount of steps to be simulated.
:param plot_interval: The interval at which the inventories shall be plotted.
:param threshold: Minimum abundance of conspecifics ToDo.
:param figsize: The size of the plot.
"""
plt.ion()
plt.close("all")
fig, axs = plt.subplots(5, 5, figsize=figsize)
fig.subplots_adjust(hspace=0.4, wspace=0.4)
initial_plot_inventory = (
self.reference_inventory
if self.reference_inventory is not None
else self.inventory
)
bins_dict = display(
initial_plot_inventory,
threshold=threshold,
axs=axs,
ref=True,
)
# Store the initial line and scatter references
initial_objects: dict = {}
for i in range(5):
for j in range(5):
if not (i < 2 and j < 2): # Skip the inventory plot area
initial_objects[(i, j)] = {"lines": [], "scatter": []}
# Store line objects
for line in axs[i, j].lines:
if "Reference" in line.get_label():
initial_objects[(i, j)]["lines"].append(
{
"xdata": line.get_xdata().copy(),
"ydata": line.get_ydata().copy(),
"color": line.get_color(),
"linewidth": line.get_linewidth(),
"label": line.get_label(),
}
)
# Store scatter objects
for scatter in axs[i, j].collections:
if (
hasattr(scatter, "get_label")
and "Reference" in scatter.get_label()
):
initial_objects[(i, j)]["scatter"].append(
{
"offsets": scatter.get_offsets().copy(),
"color": (
scatter.get_facecolors()[0]
if len(scatter.get_facecolors()) > 0
else "red"
),
"marker": (
scatter.get_paths()[0]
if len(scatter.get_paths()) > 0
else "x"
),
"size": (
scatter.get_sizes()[0]
if len(scatter.get_sizes()) > 0
else 50
),
"label": scatter.get_label(),
}
)
abundance_history = []
for step_num in range(1, n_steps + 1):
self.step(plot_interval)
for i in range(5):
for j in range(5):
if not (i < 2 and j < 2): # Skip the plot area
axs[i, j].clear()
# Redraw the initial objects
for (i, j), objects in initial_objects.items():
# Redraw lines
for line_data in objects["lines"]:
axs[i, j].plot(
line_data["xdata"],
line_data["ydata"],
color=line_data["color"],
linewidth=line_data["linewidth"],
label=line_data["label"],
)
# Redraw scatter plots
for scatter_data in objects["scatter"]:
axs[i, j].scatter(
scatter_data["offsets"][:, 0],
scatter_data["offsets"][:, 1],
c=scatter_data["color"],
marker="x", # Force 'x' marker for initial data
s=scatter_data["size"],
label=scatter_data["label"],
)
# Get current inventory and update abundance history
current_abundance = self.inventory.species_abundance()
abundance_history.append(current_abundance)
# Plot current state with different marker style
bins_dict = display(
self.inventory,
threshold=threshold,
axs=axs,
ref=False,
bins_dict=bins_dict,
abundances=abundance_history,
)
fig.suptitle(f"Step {step_num}")
for i in range(5):
for j in range(5):
if not (i < 2 and j < 2) and axs[i, j].get_legend() is None:
handles, labels = axs[i, j].get_legend_handles_labels()
if handles:
axs[i, j].legend()
fig.canvas.draw()
fig.canvas.flush_events()
plt.pause(0.01) # Small pause to allow GUI to update
plt.ioff()
plt.show()
[docs]
def optimize_parameters(
self,
start_params: np.ndarray,
lower: np.ndarray,
upper: np.ndarray,
rhobeg: float,
rhoend: float,
param_names: list[str],
methods_to_evaluate: Sequence[
tuple[Callable[..., Any], Callable[..., Any] | property]
],
stat_indices: list[slice | list[int]] | None = None,
n_init: int = 50,
n_iter: int = 10,
iter_steps: int = 5,
reference_inventory: np.ndarray | None = None,
maxfun: int | None = None,
verbose: bool = False,
objfun_has_noise: bool = True,
) -> Any:
"""
The optimizer uses the PYBOBYQA algorithm to fit arbitrary statistics to the model parameters using the log likelihood function.
:param start_params: Initial guess for parameters.
:param lower: Lower bounds for parameters.
:param upper: Upper bounds for parameters.
:param param_names: Names of parameters to optimize (order matches start_params).
:param methods_to_evaluate: Functions to calculate the statistics that shall be matched.
Each tuple: (statistic function, kwargs for function).
:param stat_indices: For each stat_func, which indices of the values to use (default: use all).
:param n_init: Initial run steps.
:param n_iter: Number of iterations for averaging.
:param iter_steps: Steps per iteration.
:param maxfun: Maximum number of objective evaluations (defaults to a small value for speed).
:param verbose: If True, print intermediate diagnostics.
:param objfun_has_noise: Pass-through to PYBOBYQA's objfun_has_noise.
:returns: The result of the PYBOBYQA algorithm
"""
if reference_inventory:
ref_inv = reference_inventory
elif reference_inventory is None:
ref_inv = self.reference_inventory
else:
raise ValueError("No reference inventory available for optimization.")
stat_funcs = _prepare_binning(methods_to_evaluate, 100, ref_inv)
# Prepare reference statistics from initial inventory
reference = _evaluate_statistics_method(ref_inv, stat_funcs, stat_indices)
def objective(params):
update_dict = {k: v for k, v in zip(param_names, params)}
self.update_parameters(update_dict)
# Reset to baseline initial inventory (not the reference) before each evaluation
self._reset_inventory()
self.step(n_init)
sim_stats = [[] for _ in stat_funcs]
for _ in range(n_iter):
self.step(iter_steps)
stats = _evaluate_statistics_method(
self.inventory, stat_funcs, stat_indices
)
for idx, stat_val in enumerate(stats):
sim_stats[idx].append(stat_val)
sim_means = []
sim_stds = []
for stat_list in sim_stats:
arr = np.array(stat_list)
sim_means.append(np.mean(arr, axis=0))
sim_stds.append(np.std(arr, axis=0))
if verbose:
print(sim_means, sim_stds, reference)
log_like = log_likelihood(
np.concatenate(sim_means),
np.concatenate(sim_stds),
np.concatenate(reference),
)
if verbose:
print(log_like)
return log_like
result = pybobyqa.solve(
objective,
start_params,
bounds=(lower, upper),
objfun_has_noise=objfun_has_noise,
rhobeg=rhobeg,
rhoend=rhoend,
print_progress=verbose,
maxfun=maxfun,
)
self._reset_inventory()
return result
def _reset_inventory(self) -> None:
"""Reset current simulation inventory to the original baseline (initial) inventory snapshot."""
self._inventory = Inventory.from_data(
data=self._initial_inventory_data.copy(), radius=self.radius
)
[docs]
def mortality(self) -> Inventory:
"""
Apply mortality process to the current inventory and return indices of dead individuals.
Dead individuals are removed from the inventory.
:returns: Array of indices that were dead (before removal).
"""
status = mortality(
self.inventory.CI_CS,
self.inventory.CI_HS,
self.inventory.CI_CS_d,
self.inventory.CI_HS_d,
self.inventory.dbh,
self.inventory.species,
self.inventory.status,
)
# Remove dead individuals from inventory
if len(status) > 0:
deads = Inventory.from_data(data=self.inventory.data[status == -2])
# Create a mask of alive individuals (True = keep, False = remove)
alive_mask = status > -2
# Keep only alive individuals
self._inventory = Inventory.from_data(
data=self.inventory.data[alive_mask], radius=self.radius
)
return deads
[docs]
def growth(self) -> np.ndarray:
"""
Apply growth process to the current inventory, updating dbh values in place.
"""
old_dbh = self.inventory.dbh
self.inventory.dbh = growth(
self.inventory.CI_CS,
self.inventory.CI_HS,
self.inventory.dbh,
self.inventory.status,
)
delta_dbh = self.inventory.dbh - old_dbh
return delta_dbh
[docs]
def recruitment(self) -> Inventory:
"""
Apply recruitment process, add the recruits to the inventory.
And return an Inventory of the new recruits.
ToDo test it.
"""
recruits = recruitment(
self.inventory.x,
self.inventory.y,
self.inventory.dbh,
self.inventory.species,
self.inventory.status,
self.threads,
)
recruits_inventory = Inventory.from_data(recruits)
self._inventory.extend_inventory(recruits_inventory)
return recruits_inventory