import numpy as np
from typing import Any
from collections.abc import Callable
from pathlib import Path
from .py_tools import (
    create_initial_inventory,
    log_likelihood,
    map_params,
    read_initial_inventory,
    read_param_file,
    _default_params,
)
from spatiocoexistence.processes import (
    recruitment,
    mortality,
    model_step,
    growth,
    set_parameters,
    get_parameters,
)
from spatiocoexistence.crowding import (
    crowding_indices,
)
import matplotlib.pyplot as plt
from spatiocoexistence.plotting import display
from spatiocoexistence.tools import species_abundance
from matplotlib.axes import Axes
import pybobyqa
[docs]
class SpatioCoexistenceModel:
[docs]
    def __init__(
        self,
        parameter_file: Path | None = None,
        params: dict[str, Any] | None = None,
        initial_inventory: np.ndarray | Path | None = None,
        use_initial_as_reference: bool = False,
        reference_inventory: np.ndarray | Path | None = None,
        threads: int = 1,
    ):
        """
        Initialize the model with parameters from a dictionary or a file.
        If neither is provided, use default parameters.
        Parameters:
        -----------
        parameter_file : Path, optional
            Path to a parameter file
        params : dict, optional
            Dictionary of model parameters
        initial_inventory : np.ndarray | Path, optional
            Initial inventory data or path to file containing it
        use_initial_as_reference : bool, default=False
            If True, the initial_inventory is only used for plotting comparison
            and a new inventory is created for simulation
        """
        self.threads = threads
        self.use_initial_as_reference = use_initial_as_reference
        # Check for conflicting parameter sources
        if parameter_file and params:
            raise ValueError(
                "Cannot specify both parameter_file and params. Choose one method to initialize parameters."
            )
        # Initialize parameters based on the given source
        if parameter_file:
            # Initialize from parameter file
            read_params = read_param_file(parameter_file)
            self.params = map_params(read_params)
        elif params:
            # Initialize from params dictionary
            self.params = params
        else:
            # No parameters provided, use defaults
            self.params = _default_params()
        # Apply parameters to the Cython backend
        set_parameters(self.params)
        # Store initial inventory for plotting if needed
        if initial_inventory is not None and use_initial_as_reference:
            if isinstance(initial_inventory, Path):
                self.reference_inventory = read_initial_inventory(initial_inventory)
            else:
                self.reference_inventory = initial_inventory
        # Set up inventory for simulation
        if initial_inventory is None:
            # Create a new inventory for simulation
            x_dimension = self.params["quadrat_dim_x"] * self.params["cell_size"]
            y_dimension = self.params["quadrat_dim_y"] * self.params["cell_size"]
            self.inventory = create_initial_inventory(
                n_species=self.params["num_species"],
                dim_x=x_dimension,
                dim_y=y_dimension,
                radius=self.params["neighborhood_radius"],
                num_threads=self.threads,
            )
            if use_initial_as_reference:
                self.reference_inventory = self.inventory
        elif isinstance(initial_inventory, Path):
            self.inventory = read_initial_inventory(initial_inventory)
        else:
            self.inventory = initial_inventory
        if reference_inventory is not None:
            if isinstance(reference_inventory, Path):
                self.reference_inventory = read_initial_inventory(reference_inventory)
                print("Path")
            else:
                self.reference_inventory = reference_inventory
                print("inv") 
[docs]
    def step(self) -> None:
        """
        One simulation timestep consists of these processes in the same order.
        """
        deads = mortality(
            self.inventory["CI_CS"],
            self.inventory["CI_HS"],
            self.inventory["CI_CS_d"],
            self.inventory["CI_HS_d"],
            self.inventory["dbh"],
            self.inventory["status"],
        )
        recruits = recruitment(
            self.inventory["x"],
            self.inventory["y"],
            self.inventory["dbh"],
            self.inventory["species"],
            self.inventory["status"],
            self.threads,
        )
        mask = np.ones(len(self.inventory), dtype=bool)
        mask[deads] = False
        self.inventory = self.inventory[mask]
        self.inventory = np.concatenate([self.inventory, recruits])
        self.inventory["dbh"] = growth(
            self.inventory["CI_CS"],
            self.inventory["CI_HS"],
            self.inventory["dbh"],
            self.inventory["status"],
        )
        CI_CS, CI_HS, CI_CS_d, CI_HS_d = crowding_indices(
            self.inventory["x"],
            self.inventory["y"],
            self.inventory["species"],
            self.inventory["status"],
            self.params["neighborhood_radius"],
            dbh=self.inventory["dbh"],
            num_threads=self.threads,
        )
        """
        Other possibilities to calculate crowding indices.
        CI_CS, CI_HS = crowding_nano_kdTree(
            self.inventory["x"],
            self.inventory["y"],
            self.inventory["species"],
            self.inventory["dbh"],
            self.params["neighborhood_radius"],
        )
        CI_CS, CI_HS = crowding_kdTree(
            self.inventory["x"],
            self.inventory["y"],
            self.inventory["species"],
            self.inventory["dbh"],
            self.params["neighborhood_radius"],
        )
        """
        self.inventory["CI_CS"] = CI_CS
        self.inventory["CI_HS"] = CI_HS
        self.inventory["CI_CS_d"] = CI_CS_d
        self.inventory["CI_HS_d"] = CI_HS_d 
