transformer_jet_tagging
Submodules
preprocess module
dataset module
model module
train module
evaluate module
plotting module
plotting.py
Visualization module for the GN2 jet flavour tagging pipeline.
Generates plots of input and output variables, reading directly from the HDF5 file and/or a DataLoader.
- transformer_jet_tagging.plotting._load_jet_data(h5_path: str | Path, jet_indices: ndarray, jet_vars: list[str], jet_flavour: str, jet_flavour_map: dict[int, int]) dict[str, ndarray][source]
Load jet-level variables and labels from HDF5 for a subset of jets.
- Parameters:
h5_path (str | Path) – Path to HDF5 file.
jet_indices (np.ndarray) – Sorted jet indices to read.
jet_vars (list) – Jet variable names.
jet_flavour (str) – Name of the flavour field in HDF5.
jet_flavour_map (dict) – Raw label - class index mapping.
- Returns:
var_name (np.ndarray): shape
(n_jets,)for each jet variable. “label” (np.ndarray): shape(n_jets,)integer class index for each jet.- Return type:
dict
- Raises:
FileNotFoundError – if the specified file does not exist.
KeyError – if expected datasets or fields are missing in the HDF5 file.
- transformer_jet_tagging.plotting._load_track_data(h5_path: str | Path, jet_indices: ndarray, track_vars: list[str], jet_flavour: str, jet_flavour_map: dict[int, int], max_jets: int = 50000) dict[str, ndarray][source]
Load valid track-level variables from HDF5, flattened across jets.
- Parameters:
h5_path (str | Path) – Path to HDF5 file.
jet_indices (np.ndarray) – Sorted jet indices.
track_vars (list) – Track variable names.
jet_flavour (str) – Flavour field name.
jet_flavour_map (dict) – Raw label - class index.
max_jets (int) – Cap on jets to read (memory guard).
- Returns:
var_name (np.ndarray): shape
(n_tracks,)for each track variable. “label” (np.ndarray): shape(n_tracks,)integer class index for each track’s jet.- Return type:
dict
- Raises:
FileNotFoundError – if the specified file does not exist.
KeyError – if expected datasets or fields are missing in the HDF5 file.
- transformer_jet_tagging.plotting.plot_jet_variables(jet_data: dict[str, ndarray], jet_vars: list[str], output_dir: str | Path) None[source]
Plot per-flavour distributions of jet-level variables. pt is shown both raw (linear) and log-transformed.
- Parameters:
jet_data (dict) – Output of _load_jet_data().
jet_vars (list) – Variable names to plot.
output_dir (str | Path) – Directory where PNGs are saved.
- transformer_jet_tagging.plotting.plot_track_variables(track_data: dict[str, ndarray], track_vars: list[str], output_dir: str | Path, vars_per_page: int = 6) None[source]
Plot per-flavour distributions of track-level variables. Variables are split across multiple pages if needed.
- Parameters:
track_data (dict) – Output of
_load_track_data().track_vars (list) – Variable names to plot.
output_dir (str | Path) – Directory where PNGs are saved.
vars_per_page (int) – Max variables per figure (default
6).
- transformer_jet_tagging.plotting.plot_label_distribution(labels: ndarray, output_dir: str | Path) None[source]
Plot the class-label distribution for a dataset split (train / val / test).
Unmapped jets (label
-1) are silently ignored.- Parameters:
labels (np.ndarray) – integer array of class indices. Jets with label
-1are dropped before plotting.output_dir (str | Path) – directory where the PDF is saved.
- Raises:
ValueError – if labels is not a 1-D array.
- transformer_jet_tagging.plotting._corr_matrix(data_dict, vars_list)[source]
Compute correlation matrix for the specified variables. (Non-finite values are replaced with the column mean before correlation)
- Parameters:
data_dict (dict) – dict of variable name to
np.ndarray.vars_list (list) – list of variable names to include in the matrix.
- Returns:
shape
(len(vars_list), len(vars_list)), correlation matrix.- Return type:
np.ndarray
- transformer_jet_tagging.plotting._draw_heatmap(ax, corr, labels, title)[source]
Draw a heatmap of the correlation matrix with annotations.
- Parameters:
ax – matplotlib axis to draw on.
corr – 2D array of correlation coefficients.
labels – list of variable names for axes.
title – title of the plot.
- Returns:
image object from imshow (for colorbar).
- Return type:
im
- transformer_jet_tagging.plotting.plot_correlations(jet_data: dict[str, ndarray], track_data: dict[str, ndarray], jet_vars: list[str], track_vars: list[str], output_dir: str | Path) None[source]
Plot Pearson correlation matrices for jet and track variables.
- Parameters:
jet_data (dict) – Output of
_load_jet_data().track_data (dict) – Output of
_load_track_data().jet_vars (list) – Jet variable names.
track_vars (list) – Track variable names.
output_dir (str | Path) – Directory where PNGs are saved.
- transformer_jet_tagging.plotting.plot_statistics(h5_path: str | Path, jet_vars: list[str], track_vars: list[str], jet_flavour: str, jet_flavour_map: dict[int, int], jet_indices: ndarray, output_dir: str = 'outputs/plots', n_jets_track: int = 50000) None[source]
Generate all plots and save them to
plot_dir.- Parameters:
h5_path (str | Path) – Path to HDF5 file.
jet_vars (list) – Jet variable names.
track_vars (list) – Track variable names.
jet_flavour (str) – Flavour field name in HDF5.
jet_flavour_map (dict) – Raw label - class index.
jet_indices (np.ndarray) – Jet indices to use (e.g.
train_indices).output_dir (str) – Directory for output PNGs.
n_jets_track (int) – Max jets for track plots (memory guard).
- transformer_jet_tagging.plotting.plot_learning_curves(history: dict[str, list[float]], output_dir: str | Path) None[source]
Plot training and validation loss curves + LR schedule.
- Parameters:
history (dict) – keys
"train_loss","val_loss","lr".output_dir (str | Path) – Directory where the PDF is saved.
- transformer_jet_tagging.plotting.plot_score_distributions(proba: ndarray, labels: ndarray, output_dir: str | Path) None[source]
Plot softmax score distributions for each output node.
One figure is produced with one panel per class (P_b, P_c, P_u, P_tau). Inside each panel, the distribution is shown separately for every true-label class, allowing direct reading of signal/background separation.
- Parameters:
proba (np.ndarray) – shape
(N, n_classes), softmax probabilities.labels (np.ndarray) – shape
(N,), true class labels.output_dir (str | Path) – directory where the PDF is saved.
- transformer_jet_tagging.plotting.plot_discriminant(discriminant_scores: ndarray, labels: ndarray, discriminant_type: str, output_dir: str | Path) None[source]
Plot the distribution of a discriminant (D_b or D_c) per flavour class.
- Parameters:
discriminant_scores (np.ndarray) – discriminant values for all jets.
labels (np.ndarray) – true class labels.
discriminant_type (str) – name of the discriminant (e.g. “b” or “c”, used for axis label and filename).
output_dir (str | Path) – output directory.
- transformer_jet_tagging.plotting.plot_confusion_matrix(labels: ndarray, preds: ndarray, output_dir: str | Path) None[source]
Plot and save a normalised confusion matrix.
- Parameters:
labels (np.ndarray) – true class labels.
preds (np.ndarray) – predicted class labels.
output_dir (str | Path) – output directory.
- transformer_jet_tagging.plotting._roc_rejection(scores: ndarray, labels: ndarray, signal_class: int, bg_class: int, n_points: int = 200, eff_min: float = 0.5) tuple[ndarray, ndarray][source]
Calculate signal efficiency and background rejection for ROC curve.
- Parameters:
scores (np.ndarray) – Discriminant scores for all jets.
labels (np.ndarray) – True class labels for all jets.
signal_class (int) – Class index of the signal (e.g.
b-jets).bg_class (int) – Class index of the background (e.g.
c-jets).n_points (int) – Number of points on the ROC curve.
eff_min (float) – Minimum signal efficiency to include (default
0.5).
- Returns:
Signal efficiency values. rej (np.ndarray): Background rejection values (
1 / bg efficiency).- Return type:
eff (np.ndarray)
- transformer_jet_tagging.plotting._plot_roc(scores: ndarray, labels: ndarray, signal_class: int, bg_classes: list[tuple[int, str, str]], discriminant_type: str, output_dir: str | Path) None[source]
Plot a ROC curve (signal efficiency vs background rejection).
- Parameters:
scores (np.ndarray) – discriminant scores for all jets.
labels (np.ndarray) – true class labels.
signal_class (int) – index of the signal class.
bg_classes (list) – list of
(class_index, linestyle, legend_label)tuples.discriminant_type (str) – name of the discriminant (e.g. “b” or “c”, used for axis label and filename).
output_dir (str | Path) – output directory.
utils module
Module contents
__init__.py
Initialization module for the transformer_jet_tagging package.