In [None]:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import re
import seaborn as sns
from matplotlib.transforms import offset_copy
import matplotlib.patches as mpatches

def column_mapping(cores):
    return {'finalTick' : 'finalTick',
                  'benchmark' : 'benchmark',
                  'experiment' : 'experiment',
                  f'board.processor.cores{cores}.core.ipc': 'ipc',
                  f'board.processor.cores{cores}.core.branchPred.mispredictDueToBTBMiss_0::total': 'btb_mispredicts',
                  f'board.processor.cores{cores}.core.branchPred.mispredictDueToPredictor_0::DirectCond' : 'cond_mispredicts',
                  f'board.processor.cores{cores}.core.branchPred.mispredictDueToPredictor_0::total': 'bp_mispredicts',
                  f'board.processor.cores{cores}.core.branchPred.mispredicted_0::total' : 'total_mispredicts',
                  f'board.processor.cores{cores}.core.TopDownL1.frontendBound' : 'TopDownL1_frontendBound',
                  f'board.processor.cores{cores}.core.TopDownL1.badSpeculation' : 'TopDownL1_badSpeculation',
                  f'board.processor.cores{cores}.core.TopDownL1.backendBound' : 'TopDownl1_backendBound',
                  f'board.processor.cores{cores}.core.TopDownL1.retiring' : 'TopDownl1_retiring',
                  f'board.processor.cores{cores}.core.rename.serializeStallCycles' : 'serializingStallCycles',
                  f'board.processor.cores{cores}.core.rename.ROBFullEvents' : 'rob_full',
                  f'board.processor.cores{cores}.core.rename.IQFullEvents':'iq_full',
                  f'board.processor.cores{cores}.core.rename.LQFullEvents' : 'lq_full',
                  f'board.processor.cores{cores}.core.rename.SQFullEvents' : 'sq_full',
                  f'board.processor.cores{cores}.core.rename.fullRegistersEvents' : 'register_full',
                  f'board.processor.cores{cores}.core.fetchStats0.icacheStallCycles' : 'icache_stall_cycles',
                  f'board.processor.cores{cores}.core.iew.memOrderViolationEvents' : 'memOrderViolations',
                  f'board.processor.cores{cores}.core.TopDownL2_BackendBound.serializeStalls' : 'TopDownL2_serializingStalls',
                  f'board.processor.cores{cores}.core.TopDownL2_BackendBound.memoryBound' : 'TopDownL2_memoryBound',
                  f'board.processor.cores{cores}.core.TopDownL2_BackendBound.coreBound' : 'TopDownL2_coreBound',
                  f'board.processor.cores{cores}.core.TopDownL2_BadSpeculation.machineClears' : 'TopDownL2_machineClears',
                  f'board.processor.cores{cores}.core.TopDownL2_BadSpeculation.branchMissPredicts' : 'TopDownL2_Mispredicts',
                  f'board.processor.cores{cores}.core.issueRate' : 'issueRate',
                  f'board.processor.cores{cores}.core.fetchStats0.fetchRate' : 'fetchRate',
                  f'board.processor.cores{cores}.core.commitStats0.numInsts' : 'insts',
                  f'board.processor.cores{cores}.core.cpi' : 'cpi',
                  f'board.processor.cores{cores}.core.idleCycles' : 'idleCycles',
                  f'board.processor.cores{cores}.core.numCycles' : 'numCycles', 
                  f'board.processor.cores{cores}.core.robOccupancy::mean' : 'ROB mean occupancy',
                  f'board.cache_hierarchy.l1icaches{cores}.demandMshrMisses::total' : 'l1icacheMisses',
                  f'board.cache_hierarchy.l1dcaches{cores}.demandMshrMisses::total' : 'l1dcacheMisses',
                  f'board.processor.cores{cores}.core.bac.ftSizeDist::mean' : 'mean_ftSize',
                  f'board.processor.cores{cores}.core.bac.ftNumber::mean' : 'mean_ftNumber',
                  f'board.processor.cores{cores}.core.numIssuedDist::mean' : 'mean_issuedInsts',
                  f'board.processor.cores{cores}.core.iew.dispInstDist::mean' : 'mean_dispatchedInsts',
                  f'board.processor.cores{cores}.core.iew.wbRate' : 'wbRate',
                  f'board.processor.cores{cores}.core.ftq.occupancy::mean' : 'Mean FTQ Occupancy',
                  f'board.processor.cores{cores}.core.instSquashedAexperimenttIssueDist::mean' : 'mean_squashedInstsAtIssue',
                  f'board.processor.cores{cores}.core.lsq0.blockedByCache' : 'lsqBlockebByCache',
                  f'board.processor.cores{cores}.core.rob.robSquashCycles' : 'robSquashCycles',
                  f'board.processor.cores{cores}.core.rob.independentInstDelta::mean' : 'independentInstDelta',
                  f'board.processor.cores{cores}.core.rob.independentInst' : 'independentInst'
           }



