import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
from spatiocoexistence.py_tools import (
    get_CI_CS_histogram,
    get_CI_HS_histogram,
    get_reduced_growth_histogram,
    get_survival_histogram,
    get_recruitment_histogram,
    reduced_growth,
    survival,
    reduced_recruitment,
    size_class,
    mean_count_size_class,
    get_BA_and_k,
)
from spatiocoexistence.crowding import crowding_indices
from spatiocoexistence.tools import count_saplings
[docs]
def plot_inventory_on_ax(
    ax: plt.axes,
    inventory: np.ndarray,
    filter: float = 1.0,
    scale: float = 2.0,
    params: dict | None = None,
):
    """
    Plot the inventory scatter plot with axis settings based on data/parameters.
    Parameters:
    -----------
    ax: plt.axes
        The matplotlib axes to plot on
    inventory: np.ndarray
        The inventory data to plot
    filter: float
        Minimum DBH to include in the plot
    scale: float
        Scaling factor for the point sizes
    params: dict, optional
        Model parameters - used to get dimensions if provided
    """
    ax.clear()
    # Filter inventory by dbh
    mask = inventory["dbh"] >= filter
    # Determine plot dimensions based on parameters or data
    if (
        params is not None
        and "quadrat_dim_x" in params
        and "quadrat_dim_y" in params
        and "cell_size" in params
    ):
        x_max = params["quadrat_dim_x"] * params["cell_size"]
        y_max = params["quadrat_dim_y"] * params["cell_size"]
    else:
        # Use data if parameters not available
        x_max = np.max(inventory["x"]) * 1.05
        y_max = np.max(inventory["y"]) * 1.05
    # Set basic plot properties
    ax.set_title("Inventory", fontsize=14, pad=10)
    ax.set_xlabel("x", fontsize=12, labelpad=10)
    ax.set_ylabel("y", fontsize=12, labelpad=10)
    # Force axis limits based on dimensions from parameters or data
    ax.set_xlim(0, x_max)
    ax.set_ylim(0, y_max)
    # Create appropriate ticks
    n_ticks = 6
    x_ticks = np.linspace(0, x_max, n_ticks)
    y_ticks = np.linspace(0, y_max, n_ticks)
    # Set explicit ticks and labels with integer formatting
    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)
    ax.set_xticklabels([f"{int(x)}" for x in x_ticks], fontdict={"fontsize": 10})
    ax.set_yticklabels([f"{int(y)}" for y in y_ticks], fontdict={"fontsize": 10})
    ax.grid(True, linestyle="--", alpha=0.3)
    ax.autoscale(enable=False)
    # Check if any trees meet the filter criterion
    if np.any(mask):
        gx_f = inventory["x"][mask]
        gy_f = inventory["y"][mask]
        dbh_f = inventory["dbh"][mask]
        species_f = inventory["species"][mask]
        # Create scatter plot
        sc = ax.scatter(gx_f, gy_f, dbh_f * scale, c=species_f, cmap="tab10", alpha=0.7)
    else:
        # No trees meet filter criterion - draw empty plot with warning
        ax.text(
            x_max / 2,
            y_max / 2,
            f"No trees with DBH ≥ {filter}",
            ha="center",
            va="center",
            fontsize=14,
            color="red",
            bbox=dict(facecolor="white", alpha=0.8, edgecolor="red"),
        )
        sc = None
    # Add count information
    num_individuals = len(inventory)
    num_visible = np.sum(mask) if np.any(mask) else 0
    ax.text(
        0.02,
        0.98,
        f"N: {num_individuals} (Visible: {num_visible})",
        transform=ax.transAxes,
        fontsize=12,
        va="top",
        ha="left",
        bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
    )
    # Force rendering settings to be preserved
    ax.set_rasterization_zorder(-1)  # Rasterize only background elements
    return sc 
[docs]
def plot_BA_dist(dbh: NDArray[np.float64]) -> None:
    BA = (dbh / 2) ** 2 * np.pi
    _, bins = np.histogram(BA, bins=25)
    logbins = np.logspace(np.log10(bins[0]), np.log10(bins[-1]), len(bins))
    plt.figure()
    plt.hist(BA, bins=logbins)
    plt.xscale("log")
    plt.show() 
