-
Notifications
You must be signed in to change notification settings - Fork 2
Add heterogeneity indices, printing indices/legend, fix index.html. #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| from .sensitivity_indices import sensitivity_indices | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import order matters (also sorting m before n, p) this should go at the bottom (with a blank line to separate) as this is coming from our library. Order is, std lib, 3rd party deps, dependencies from our own other packages, then the current package. Also prefer to always use absolute import using our namespace |
||
| import numpy as np | ||
| import pandas as pd | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| __all__ = ["heterogeneity_indices"] | ||
|
|
||
|
|
||
| def heterogeneity_indices( | ||
| output: pd.Series, | ||
| inputs: pd.DataFrame, | ||
| split_variable: str | pd.Series, | ||
| n_subdivisions: int | None = None, | ||
| plot: bool = False, | ||
| ) -> pd.DataFrame: | ||
| """ | ||
| Compute sensitivity-based heterogeneity across subdivisions of a variable. | ||
|
Comment on lines
+16
to
+17
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A nits but the format is actually to always write on the first line the sentence ending with a point. If too long, we make two sentences leaving a blank line. |
||
|
|
||
| Parameters | ||
| ---------- | ||
| output : pd.Series | ||
| Model output vector. | ||
| inputs : pd.DataFrame | ||
| Input/feature matrix. | ||
| split_variable : str or pd.Series | ||
| Variable to split on. If string, must be a column in 'inputs'. | ||
| n_subdivisions : int, optional | ||
| Number of regions for continuous variables. Defaults to 4. | ||
| plot : bool, default False | ||
| If True, displays a stacked bar chart of regional sensitivities. | ||
|
|
||
| Returns | ||
| ---------- | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some nits |
||
| summary : pd.Dataframe | ||
| A summary of calculated heterogeneity indices. | ||
| """ | ||
| y = pd.Series(output).reset_index(drop=True) | ||
| X = pd.DataFrame(inputs).reset_index(drop=True) | ||
|
|
||
| if isinstance(split_variable, str): | ||
| if split_variable not in X.columns: | ||
| raise ValueError(f"'{split_variable}' not found in inputs.") | ||
| z = X[split_variable].reset_index(drop=True) | ||
| split_name = split_variable | ||
| else: | ||
| z = pd.Series(split_variable).reset_index(drop=True) | ||
| split_name = getattr(split_variable, "name", "split_variable") | ||
|
|
||
| unique_vals = z.dropna().unique() | ||
| n_unique = len(unique_vals) | ||
|
|
||
| # Determine if variable is categorical/binary | ||
| is_categorical = ( | ||
| pd.api.types.is_categorical_dtype(z) | ||
| or pd.api.types.is_object_dtype(z) | ||
| or pd.api.types.is_bool_dtype(z) | ||
| or n_unique <= 2 | ||
| ) | ||
|
|
||
| if is_categorical: | ||
| regions = z.astype("category") | ||
| else: | ||
| q = n_subdivisions if n_subdivisions is not None else 4 | ||
| try: | ||
| regions = pd.qcut(z, q=q, duplicates="drop") | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| f"Failed to bin '{split_name}' into {q} quantiles: {e}" | ||
| ) from e | ||
|
|
||
| regional_profiles = [] | ||
| skipped = [] | ||
|
|
||
| for region in regions.cat.categories: | ||
| mask = regions == region | ||
| n_in_region = mask.sum() | ||
|
|
||
| if n_in_region < 10: | ||
| # Need enough samples for meaningful sensitivity indices | ||
| skipped.append((region, n_in_region, "too few samples (< 10)")) | ||
| continue | ||
|
|
||
| X_sub = X.loc[mask] | ||
| y_sub = y.loc[mask] | ||
|
|
||
| # Skip if output has zero or near-zero variance in this region | ||
| if y_sub.var() < 1e-12: | ||
| skipped.append((region, n_in_region, "output variance ≈ 0")) | ||
| continue | ||
|
|
||
| try: | ||
| res = sensitivity_indices(inputs=X_sub, output=y_sub) | ||
| si_vals = np.asarray(res.si).ravel() | ||
|
|
||
| # Guard against NaN/Inf from degenerate sensitivity computation | ||
| if not np.all(np.isfinite(si_vals)): | ||
| skipped.append((region, n_in_region, "non-finite SI values")) | ||
| continue | ||
|
|
||
| si_region = pd.Series(si_vals, index=X.columns, name=region) | ||
| regional_profiles.append(si_region) | ||
|
|
||
| except Exception as e: | ||
| skipped.append((region, n_in_region, f"exception: {e}")) | ||
| continue | ||
|
|
||
| if skipped: | ||
| print( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really not a fan of using any print. We should use the logging if we must. People using the library have no control on prints vs logging which they can properly capture and filter. That's the main reason. |
||
| f"[heterogeneity_indices] Skipped {len(skipped)} region(s) of '{split_name}':" | ||
| ) | ||
| for reg, n, reason in skipped: | ||
| print(f" - region={reg!r}, n={n}, reason={reason}") | ||
|
|
||
| if len(regional_profiles) < 2: | ||
| total_regions = len(regions.cat.categories) | ||
| valid = len(regional_profiles) | ||
| raise ValueError( | ||
| f"Not enough valid subdivisions to compute heterogeneity: " | ||
| f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n" | ||
| f"Skipped regions:\n" | ||
| + "\n".join(f" {r!r}: n={n}, {reason}" for r, n, reason in skipped) | ||
| + "\n\nTry: (1) reducing n_subdivisions, " | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't need the + there |
||
| "(2) using a different split_variable, or " | ||
| "(3) ensuring more samples per region." | ||
| ) | ||
|
|
||
| regional_si = pd.concat(regional_profiles, axis=1) | ||
|
|
||
| res_global = sensitivity_indices(inputs=X, output=y) | ||
| overall_si = pd.Series( | ||
| np.asarray(res_global.si).ravel(), | ||
| index=X.columns, | ||
| name="Overall_SI", | ||
| ) | ||
|
|
||
| # Heterogeneity = 2 × population std dev across regions | ||
| hetero_scores = 2 * regional_si.std(axis=1, ddof=0) | ||
| total_hetero = hetero_scores.mean() | ||
|
|
||
| hetero_col_name = f"Heterogeneity (across {split_name})" | ||
| summary = pd.DataFrame( | ||
| {"Overall_SI": overall_si, hetero_col_name: hetero_scores} | ||
| ).sort_values(by=hetero_col_name, ascending=False) | ||
| summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero] | ||
|
|
||
| if plot: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could consider making a function for this so that we can call it externally. Otherwise you are bound to use this big function. That's also good in terms of architecture to separate concerns. |
||
| plot_order = summary.index[:-1] | ||
| data_to_plot = regional_si.loc[plot_order].T | ||
|
|
||
| cmap = plt.get_cmap("terrain") | ||
| colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(plot_order))] | ||
|
|
||
| _ = data_to_plot.plot( | ||
| kind="bar", | ||
| stacked=True, | ||
| figsize=(10, 6), | ||
| color=colors, | ||
| edgecolor="white", | ||
| width=0.8, | ||
| ) | ||
|
|
||
| plt.title(f"Sensitivity Profiles across {split_name}", fontsize=14) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The best practice is really to work with I am not a fan of calling show directly vs just returning |
||
| plt.ylabel("Variance Contribution", fontsize=12) | ||
| plt.xlabel(f"Regions of {split_name}", fontsize=12) | ||
| plt.legend(title="Input Variables", bbox_to_anchor=(1.05, 1), loc="upper left") | ||
| plt.xticks(rotation=45) | ||
| plt.grid(axis="y", linestyle="--", alpha=0.7) | ||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
||
| return summary | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,9 @@ class SensitivityAnalysisResult: | |
|
|
||
|
|
||
| def sensitivity_indices( | ||
| inputs: pd.DataFrame | np.ndarray, output: pd.DataFrame | np.ndarray | ||
| inputs: pd.DataFrame | np.ndarray, | ||
| output: pd.DataFrame | np.ndarray, | ||
| print_indices: bool = False, | ||
| ) -> SensitivityAnalysisResult: | ||
| """Sensitivity indices. | ||
|
|
||
|
|
@@ -50,6 +52,8 @@ def sensitivity_indices( | |
| Input variables. | ||
| output : ndarray or DataFrame of shape (n_runs, 1) | ||
| Target variable. | ||
| print_indices : bool, default False | ||
| If True, displays computed indices. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -97,11 +101,18 @@ def sensitivity_indices( | |
| """ | ||
| # Handle inputs conversion | ||
| if isinstance(inputs, pd.DataFrame): | ||
| cat_columns = inputs.select_dtypes(["category", "O"]).columns | ||
| inputs[cat_columns] = inputs[cat_columns].apply( | ||
| lambda x: x.astype("category").cat.codes | ||
| ) | ||
| var_names = inputs.columns.tolist() | ||
| cat_cols = inputs.select_dtypes(["category", "O"]).columns | ||
| if not cat_cols.empty: | ||
| inputs = inputs.copy() # Avoid SettingWithCopyWarning | ||
| inputs[cat_cols] = inputs[cat_cols].apply( | ||
| lambda x: x.astype("category").cat.codes | ||
| ) | ||
| inputs = inputs.to_numpy() | ||
| else: | ||
| inputs = np.asarray(inputs) | ||
| # Fallback names if it's just a numpy array | ||
| var_names = [f"x{i}" for i in range(inputs.shape[1])] | ||
|
|
||
| # Handle output conversion first, then flatten | ||
| if isinstance(output, (pd.DataFrame, pd.Series)): | ||
|
|
@@ -181,4 +192,14 @@ def sensitivity_indices( | |
| for k in range(n_factors): | ||
| si[k] = foe[k] + (soe[:, k].sum() / 2) | ||
|
|
||
| if print_indices: | ||
| df_foe = pd.DataFrame(foe, index=var_names, columns=["First-order effect"]) | ||
| df_soe = pd.DataFrame(soe, index=var_names, columns=var_names) | ||
| df_si = pd.DataFrame(si, index=var_names, columns=["Combined effect"]) | ||
|
|
||
| df_indices = pd.concat([df_foe, df_soe, df_si], axis=1) | ||
| print(f"{'-'*69}") | ||
| print(df_indices) | ||
| print(f"{'-'*69}") | ||
|
Comment on lines
+201
to
+203
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use one print with |
||
|
|
||
| return SensitivityAnalysisResult(si, foe, soe) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a fan of adding this just on 1 parameters as we are not consistent then. Either we allow this on both or none. This is more done on the user side IMO.
What we can consider in the future is to be actually dataframe agnostic and use Narwhal. This is a great way to support multiple format at once as it supports the data API.