def loadAndPrepare(fileNames, cores=[], experiment_marker = []):
    dfs = []
    for idx, fileName in enumerate(fileNames):
        tmp_df = pd.read_csv(fileName).fillna(0)
        if len(experiment_marker) > idx:
            tmp_df['set'] = experiment_marker[idx]    
        tmp_df.rename(columns=column_mapping(cores[idx] if len (cores) > 0 else ""), inplace=True)    
        dfs.append(tmp_df)
    df = pd.concat(dfs , ignore_index=True)
    
    
    #df.drop(columns=[col for col in df.columns if not (col.startswith('board.processor.cores1.core') or col in column_mapping.values())], inplace=True)

    new_cols = {
    'mpki': df['total_mispredicts'] * 1000 / df['insts'],
    'btb_mpki': df['btb_mispredicts'] * 1000 / df['insts'],
    'cond_mpki': df['cond_mispredicts'] * 1000 / df['insts'],
    'bp_mpki': df['bp_mispredicts'] * 1000 / df['insts'],
    'icache_mpki': df['l1icacheMisses'] * 1000 / df['insts'],
    'IPC': df['insts'] / (df['numCycles'] - df['idleCycles']),
    'm_cpi': 1 / (df['insts'] / (df['numCycles'] - df['idleCycles'])),
    'L1IStallCycleRate': df['icache_stall_cycles'] / (df['numCycles'] - df['idleCycles']),
    'fetchRate' : df['fetchRate'] * df['numCycles'] / (df['numCycles'] - df['idleCycles']),
    'serializingStallCycleRate': df['serializingStallCycles'] / (df['numCycles'] - df['idleCycles']),
    'ppc' :  df["experiment"].str.extract(r"(ppc\d*)")   
}
   
    df['benchmark'] = df['benchmark'].replace('502.gcc_r.gcc-pp.opts-O3_-finline-limit_36000', '502.gcc_r')
    df = df.assign(**new_cols)
    
    
    return df

def removeWarmup(df):
    n = 2
    mask = df.groupby(['benchmark', 'experiment']).cumcount() >= n
    df = df[mask].reset_index(drop=True)
    return df


def filterBenchExperiment(df, experiment_like, exclude_benchs = [] , only_benchs = []):
    if len(only_benchs) > 0:
        df = df[df['benchmark'].isin(only_benchs)]
    else:    
        df = df[~df['benchmark'].isin(exclude_benchs)]
        
    df = df[df['experiment'].str.contains(experiment_like, regex = True, na = False)]
    return df