[docs]
    def cy_run(self, n_steps: int):
        self.inventory = model_step(self.inventory, n_steps, self.threads) 
[docs]
    def run(self, n_steps: int) -> None:
        for _ in range(n_steps):
            self.step() 
[docs]
    def get_inventory(self) -> np.ndarray:
        return self.inventory 
[docs]
    def get_params(self) -> dict:
        """Return the current model parameters as a dictionary."""
        return get_parameters() 
[docs]
    def update_params(self, params: dict) -> None:
        """Set the model parameters from a dictionary and update radius if present."""
        self.params.update(params)
        set_parameters(self.params) 
[docs]
    def plot(
        self,
        title: str,
        initial_plot: bool = False,
        axs: Axes | None = None,
        threshold: int = 50,
        filter: float = 1.0,
        abundances: np.ndarray | None = None,
        figsize=(22, 10),
    ):
        """Plot the current state of the model."""
        inventory = self.get_inventory()
        if axs is None:
            fig, axs = plt.subplots(5, 5, figsize=figsize)
        else:
            fig = axs[0, 0].figure
        # If we're plotting the initial state and have a reference inventory, use that instead
        if initial_plot and self.reference_inventory:
            plot_inventory = self.reference_inventory
        else:
            plot_inventory = inventory
        display(
            plot_inventory,
            self.params["neighborhood_radius"],
            threshold=threshold,
            filter=filter,
            abundances=abundances,
            axs=axs,
            initial=initial_plot,
        )
        fig.suptitle(title)
        plt.show() 
[docs]
    def run_with_plotting(
        self,
        n_steps: int = 1000,
        plot_interval: int = 50,
        threshold: int = 50,
        figsize=(24, 12),
    ):
        """Run simulation with interactive plotting."""
        plt.ion()
        plt.close("all")
        fig, axs = plt.subplots(5, 5, figsize=figsize)
        fig.subplots_adjust(hspace=0.4, wspace=0.4)
        # Determine which inventory to use for initial plotting
        initial_plot_inventory = (
            self.reference_inventory
            if self.reference_inventory is not None
            else self.get_inventory()
        )
        # Calculate max_species for abundance tracking
        max_species = max(initial_plot_inventory["species"])
        # Plot initial distribution and get the bins
        bins_dict = display(
            initial_plot_inventory,
            self.params["neighborhood_radius"],
            threshold=threshold,
            axs=axs,
            initial=True,
            params=self.params,
        )
        # Store the initial line and scatter references
        initial_objects: dict = {}
        for i in range(5):
            for j in range(5):
                if not (i < 2 and j < 2):  # Skip the inventory plot area
                    initial_objects[(i, j)] = {"lines": [], "scatter": []}
                    # Store line objects
                    for line in axs[i, j].lines:
                        if "Initial" in line.get_label():
                            initial_objects[(i, j)]["lines"].append(
                                {
                                    "xdata": line.get_xdata().copy(),
                                    "ydata": line.get_ydata().copy(),
                                    "color": line.get_color(),
                                    "linewidth": line.get_linewidth(),
                                    "label": line.get_label(),
                                }
                            )
                    # Store scatter objects
                    for scatter in axs[i, j].collections:
                        if (
                            hasattr(scatter, "get_label")
                            and "Initial" in scatter.get_label()
                        ):
                            initial_objects[(i, j)]["scatter"].append(
                                {
                                    "offsets": scatter.get_offsets().copy(),
                                    "color": (
                                        scatter.get_facecolors()[0]
                                        if len(scatter.get_facecolors()) > 0
                                        else "red"
                                    ),
                                    "marker": (
                                        scatter.get_paths()[0]
                                        if len(scatter.get_paths()) > 0
                                        else "x"
                                    ),
                                    "size": (
                                        scatter.get_sizes()[0]
                                        if len(scatter.get_sizes()) > 0
                                        else 50
                                    ),
                                    "label": scatter.get_label(),
                                }
                            )
        # Now run simulation with normal histograms that will overlap the lines
        abundance_history = []
        for step_num in range(1, n_steps + 1):
            # Clear all axes
            self.cy_run(plot_interval)
            for i in range(5):
                for j in range(5):
                    if not (i < 2 and j < 2):  # Skip the plot area
                        axs[i, j].clear()
            # Redraw the initial objects
            for (i, j), objects in initial_objects.items():
                # Redraw lines
                for line_data in objects["lines"]:
                    axs[i, j].plot(
                        line_data["xdata"],
                        line_data["ydata"],
                        color=line_data["color"],
                        linewidth=line_data["linewidth"],
                        label=line_data["label"],
                    )
                # Redraw scatter plots
                for scatter_data in objects["scatter"]:
                    axs[i, j].scatter(
                        scatter_data["offsets"][:, 0],
                        scatter_data["offsets"][:, 1],
                        c=scatter_data["color"],
                        marker="x",  # Force 'x' marker for initial data
                        s=scatter_data["size"],
                        label=scatter_data["label"],
                    )
            # Get current inventory and update abundance history
            inventory = self.get_inventory()
            current_abundance = species_abundance(inventory["species"], max_species)
            abundance_history.append(current_abundance)
            # Plot current state with different marker style
            bins_dict = display(
                self.get_inventory(),
                self.params["neighborhood_radius"],
                threshold=threshold,
                axs=axs,
                initial=False,
                bins_dict=bins_dict,
                abundances=abundance_history,
                params=self.params,
            )
            fig.suptitle(f"Step {step_num}")
            for i in range(5):
                for j in range(5):
                    if not (i < 2 and j < 2) and axs[i, j].get_legend() is None:
                        handles, labels = axs[i, j].get_legend_handles_labels()
                        if handles:
                            axs[i, j].legend()
            fig.canvas.draw()
            fig.canvas.flush_events()
            plt.pause(0.01)  # Small pause to allow GUI to update
        plt.ioff()
        plt.show() 
