Source code for spatiocoexistence.model

import numpy as np
from typing import Any
from collections.abc import Callable
from pathlib import Path
from .py_tools import (
    create_initial_inventory,
    log_likelihood,
    map_params,
    read_initial_inventory,
    read_param_file,
    _default_params,
)
from spatiocoexistence.processes import (
    recruitment,
    mortality,
    model_step,
    growth,
    set_parameters,
    get_parameters,
)
from spatiocoexistence.crowding import (
    crowding_indices,
)
import matplotlib.pyplot as plt
from spatiocoexistence.plotting import display
from spatiocoexistence.tools import species_abundance
from matplotlib.axes import Axes
import pybobyqa


[docs] class SpatioCoexistenceModel:
[docs] def __init__( self, parameter_file: Path | None = None, params: dict[str, Any] | None = None, initial_inventory: np.ndarray | Path | None = None, use_initial_as_reference: bool = False, reference_inventory: np.ndarray | Path | None = None, threads: int = 1, ): """ Initialize the model with parameters from a dictionary or a file. If neither is provided, use default parameters. Parameters: ----------- parameter_file : Path, optional Path to a parameter file params : dict, optional Dictionary of model parameters initial_inventory : np.ndarray | Path, optional Initial inventory data or path to file containing it use_initial_as_reference : bool, default=False If True, the initial_inventory is only used for plotting comparison and a new inventory is created for simulation """ self.threads = threads self.use_initial_as_reference = use_initial_as_reference # Check for conflicting parameter sources if parameter_file and params: raise ValueError( "Cannot specify both parameter_file and params. Choose one method to initialize parameters." ) # Initialize parameters based on the given source if parameter_file: # Initialize from parameter file read_params = read_param_file(parameter_file) self.params = map_params(read_params) elif params: # Initialize from params dictionary self.params = params else: # No parameters provided, use defaults self.params = _default_params() # Apply parameters to the Cython backend set_parameters(self.params) # Store initial inventory for plotting if needed if initial_inventory is not None and use_initial_as_reference: if isinstance(initial_inventory, Path): self.reference_inventory = read_initial_inventory(initial_inventory) else: self.reference_inventory = initial_inventory # Set up inventory for simulation if initial_inventory is None: # Create a new inventory for simulation x_dimension = self.params["quadrat_dim_x"] * self.params["cell_size"] y_dimension = self.params["quadrat_dim_y"] * self.params["cell_size"] self.inventory = create_initial_inventory( n_species=self.params["num_species"], dim_x=x_dimension, dim_y=y_dimension, radius=self.params["neighborhood_radius"], num_threads=self.threads, ) if use_initial_as_reference: self.reference_inventory = self.inventory elif isinstance(initial_inventory, Path): self.inventory = read_initial_inventory(initial_inventory) else: self.inventory = initial_inventory if reference_inventory is not None: if isinstance(reference_inventory, Path): self.reference_inventory = read_initial_inventory(reference_inventory) print("Path") else: self.reference_inventory = reference_inventory print("inv")
[docs] def step(self) -> None: """ One simulation timestep consists of these processes in the same order. """ deads = mortality( self.inventory["CI_CS"], self.inventory["CI_HS"], self.inventory["CI_CS_d"], self.inventory["CI_HS_d"], self.inventory["dbh"], self.inventory["status"], ) recruits = recruitment( self.inventory["x"], self.inventory["y"], self.inventory["dbh"], self.inventory["species"], self.inventory["status"], self.threads, ) mask = np.ones(len(self.inventory), dtype=bool) mask[deads] = False self.inventory = self.inventory[mask] self.inventory = np.concatenate([self.inventory, recruits]) self.inventory["dbh"] = growth( self.inventory["CI_CS"], self.inventory["CI_HS"], self.inventory["dbh"], self.inventory["status"], ) CI_CS, CI_HS, CI_CS_d, CI_HS_d = crowding_indices( self.inventory["x"], self.inventory["y"], self.inventory["species"], self.inventory["status"], self.params["neighborhood_radius"], dbh=self.inventory["dbh"], num_threads=self.threads, ) """ Other possibilities to calculate crowding indices. CI_CS, CI_HS = crowding_nano_kdTree( self.inventory["x"], self.inventory["y"], self.inventory["species"], self.inventory["dbh"], self.params["neighborhood_radius"], ) CI_CS, CI_HS = crowding_kdTree( self.inventory["x"], self.inventory["y"], self.inventory["species"], self.inventory["dbh"], self.params["neighborhood_radius"], ) """ self.inventory["CI_CS"] = CI_CS self.inventory["CI_HS"] = CI_HS self.inventory["CI_CS_d"] = CI_CS_d self.inventory["CI_HS_d"] = CI_HS_d
[docs] def cy_run(self, n_steps: int): self.inventory = model_step(self.inventory, n_steps, self.threads)
[docs] def run(self, n_steps: int) -> None: for _ in range(n_steps): self.step()
[docs] def get_inventory(self) -> np.ndarray: return self.inventory
[docs] def get_params(self) -> dict: """Return the current model parameters as a dictionary.""" return get_parameters()
[docs] def update_params(self, params: dict) -> None: """Set the model parameters from a dictionary and update radius if present.""" self.params.update(params) set_parameters(self.params)
[docs] def plot( self, title: str, initial_plot: bool = False, axs: Axes | None = None, threshold: int = 50, filter: float = 1.0, abundances: np.ndarray | None = None, figsize=(22, 10), ): """Plot the current state of the model.""" inventory = self.get_inventory() if axs is None: fig, axs = plt.subplots(5, 5, figsize=figsize) else: fig = axs[0, 0].figure # If we're plotting the initial state and have a reference inventory, use that instead if initial_plot and self.reference_inventory: plot_inventory = self.reference_inventory else: plot_inventory = inventory display( plot_inventory, self.params["neighborhood_radius"], threshold=threshold, filter=filter, abundances=abundances, axs=axs, initial=initial_plot, ) fig.suptitle(title) plt.show()
[docs] def run_with_plotting( self, n_steps: int = 1000, plot_interval: int = 50, threshold: int = 50, figsize=(24, 12), ): """Run simulation with interactive plotting.""" plt.ion() plt.close("all") fig, axs = plt.subplots(5, 5, figsize=figsize) fig.subplots_adjust(hspace=0.4, wspace=0.4) # Determine which inventory to use for initial plotting initial_plot_inventory = ( self.reference_inventory if self.reference_inventory is not None else self.get_inventory() ) # Calculate max_species for abundance tracking max_species = max(initial_plot_inventory["species"]) # Plot initial distribution and get the bins bins_dict = display( initial_plot_inventory, self.params["neighborhood_radius"], threshold=threshold, axs=axs, initial=True, params=self.params, ) # Store the initial line and scatter references initial_objects: dict = {} for i in range(5): for j in range(5): if not (i < 2 and j < 2): # Skip the inventory plot area initial_objects[(i, j)] = {"lines": [], "scatter": []} # Store line objects for line in axs[i, j].lines: if "Initial" in line.get_label(): initial_objects[(i, j)]["lines"].append( { "xdata": line.get_xdata().copy(), "ydata": line.get_ydata().copy(), "color": line.get_color(), "linewidth": line.get_linewidth(), "label": line.get_label(), } ) # Store scatter objects for scatter in axs[i, j].collections: if ( hasattr(scatter, "get_label") and "Initial" in scatter.get_label() ): initial_objects[(i, j)]["scatter"].append( { "offsets": scatter.get_offsets().copy(), "color": ( scatter.get_facecolors()[0] if len(scatter.get_facecolors()) > 0 else "red" ), "marker": ( scatter.get_paths()[0] if len(scatter.get_paths()) > 0 else "x" ), "size": ( scatter.get_sizes()[0] if len(scatter.get_sizes()) > 0 else 50 ), "label": scatter.get_label(), } ) # Now run simulation with normal histograms that will overlap the lines abundance_history = [] for step_num in range(1, n_steps + 1): # Clear all axes self.cy_run(plot_interval) for i in range(5): for j in range(5): if not (i < 2 and j < 2): # Skip the plot area axs[i, j].clear() # Redraw the initial objects for (i, j), objects in initial_objects.items(): # Redraw lines for line_data in objects["lines"]: axs[i, j].plot( line_data["xdata"], line_data["ydata"], color=line_data["color"], linewidth=line_data["linewidth"], label=line_data["label"], ) # Redraw scatter plots for scatter_data in objects["scatter"]: axs[i, j].scatter( scatter_data["offsets"][:, 0], scatter_data["offsets"][:, 1], c=scatter_data["color"], marker="x", # Force 'x' marker for initial data s=scatter_data["size"], label=scatter_data["label"], ) # Get current inventory and update abundance history inventory = self.get_inventory() current_abundance = species_abundance(inventory["species"], max_species) abundance_history.append(current_abundance) # Plot current state with different marker style bins_dict = display( self.get_inventory(), self.params["neighborhood_radius"], threshold=threshold, axs=axs, initial=False, bins_dict=bins_dict, abundances=abundance_history, params=self.params, ) fig.suptitle(f"Step {step_num}") for i in range(5): for j in range(5): if not (i < 2 and j < 2) and axs[i, j].get_legend() is None: handles, labels = axs[i, j].get_legend_handles_labels() if handles: axs[i, j].legend() fig.canvas.draw() fig.canvas.flush_events() plt.pause(0.01) # Small pause to allow GUI to update plt.ioff() plt.show()
[docs] def optimize_parameters( self, start_params: np.ndarray, lower: np.ndarray, upper: np.ndarray, rhobeg: float, rhoend: float, param_names: list[str], stat_funcs: list[tuple[Callable, dict]], stat_indices: list[slice | list[int]] | None = None, n_init: int = 50, n_iter: int = 10, iter_steps: int = 5, reference_inventory: np.ndarray | None = None, ): """ Generic optimizer for model parameters to fit arbitrary statistics using log_likelihood. Parameters ---------- start_params : np.ndarray Initial guess for parameters. lower : np.ndarray Lower bounds for parameters. upper : np.ndarray Upper bounds for parameters. param_names : list[str] Names of parameters to optimize (order matches start_params). stat_funcs : list of (func, func_kwargs) Each tuple: (statistic function, kwargs for function). stat_indices : list of slices or lists, optional For each stat_func, which indices to use (default: use all). n_init : int Initial run steps. n_iter : int Number of iterations for averaging. iter_steps : int Steps per iteration. """ if reference_inventory: ref_inv = reference_inventory elif reference_inventory is None and self.use_initial_as_reference: ref_inv = self.reference_inventory else: raise ValueError("No reference inventory available for optimization.") # Prepare reference statistics from initial inventory reference = [] for i, (func, func_kwargs) in enumerate(stat_funcs): vals = func(ref_inv, **func_kwargs) if stat_indices and stat_indices[i]: vals = vals[stat_indices[i]] reference.append(vals) def objective(params): update_dict = {k: v for k, v in zip(param_names, params)} self.update_params(update_dict) self.cy_run(n_init) sim_stats = [[] for _ in stat_funcs] for _ in range(n_iter): self.cy_run(iter_steps) inventory = self.get_inventory() for idx, (func, func_kwargs) in enumerate(stat_funcs): vals = func(inventory, **func_kwargs) if stat_indices and stat_indices[idx]: vals = vals[stat_indices[idx]] sim_stats[idx].append(vals) sim_means = [] sim_stds = [] for stat_list in sim_stats: arr = np.array(stat_list) sim_means.append(np.mean(arr, axis=0)) sim_stds.append(np.std(arr, axis=0)) return log_likelihood( np.concatenate(sim_means), np.concatenate(sim_stds), np.concatenate(reference), ) result = pybobyqa.solve( objective, start_params, bounds=(lower, upper), objfun_has_noise=True, rhobeg=rhobeg, rhoend=rhoend, print_progress=True, ) return result