def plotYoverTime(df, y, experiment_like = '', exclude_benchs = [], only_benchs = []):
    x = 'finalTick'
    
    df = filterBenchExperiment(df, experiment_like, exclude_benchs , only_benchs)

    for benchmark, group in df.groupby('benchmark'):
        plt.figure(figsize=(8, 5))
        plt.title(f"Benchmark: {benchmark}")
    
        # Group by experiment within the benchmark group
        for experiment, subgrp in group.groupby('experiment'):
            plt.plot(subgrp[x], subgrp[y], label=experiment)      # Line
            plt.scatter(subgrp[x], subgrp[y])                      # Scatter points
    
        plt.xlabel(x)
        plt.ylabel(y)
        plt.legend(title="Experiment")
        plt.grid(True)
        plt.tight_layout()
        plt.show()



def barPlotY_B(
    dfs,
    y,
    label=None,
    experiment_like='',
    exclude_benchs=[],
    only_benchs=[],
    save_as="",
    ylim=None,
    legend_out=False,
    show_means=False,  # mean bars across benchmarks
    ref_value=None     # percentage ref value
):
    # --- If a single df is passed instead of list, wrap it ---
    if isinstance(dfs, pd.DataFrame):
        dfs = [dfs]

    # --- Concatenate with explicit set ordering preserved ---
    df = pd.concat(dfs, ignore_index=True)

    # --- Maintain set order as provided ---
    ordered_sets = []
    for d in dfs:
        for s in d['set'].unique():
            if s not in ordered_sets:
                ordered_sets.append(s)

    df = filterBenchExperiment(df, experiment_like, exclude_benchs, only_benchs)

    # --- Group by benchmark, set, ppc ---
    grouped = df.groupby(['benchmark', 'set', 'ppc'])[y].mean().reset_index()

    # --- Optionally add "Mean" row across benchmarks ---
    if show_means:
        mean_rows = (
            grouped.groupby(['set', 'ppc'])[y]
            .mean()
            .reset_index()
            .assign(benchmark="Mean")
        )
        grouped = pd.concat([grouped, mean_rows], ignore_index=True)

    # --- Plot ---
    benchmarks = grouped['benchmark'].unique()
    sets = ordered_sets  # use preserved order
    ppcs = grouped['ppc'].unique()

    x = np.arange(len(benchmarks))
    width = 0.8 / len(sets)  # width per set group

    hatch_patterns = ["//" , "...", "\\\\", "xx", "++"]  # extend if needed
    colors = sns.color_palette("tab10", len(ppcs))

    fig, ax = plt.subplots(figsize=(12, 6))

    for i, s in enumerate(sets):
        for j, p in enumerate(ppcs):
            vals = []
            for b in benchmarks:
                row = grouped[(grouped['benchmark'] == b) & (grouped['set'] == s) & (grouped['ppc'] == p)]
                vals.append(row[y].values[0] if not row.empty else 0)

            bars = ax.bar(
                x + i * width + j * (width / len(ppcs)),
                vals,
                width / len(ppcs),
                color=colors[j],
                hatch=hatch_patterns[i % len(hatch_patterns)],
                edgecolor="black"
            )

            # --- Add percentage labels if ref_value is given ---
            if ref_value is not None:
                for bar in bars:
                    height = bar.get_height()
                    if height > 0:
                        perc = (height / ref_value) * 100
                        ax.text(
                            bar.get_x() + bar.get_width() / 2,
                            height + (ylim * 0.01 if ylim else height * 0.01),
                            f"{perc:.1f}%",
                            ha="center", va="bottom", fontsize=8
                        )

    # --- Beautify ---
    ax.set_xticks(x + (len(sets) - 1) * width / 2)
    ax.set_xticklabels(benchmarks, rotation=45, ha="right")

    if ylim:
        plt.ylim(0, ylim)

    ax.set_title(f'{label if label else y} per Benchmark')
    ax.set_xlabel('Benchmark')
    ax.set_ylabel(f'{label if label else y}')

    # --- Single combined legend ---
    color_patches = [mpatches.Patch(color=colors[j], label=f"{p}") for j, p in enumerate(ppcs)]
    hatch_patches = [mpatches.Patch(facecolor="white", hatch=hatch_patterns[i % len(hatch_patterns)], edgecolor="black", label=f"{s}") for i, s in enumerate(sets)]

    all_handles = color_patches + hatch_patches

    ax.legend(
        handles=all_handles,
        title="Legend",
        loc="best",
        bbox_to_anchor=(1.0, 0.5) if legend_out else None
    )

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()

    if len(save_as) > 0:
        plt.savefig(f"bar_plots/{save_as}.png", bbox_inches='tight')

    plt.show()