[docs]
def display(
    inventory: np.ndarray,
    radius: float,
    threshold: int = 50,
    filter: float = 5.0,
    abundances: np.ndarray | list | None = None,
    axs: plt.axes = None,
    initial: bool = False,
    bins_dict: dict[str, np.ndarray] | None = None,
    params: dict | None = None,
    trees: str = "rep",
) -> dict:
    """
    Display forest inventory statistics.
    If initial=True, use 'x' markers and red lines for all plots.
    Otherwise, use 'o' markers and standard colors.
    """
    status = inventory["status"]
    dbh = inventory["dbh"]
    # Choose Recruit / Reproductive / Sapling / Dead
    if trees == "rep":
        focus = np.asarray(dbh) >= 1  # reproductive_size
    elif trees == "rec":
        focus = np.asarray(status) == 1
    elif trees == "sap":
        focus = np.asarray(dbh) < 1  # reproductive_size
    elif trees == "dead":
        focus = np.asarray(status) == -1
    x = inventory["x"][focus]
    y = inventory["y"][focus]
    species = inventory["species"][focus]
    dbh = inventory["dbh"][focus]
    CI_CS = inventory["CI_CS"][focus]
    CI_HS = inventory["CI_HS"][focus]
    CI_CS_d = inventory["CI_CS_d"][focus]
    CI_HS_d = inventory["CI_HS_d"][focus]
    status = inventory["status"][focus]
    # ToDo, use focus for all values plotted!
    CI_C, CI_H, _, _ = crowding_indices(x, y, species, status, radius)
    BA_ff, BA_fh, n_BA_ff, n_BA_fh, k_ff, k_fh, abundance_con, abundance_het = (
        get_BA_and_k(
            CI_CS, CI_C, CI_HS, CI_H, species, dbh, np.max(x), np.max(y), radius
        )
    )
    if axs is None:
        fig, axs = plt.subplots(5, 5, figsize=(22, 10))
    else:
        fig = axs[0, 0].figure
    # Create bins_dict if not provided (for initial call)
    if bins_dict is None:
        bins_dict = {
            "CI_CS": np.linspace(0, np.percentile(CI_CS, 99), 25),
            "CI_HS": np.linspace(0, np.percentile(CI_HS, 99), 50),
            "reduced_growth": np.linspace(0, 1, 31),
            "survival": np.linspace(0, 1, 31),
            "recruitment": np.linspace(0, 1, 31),
            "size_class": size_class(),
            "abundance_class": np.arange(0, 33) * 37,  # Added for SAD
        }
    # Inventory scatter plot - Create a dedicated subplot with more space
    gs = axs[0, 0].get_gridspec()
    for i in range(2):
        for j in range(2):
            ax = axs[i, j]
            if ax in fig.axes:
                fig.delaxes(ax)
    big_ax = fig.add_subplot(gs[0:2, 0:2])
    plot_inventory_on_ax(big_ax, inventory, filter=filter, scale=7.5, params=params)
    # Create Text Field
    num_total = len(inventory)
    num_recruits = np.sum(inventory["status"] == 1)
    num_saplings = count_saplings(inventory["dbh"], inventory["status"], 1)
    d_sap = len(inventory[(inventory["status"] == -2) & (inventory["dbh"] < 1)])
    d_rep = len(inventory[(inventory["status"] == -2) & (inventory["dbh"] >= 1)])
    num_rep = np.sum(inventory["dbh"] >= 1)
    fig.text(
        0.875,
        0.15,
        f"Total: {num_total}\nRecruits: {num_recruits}\nSaplings: {num_saplings}\nDead_rep: {d_rep}\nDead_sap: {d_sap}\nReproductive: {num_rep}",
        fontsize=13,
        va="top",
        ha="left",
        bbox=dict(facecolor="white", alpha=0.8, edgecolor="black"),
    )
    # Apply tight layout to the figure to avoid overlapping
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    # Add individual count info
    num_individuals = len(inventory)
    big_ax.text(
        0.02,
        0.98,
        f"N: {num_individuals}",
        transform=big_ax.transAxes,
        fontsize=12,
        va="top",
        ha="left",
        bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
    )
    # Ensure proper spacing and formatting
    fig.tight_layout()
    if abundances is not None:
        if isinstance(abundances, list):
            abundances = np.vstack(abundances)
    # Abundances at run_time
    if axs is not None and abundances is not None:
        timesteps = np.arange(0, len(abundances)) * 50 * 5
        for species_id in range(abundances.shape[1]):
            axs[4, 0].plot(
                timesteps,
                abundances[:, species_id],
                alpha=0.7,
            )
        axs[4, 0].set_xlabel("Year")
        axs[4, 0].set_ylabel("Abundance")
        axs[4, 0].set_title("Abundance per species (time series)")
        # Add sum of abundances as a legend/text
        total_abundance = int(abundances[-1].sum())
        axs[4, 0].text(
            0.98,
            0.98,
            f"Total: {total_abundance}",
            transform=axs[4, 0].transAxes,
            fontsize=12,
            va="top",
            ha="right",
            bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
        )
    mask = abundance_con >= threshold
    # Summary statistics scatter plots with marker distinction
    scatter_marker = "x" if initial else "o"
    scatter_color = "red" if initial else None
    scatter_label = "Initial" if initial else "Current"
    scatter_size = 50 if initial else 30
    scatter_alpha = 1.0 if initial else 0.7
    # k_ff vs abundance
    axs[3, 0].scatter(
        np.log(abundance_con[mask]),
        np.log(k_ff[mask]),
        marker=scatter_marker,
        color=scatter_color,
        s=scatter_size,
        alpha=scatter_alpha,
        label=scatter_label,
    )
    axs[3, 0].set_ylabel("log(k_ff)")
    axs[3, 0].set_xlabel("log(Abundance)")
    # k_fh vs abundance
    axs[3, 1].scatter(
        np.log(abundance_con[mask]),
        np.log(k_fh[mask]),
        marker=scatter_marker,
        color=scatter_color,
        s=scatter_size,
        alpha=scatter_alpha,
        label=scatter_label,
    )
    axs[3, 1].set_ylabel("log(k_fh)")
    axs[3, 1].set_xlabel("log(Abundance)")
    # BA_ff vs abundance
    axs[3, 2].scatter(
        np.log(abundance_con[mask]),
        np.log(BA_ff[mask]),
        marker=scatter_marker,
        color=scatter_color,
        s=scatter_size,
        alpha=scatter_alpha,
        label=scatter_label,
    )
    axs[3, 2].set_ylabel("log(BA_ff)")
    axs[3, 2].set_xlabel("log(Abundance)")
    # BA_fh vs abundance
    axs[3, 3].scatter(
        np.log(abundance_con[mask]),
        np.log(BA_fh[mask]),
        marker=scatter_marker,
        color=scatter_color,
        s=scatter_size,
        alpha=scatter_alpha,
        label=scatter_label,
    )
    axs[3, 3].set_ylabel("log(BA_fh)")
    axs[3, 3].set_xlabel("log(Abundance)")
    # Cumulative k_ff with line style distinction
    sorted_k_ff = np.sort(k_ff[mask])
    cumulative = np.arange(1, len(sorted_k_ff) + 1) / len(sorted_k_ff)
    line_style = "-"
    line_color = "red" if initial else "blue"
    line_width = 2 if initial else 1.5
    axs[0, 2].plot(
        np.log(sorted_k_ff),
        cumulative,
        linestyle=line_style,
        color=line_color,
        linewidth=line_width,
        label=scatter_label,
    )
    axs[0, 2].set_title("Cumulative k_ff")
    axs[0, 2].set_xlabel("log(k_ff)")
    axs[0, 2].set_ylabel("Cumulative fraction")
    # Cumulative k_fh with line style distinction
    sorted_k_fh = np.sort(k_fh[mask])
    cumulative_fh = np.arange(1, len(sorted_k_fh) + 1) / len(sorted_k_fh)
    line_style = "-"
    line_color = "red" if initial else "green"
    line_width = 2 if initial else 1.5
    axs[1, 2].plot(
        np.log(sorted_k_fh),
        cumulative_fh,
        linestyle=line_style,
        color=line_color,
        linewidth=line_width,
        label=scatter_label,
    )
    axs[1, 2].set_title("Cumulative k_fh")
    axs[1, 2].set_xlabel("log(k_fh)")
    axs[1, 2].set_ylabel("Cumulative fraction")
    if axs[1, 2].get_legend() is None:
        axs[1, 2].legend()
    # CI_CS Histogram or Line Plot
    hist_cs = get_CI_CS_histogram(inventory, bins_dict["CI_CS"])
    bin_edges_cs = bins_dict["CI_CS"]
    bin_centers_cs = (bin_edges_cs[:-1] + bin_edges_cs[1:]) / 2
    if initial:
        axs[0, 3].plot(bin_centers_cs, hist_cs, "r-", linewidth=2, label="Initial")
    else:
        axs[0, 3].hist(
            CI_CS,
            bins=bins_dict["CI_CS"],
            color="mediumorchid",
            edgecolor="black",
            label="Current",
        )
    axs[0, 3].set_title("CI_CS distribution")
    axs[0, 3].set_xlabel("CI_CS")
    axs[0, 3].set_ylabel("Frequency")
    axs[0, 3].set_xlim(left=0, right=np.max(bins_dict["CI_CS"]))
    if axs[0, 3].get_legend() is None:
        axs[0, 3].legend()
    # CI_HS Histogram or Line Plot
    hist_hs = get_CI_HS_histogram(inventory[focus], bins_dict["CI_HS"])
    bin_edges_hs = bins_dict["CI_HS"]
    bin_centers_hs = (bin_edges_hs[:-1] + bin_edges_hs[1:]) / 2
    if initial:
        axs[1, 3].plot(bin_centers_hs, hist_hs, "r-", linewidth=2, label="Initial")
    else:
        axs[1, 3].hist(
            CI_HS,
            bins=bins_dict["CI_HS"],
            color="goldenrod",
            edgecolor="black",
            label="Current",
        )
    axs[1, 3].set_title("CI_HS distribution")
    axs[1, 3].set_xlabel("CI_HS")
    axs[1, 3].set_ylabel("Frequency")
    axs[1, 3].set_xlim(left=0, right=np.max(bins_dict["CI_HS"]))
    # Reduced Growth Histogram or Line Plot
    hist_gr = get_reduced_growth_histogram(
        inventory[focus], bins_dict["reduced_growth"]
    )
    bin_edges_gr = bins_dict["reduced_growth"]
    bin_centers_gr = (bin_edges_gr[:-1] + bin_edges_gr[1:]) / 2
    if initial:
        axs[2, 0].plot(bin_centers_gr, hist_gr, "r-", linewidth=2, label="Initial")
    else:
        axs[2, 0].hist(
            reduced_growth(CI_CS, CI_HS, beta_gr=0.084),
            bins=bins_dict["reduced_growth"],
            color="skyblue",
            edgecolor="black",
            label="Current",
        )
    axs[2, 0].set_title("Reduced Growth")
    axs[2, 0].set_xlabel("reduced growth")
    axs[2, 0].set_ylabel("Frequency")
    # Reduced Survival Histogram or Line Plot
    hist_surv = get_survival_histogram(inventory[focus], bins_dict["survival"])
    bin_edges_surv = bins_dict["survival"]
    bin_centers_surv = (bin_edges_surv[:-1] + bin_edges_surv[1:]) / 2
    if initial:
        axs[2, 1].plot(bin_centers_surv, hist_surv, "r-", linewidth=2, label="Initial")
    else:
        axs[2, 1].hist(
            survival(CI_CS, CI_HS, CI_CS_d, CI_HS_d, dbh),
            bins=bins_dict["survival"],
            color="salmon",
            edgecolor="black",
            label="Current",
        )
    axs[2, 1].set_title("Reduced Survival")
    axs[2, 1].set_xlabel("reduced survival")
    axs[2, 1].set_ylabel("Frequency")
    # Reduced Recruitment Histogram or Line Plot
    hist_rec = get_recruitment_histogram(inventory[focus], bins_dict["recruitment"])
    bin_edges_rec = bins_dict["recruitment"]
    bin_centers_rec = (bin_edges_rec[:-1] + bin_edges_rec[1:]) / 2
    if initial:
        axs[2, 2].plot(bin_centers_rec, hist_rec, "r-", linewidth=2, label="Initial")
    else:
        axs[2, 2].hist(
            reduced_recruitment(CI_CS, CI_HS, CI_CS_d, CI_HS_d, dbh, species),
            bins=bins_dict["recruitment"],
            color="grey",
            edgecolor="black",
            label="Current",
        )
    axs[2, 2].set_title("Reduced Recruitment")
    axs[2, 2].set_xlabel("reduced recruitment")
    axs[2, 2].set_ylabel("Frequency")
    # Size-class plots
    sc = bins_dict["size_class"]
    # Mean crowding by size class
    mean_crowding = mean_count_size_class(sc, CI_CS, dbh * 10)
    valid_indices = mean_crowding > 0
    x_values = np.log(sc[:-1][valid_indices])
    y_values = mean_crowding[valid_indices]
    if initial:
        axs[0, 4].plot(x_values, y_values, "r-", linewidth=2, label="Initial")
    else:
        axs[0, 4].bar(
            x_values,
            y_values,
            edgecolor="black",
            linewidth=1.5,
            width=0.175,
            label="Current",
        )
    axs[0, 4].set_xlabel("size class")
    axs[0, 4].set_ylabel("Mean CI_CS")
    # Mean heterospecific crowding by size class
    mean_crowding = mean_count_size_class(sc, CI_HS, dbh * 10)
    valid_indices = mean_crowding > 0
    x_values = np.log(sc[:-1][valid_indices])
    y_values = mean_crowding[valid_indices]
    if initial:
        axs[1, 4].plot(x_values, y_values, "r-", linewidth=2, label="Initial")
    else:
        axs[1, 4].bar(
            x_values,
            y_values,
            edgecolor="black",
            linewidth=1.5,
            width=0.175,
            label="Current",
        )
    axs[1, 4].set_xlabel("size class")
    axs[1, 4].set_ylabel("Mean CI_HS")
    # Size distribution
    counts = mean_count_size_class(sc, dbh * 10, dbh * 10, count=True)
    valid_indices = counts > 0
    x_values = np.log(sc[:-1][valid_indices])
    y_values = np.log(counts[valid_indices])
    if initial:
        axs[2, 4].plot(x_values, y_values, "r-", linewidth=2, label="Initial")
    else:
        axs[2, 4].bar(
            x_values,
            y_values,
            edgecolor="black",
            linewidth=1.5,
            width=0.175,
            label="Current",
        )
    axs[2, 4].set_xlabel("dbh size class")
    axs[2, 4].set_ylabel("Frequency")
    axs[2, 4].set_title("Size distribution")
    # Mortality by size class
    surv = mean_count_size_class(
        sc, survival(CI_CS, CI_HS, CI_CS_d, CI_HS_d, dbh, reduced=False), dbh * 10
    )
    valid_indices = surv > 0
    x_values = np.log(sc[:-1][valid_indices])
    y_values = surv[valid_indices]
    if initial:
        axs[3, 4].plot(x_values, y_values, "r-", linewidth=2, label="Initial")
    else:
        axs[3, 4].bar(
            x_values,
            y_values,
            edgecolor="black",
            linewidth=1.5,
            width=0.175,
            label="Current",
        )
    axs[3, 4].set_xlabel("dbh size class")
    axs[3, 4].set_ylabel("Mortality")
    # Add Species Abundance Distribution (SAD) plot
    unique_species = np.unique(species, return_counts=True)
    species_counts = unique_species[1]
    abundance_class = bins_dict["abundance_class"]
    if initial:
        # For initial distribution, use line plot
        sad_hist, sad_bins = np.histogram(
            species_counts,
            bins=abundance_class,
            weights=np.ones_like(species_counts) * 1 / len(unique_species[0]),
        )
        sad_bin_centers = (sad_bins[:-1] + sad_bins[1:]) / 2
        axs[2, 3].plot(
            sad_bin_centers,
            sad_hist,
            "r-",
            linewidth=2,
            label="Initial",
        )
    else:
        # For regular display, use histogram with same bins
        axs[2, 3].hist(
            species_counts,
            bins=abundance_class,
            weights=np.ones_like(species_counts) * 1 / len(unique_species[0]),
            color="forestgreen",
            edgecolor="black",
            label="Current",
        )
    axs[2, 3].set_title("Species Abundance Distribution (SAD)")
    axs[2, 3].set_xlabel("Abundance class")
    axs[2, 3].set_ylabel("Fraction of species")
    if axs[2, 3].get_legend() is None:
        axs[2, 3].legend()
    if axs is None:
        plt.tight_layout()
        plt.show()
    else:
        plt.tight_layout()
        axs[0, 0].figure.canvas.draw_idle()
    return bins_dict