[docs]
    def optimize_parameters(
        self,
        start_params: np.ndarray,
        lower: np.ndarray,
        upper: np.ndarray,
        rhobeg: float,
        rhoend: float,
        param_names: list[str],
        stat_funcs: list[tuple[Callable, dict]],
        stat_indices: list[slice | list[int]] | None = None,
        n_init: int = 50,
        n_iter: int = 10,
        iter_steps: int = 5,
        reference_inventory: np.ndarray | None = None,
    ):
        """
        Generic optimizer for model parameters to fit arbitrary statistics using log_likelihood.
        Parameters
        ----------
        start_params : np.ndarray
            Initial guess for parameters.
        lower : np.ndarray
            Lower bounds for parameters.
        upper : np.ndarray
            Upper bounds for parameters.
        param_names : list[str]
            Names of parameters to optimize (order matches start_params).
        stat_funcs : list of (func, func_kwargs)
            Each tuple: (statistic function, kwargs for function).
        stat_indices : list of slices or lists, optional
            For each stat_func, which indices to use (default: use all).
        n_init : int
            Initial run steps.
        n_iter : int
            Number of iterations for averaging.
        iter_steps : int
            Steps per iteration.
        """
        if reference_inventory:
            ref_inv = reference_inventory
        elif reference_inventory is None and self.use_initial_as_reference:
            ref_inv = self.reference_inventory
        else:
            raise ValueError("No reference inventory available for optimization.")
        # Prepare reference statistics from initial inventory
        reference = []
        for i, (func, func_kwargs) in enumerate(stat_funcs):
            vals = func(ref_inv, **func_kwargs)
            if stat_indices and stat_indices[i]:
                vals = vals[stat_indices[i]]
            reference.append(vals)
        def objective(params):
            update_dict = {k: v for k, v in zip(param_names, params)}
            self.update_params(update_dict)
            self.cy_run(n_init)
            sim_stats = [[] for _ in stat_funcs]
            for _ in range(n_iter):
                self.cy_run(iter_steps)
                inventory = self.get_inventory()
                for idx, (func, func_kwargs) in enumerate(stat_funcs):
                    vals = func(inventory, **func_kwargs)
                    if stat_indices and stat_indices[idx]:
                        vals = vals[stat_indices[idx]]
                    sim_stats[idx].append(vals)
            sim_means = []
            sim_stds = []
            for stat_list in sim_stats:
                arr = np.array(stat_list)
                sim_means.append(np.mean(arr, axis=0))
                sim_stds.append(np.std(arr, axis=0))
            return log_likelihood(
                np.concatenate(sim_means),
                np.concatenate(sim_stds),
                np.concatenate(reference),
            )
        result = pybobyqa.solve(
            objective,
            start_params,
            bounds=(lower, upper),
            objfun_has_noise=True,
            rhobeg=rhobeg,
            rhoend=rhoend,
            print_progress=True,
        )
        return result