def barPlotY(
    df, 
    y, 
    label=None, 
    experiment_like='', 
    exclude_benchs=[], 
    only_benchs=[], 
    save_as="", 
    ylim=None, 
    legend_out=False,
    show_means=False,  # NEW flag for mean bars
    ref_value=None     # NEW: reference value for percentage
):

    df = filterBenchExperiment(df, experiment_like, exclude_benchs, only_benchs)
    
    # Mean per benchmark per experiment
    grouped = df.groupby(['benchmark', 'experiment'])[f'{y}'].mean().unstack()

    # Optionally add a "Mean" row across benchmarks
    if show_means:
        means = grouped.mean(axis=0)
        grouped.loc["Mean"] = means

    # Step 2: Plot as bar chart
    ax = grouped.plot(kind='bar', figsize=(10, 6))
    
    # Step 3: Beautify
    if ylim:
        plt.ylim(0, ylim)
    plt.title(f'{label if label else y} per Benchmark per Experiment')
    plt.xlabel('Benchmark')
    plt.ylabel(f'{label if label else y}')
    plt.legend(
        title='Experiment', 
        loc="best",
        bbox_to_anchor=(1.0, 0.5) if legend_out else None
    )
    plt.ylim(bottom=0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()

    # Step 4: Add percentages if ref_value is given
    if ref_value is not None:
        for container in ax.containers:  # each experiment
            for bar in container:
                height = bar.get_height()
                if height > 0:
                    perc = (height / ref_value) * 100
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        height + (ylim * 0.01 if ylim else height * 0.01),  # offset a bit above
                        f"{perc:.1f}%",
                        ha="center", va="bottom", fontsize=8
                    )

    if len(save_as) > 0:
        plt.savefig(f"bar_plots/{save_as}.png", bbox_inches='tight')
    plt.show()



def plotMeanY(df, y, experiment_like = '', exclude_benchs = [], only_benchs = [], save_as = ""):

    df = filterBenchExperiment(df, experiment_like, exclude_benchs , only_benchs)
    
    grouped = df.groupby(['benchmark', 'experiment'])[f'{y}'].mean().unstack()

    # Step 2: Plot as bar chart
    grouped.plot(kind='bar', figsize=(10, 6))
    
    # Step 3: Beautify
    plt.title(f'Mean of {y} per Benchmark per Experiment')
    plt.xlabel('Benchmark')
    plt.ylabel(f'Mean {y}')
    plt.legend(title='Experiment', 
              bbox_to_anchor=(1.0, 0.5) 
              )
    plt.grid(axis= 'y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    if len(save_as) > 0:
        plt.savefig(f"bar_plots/{save_as}.png",  bbox_inches='tight')
    plt.show()

def plotYDensity(df, y , cores="",  experiment_like = '', exclude_benchs = [], only_benchs = [], label = None ,   save_as = ""):

    variable_supported = {
                            'robOccupancy': fr'board\.processor\.cores{cores}\.core\.robOccupancy::\d+(-\d+)?$', 
                            'ftNumber' : fr'board\.processor\.cores{cores}\.core\.bac\.ftNumber::\d+(-\d+)?$',
                            'BB_Size' :  fr'board\.processor\.cores{cores}\.core\.bac\.ftSizeDist::\d+(-\d+)?$',
                            'issuedInsts': fr'board\.processor\.cores{cores}\.core\.numIssuedDist::\d+(-\d+)?$', 
                            'ftqOccupancy' : fr'board\.processor\.cores{cores}\.core\.ftq\.occupancy::\d+(-\d+)?$'
                         }

    if y not in variable_supported.keys():
        print("The variable entered is not yet supported!")
        return
    
    df = filterBenchExperiment(df, experiment_like, exclude_benchs , only_benchs)

    # Find all columns that represent ROB occupancy buckets (excluding 'total', 'mean', 'stddev', 'samples')
    bucket_cols = [
        col for col in df.columns
        if re.match(variable_supported[y] , col)
    ]
    
    if not bucket_cols:
        print(f"No {y} bucket columns found.")
        return

    # Prepare data for plotting
    df_plot = df[['benchmark', 'experiment', 'finalTick'] + bucket_cols].copy()
    df_plot = df_plot.loc[df_plot.groupby(['experiment', 'benchmark'])['finalTick'].idxmax()]

    # Melt for density plotting
    df_melted = df_plot.melt(
        id_vars=['benchmark', 'experiment'],
        value_vars=bucket_cols,
        var_name=f'{y} Bucket',
        value_name='Frequency'
    )

    # Extract bucket start as numeric value for plotting
    df_melted['Bucket Start'] = df_melted[f'{y} Bucket'].apply(lambda x: int(x.split('::')[-1].split('-')[0]))

    for benchmark, group_df  in df_melted.groupby('benchmark'):
        # Plot
        plt.figure(figsize=(10, 6))
        sns.kdeplot(
            data=group_df,
            x='Bucket Start',
            weights='Frequency',
            hue='experiment',
            common_norm=False,
            fill=True,
            alpha=0.4,
            linewidth=1.5, 
            cut=0,
            clip=(0, df_melted['Bucket Start'].max())
        )

        
        plt.title(f" {label if label else y} Density by Experiment benchmark: {benchmark}")
        plt.xlabel(f'{label if label else y} Bucket Start')
        plt.ylabel('Density (weighted)')
        plt.tight_layout()
        if len(save_as) > 0:
            plt.savefig(f"density_plots/{save_as}.png",  bbox_inches='tight')
        plt.show()


def plot_incremental_from_dfs(
    dfs, labels,
    benchmarks_col="benchmark", experiment_col="experiment", value_col="IPC",
    ylabel="Performance (IPC)", title=None , save_as="", ymax=None,
    show_means=False  # NEW argument
):
    """
    Plot grouped stacked bars from multiple DataFrames, each representing
    a performance scenario. The DataFrames must have columns:
    [benchmark, experiment, IPC].
    
    Parameters:
        dfs (list[pd.DataFrame]): list of DataFrames in scenario order
        labels (list[str]): names of each scenario
    """
    experiments = dfs[0][experiment_col].unique()
    
    # --- aggregate each df ---
    grouped_dfs = []
    for df in dfs:
        grouped = df.groupby([benchmarks_col, experiment_col])[value_col].mean().reset_index()
        grouped_dfs.append(grouped)
    
    # --- merge into one wide table ---
    merged = grouped_dfs[0].rename(columns={value_col: labels[0]})
    for i, df in enumerate(grouped_dfs[1:], 1):
        merged = merged.merge(df.rename(columns={value_col: labels[i]}),
                              on=[benchmarks_col, experiment_col])
    
    # --- optionally add a "Mean" benchmark ---
    if show_means:
        mean_rows = []
        for exp in experiments:
            sub = merged[merged[experiment_col] == exp]
            means = sub[labels].mean()
            row = {benchmarks_col: "Mean", experiment_col: exp}
            row.update(means.to_dict())
            mean_rows.append(row)
        merged = pd.concat([merged, pd.DataFrame(mean_rows)], ignore_index=True)
    
    benchmarks = merged[benchmarks_col].unique()
    x = np.arange(len(benchmarks))
    width = 0.75 / len(experiments)
    fig, ax = plt.subplots(figsize=(16, 7))
    
    # softer, more readable colors
    colors = ["#183ab8", "#bf8970", "#FF2400", "#ffd700", "#50C878", "#C0C0C0", "#7851a9", "#f77b07"]
    hatches = ["//", "\\\\", "xx", "oo", "..", "++"]  
    
    # --- plot each experiment as grouped stacked bar ---
    for i, exp in enumerate(experiments):
        sub = merged[merged[experiment_col]==exp]
        xpos = x - 0.375 + i*width + width/2
        
        bottom = np.zeros(len(sub))
        prev_vals = bottom.copy()
        for j, lab in enumerate(labels):
            vals = sub[lab].to_numpy()
            inc = vals - prev_vals  # true difference

            # only plot positive increments
            pos_mask = inc > 0
            if np.any(pos_mask):
                ax.bar(
                    xpos[pos_mask], inc[pos_mask], width,
                    bottom=prev_vals[pos_mask],
                    color=colors[j % len(colors)],
                    hatch=hatches[i % len(hatches)],
                    edgecolor="black"
                )

            prev_vals = vals  # always update (so baseline is correct)

    # --- formatting ---
    ax.set_ylabel(ylabel, fontsize=12)
    if title: ax.set_title(title, fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(benchmarks, rotation=45, ha="right")
    ax.yaxis.grid(True, which="both", linestyle="--", alpha=0.6)
    ax.margins(x=0.12)
    if ymax:
        ax.set_ylim(0, ymax)
    
    # --- build combined legend ---
    scenario_patches = [mpatches.Patch(color=colors[j], label=lab) for j, lab in enumerate(labels)]
    hatch_patches = [mpatches.Patch(facecolor="white", edgecolor="black", hatch=hatches[i], label=exp) 
                     for i, exp in enumerate(experiments)]
    ax.legend(handles=scenario_patches + hatch_patches, title="Legend", ncol=2)
    
    if len(save_as) > 0:
        plt.savefig(f"IPC_plots/{save_as}.png",  bbox_inches='tight')
    plt.tight_layout()
    plt.show()

def plotMPKIStack(
    df, 
    mean=False, 
    experiment_like='', 
    exclude_benchs=[], 
    only_benchs=[], 
    rm_components=[],  
    save_as="", 
    legend_out=False,
    show_means=False,   # NEW ARGUMENT (already added)
    reverse_experiments=False  # NEW ARGUMENT
):
    latest = df
    latest = filterBenchExperiment(latest, experiment_like, exclude_benchs, only_benchs)

    if mean:
        latest = latest.groupby(['experiment', 'benchmark'])[
            ['btb_mpki', 'cond_mpki', 'mpki', 'bp_mpki', 'icache_mpki']
        ].mean().reset_index(drop=False)
    else:
        latest = latest.loc[latest.groupby(['experiment', 'benchmark'])['finalTick'].idxmax()]
        
    component_cols = {
        "icache_mpki": "L1ICache",
        "btb_mpki": "BTB",
        "bp_mpki": "Branch Predictor"
    }

    for comp in rm_components:
        component_cols.pop(comp, None) 
     
    # --- Optionally add "Mean" row for each experiment ---
    if show_means:
        mean_rows = []
        for exp in latest['experiment'].unique():
            sub = latest[latest['experiment'] == exp]
            means = sub[list(component_cols.keys())].mean()
            row = {"benchmark": "Mean", "experiment": exp}
            row.update(means.to_dict())
            mean_rows.append(row)
        latest = pd.concat([latest, pd.DataFrame(mean_rows)], ignore_index=True)

    # Prepare axis categories
    latest['benchmark'] = latest['benchmark'].astype(str)
    benchmarks = list(latest['benchmark'].unique())
    if "Mean" in benchmarks:
        benchmarks = [b for b in benchmarks if b != "Mean"] + ["Mean"]
    
    experiments = sorted(latest['experiment'].unique())
    if reverse_experiments:   # <-- invert order if requested
        experiments = experiments[::-1]
    
    # Define colors and hatches
    colors = {
        "L1ICache": "#14a31b",
        "BTB": "#3734eb",
        "Branch Predictor": "#eb5334"
    }
    hatches = ['', '*', 'o', '-', 'o+*', '-+.', '.', '+']  
    
    # Plot
    fig, ax = plt.subplots(figsize=(16, 6))
    bar_width = 0.6 / len(experiments)
    x = np.arange(len(benchmarks))
    
    # Offset bars by experiment
    for i, exp in enumerate(experiments):
        subset = latest[latest['experiment'] == exp].set_index('benchmark')
        bottom = np.zeros(len(benchmarks))
        for col, name in component_cols.items():
            values = [subset.loc[b, col] if b in subset.index else 0 for b in benchmarks]
            bars = ax.bar(
                x + i * bar_width, values, bar_width,
                bottom=bottom,
                color=colors[name],
                hatch=hatches[i % len(hatches)],
                edgecolor='black',
                label=name if i == 0 else None
            )
            bottom += values
    
    # Labels and ticks
    ax.set_xticks(x + bar_width * (len(experiments) - 1) / 2)
    ax.set_xticklabels(benchmarks, rotation=45, ha='right')
    ax.set_ylabel("MPKI")
    ax.set_title("MPKI Breakdown per Benchmark (Per Experiment)")
    
    # Legends
    component_handles = [plt.Rectangle((0, 0), 1, 1, color=colors[name]) for name in component_cols.values()]
    experiment_handles = [
        plt.Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black', hatch=hatches[i % len(hatches)])
        for i, exp in enumerate(experiments)
    ]
    all_handles = component_handles + experiment_handles
    all_labels = list(component_cols.values()) + experiments
    
    ax.legend(
        all_handles, all_labels,
        title="Component / Experiment",
        loc='best',
        bbox_to_anchor=(1.0, 0.5) if legend_out else None
    )
    
    plt.tight_layout()
    if len(save_as) > 0:
        plt.savefig(f"mpki_plots/{save_as}.png", bbox_inches='tight')
    plt.show()


 

def plotCPIStack(df, mean = False, experiment_like = '', exclude_benchs = [], only_benchs = [], m_cpi = False, detailed = True,  save_as = "", ymax=None):
    latest = df

    latest = filterBenchExperiment(latest, experiment_like, exclude_benchs, only_benchs)

    cpi_var = 'm_cpi' if m_cpi else 'cpi'

    
    
    if mean:
        latest = latest.groupby(['experiment', 'benchmark'])[['TopDownl1_retiring', 'TopDownL1_frontendBound', 'TopDownL1_badSpeculation', 'TopDownl1_backendBound', "TopDownL2_Mispredicts", "TopDownL2_machineClears", "TopDownL2_serializingStalls", "TopDownL2_coreBound", "TopDownL2_memoryBound" , cpi_var]].mean().reset_index(drop=False)
    else:
        latest = latest.loc[latest.groupby(['experiment', 'benchmark'])['finalTick'].idxmax()]
        
 
    #latest = latest.map(lambda x: 0 if isinstance(x, (int, float)) and x < 0 else x)
        
    #Print the CPI Stack for specific benchmarks and experiments

    # Define component columns and friendly display names
    component_cols_detailed = {
        "TopDownl1_retiring": "Retiring",
        "TopDownL1_frontendBound": "Frontend Bound",
        "TopDownL2_Mispredicts": "BadSpec Mispredicts",
        "TopDownL2_machineClears" : "BadSpec memOrderViolations",
        "TopDownL2_serializingStalls" : "Backend SerializingStalls",
        "TopDownL2_coreBound" : "Backend coreBound",
        "TopDownL2_memoryBound": "Backend memoryBound"
    }

    component_cols_L1 = {
        "TopDownl1_retiring": "Retiring",
        "TopDownL1_frontendBound": "Frontend Bound",
        "TopDownL1_badSpeculation": "Bad Speculation",
        "TopDownl1_backendBound": "Backend Bound"
    }

    component_cols = component_cols_detailed if detailed else component_cols_L1
    
    # Multiply each component % by total CPI
    for col in component_cols:
        latest[col] = latest[col] * latest[cpi_var]
    
    # Prepare axis categories
    latest['benchmark'] = latest['benchmark'].astype(str)
    benchmarks = sorted(latest['benchmark'].unique())
    experiments = sorted(latest['experiment'].unique())
    
    # Define colors and hatches
    colors = {
        "Retiring": "#7B3294",
        "Frontend Bound": "#F0E442",
        "Bad Speculation": "#999999",
        "BadSpec Mispredicts": "#999999",
        "BadSpec memOrderViolations": "#0dd617",
        "Backend Bound": "#D7191C",
        "Backend SerializingStalls": "#f7760c",
        "Backend coreBound": "#1939d7", 
        "Backend memoryBound" : "#D7191C"
    }
    hatches = ['', '*', 'oo', 'X', 'x', '.', '.', '+']  # One per experiment
    
        # Plot
    fig, ax = plt.subplots(figsize=(16, 6))
    bar_width = 0.6 / len(experiments)
    x = np.arange(len(benchmarks))
    
    # Offset bars by experiment
    for i, exp in enumerate(experiments):
        subset = latest[latest['experiment'] == exp].set_index('benchmark')

        for j, b in enumerate(benchmarks):
            if b not in subset.index:
                continue

            total_cpi = subset.loc[b, cpi_var]
            bottom = 0.0
            drawn = 0.0

            for col, name in component_cols.items():
                value = subset.loc[b, col] if col in subset.columns else 0.0

                # Nothing left to allocate? skip
                if drawn >= total_cpi:
                    continue

                # If the component would overshoot, clip it
                if drawn + value > total_cpi:
                    value = max(0, total_cpi - drawn)

                if value > 0:
                    ax.bar(
                        j + i * bar_width,
                        value,
                        bar_width,
                        bottom=bottom,
                        color=colors[name],
                        hatch=hatches[i % len(hatches)],
                        edgecolor='black',
                        label=name if (i == 0 and bottom == 0) else None
                    )

                # Update positions
                drawn += value
                bottom += value
    
    # Labels and ticks
    ax.set_xticks(x + bar_width * (len(experiments) - 1) / 2)
    ax.set_xticklabels(benchmarks, rotation=45, ha='right')
    ax.set_ylabel("CPI")
    ax.set_title("CPI Breakdown per Benchmark (Per Experiment)")
    if ymax:
        ax.set_ylim(0, ymax)
    
    # Legend for components (only once)
    # Component legend (already exists)
    component_handles = [
        plt.Rectangle((0, 0), 1, 1, color=colors[name])
        for name in component_cols.values()
    ]
    
    # Experiment legend (hatch only, gray fill to avoid visual conflict)
    experiment_handles = [
        plt.Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black', hatch=hatches[i % len(hatches)])
        for i, exp in enumerate(experiments)
    ]
    
    # Combine both
    all_handles = component_handles + experiment_handles
    all_labels = list(component_cols.values()) + experiments
    
    # Show both in a unified legend
    ax.legend(
        all_handles,
        all_labels,
        title="Component / Experiment",
        loc='center left',
        bbox_to_anchor=(1.0, 0.5)  # x=1.0 means just outside the plot on the right
    )
    
    plt.tight_layout()
    if len(save_as) > 0:
            plt.savefig(f"CPI_Stacks/{save_as}.png",  bbox_inches='tight')
    plt.show()