Source code for isaricanalytics.visualisation

from __future__ import annotations

__all__ = [
    "fig_bar_chart",
    "fig_bar_line_chart",
    "fig_count_chart",
    "fig_dual_stack_pyramid",
    "fig_flowchart",
    "fig_forest_plot",
    "fig_frequency_chart",
    "fig_heatmaps",
    "fig_kaplan_meier",
    "fig_line_chart",
    "fig_pie",
    "fig_placeholder",
    "fig_sankey",
    "fig_sunburst",
    "fig_table",
    "fig_text",
    "fig_timelines",
    "fig_upset",
    "hex_to_rgb",
    "hex_to_rgba",
    "rgb_to_rgba",
]

# -- IMPORTS --

# -- Standard libraries --
import typing

# -- 3rd party libraries --
import numpy as np
import pandas
import plotly
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# -- Internal libraries --


pd = pandas  # An alias to allow Pandas code refs to work independently
# of Pandas Intersphinx refs in type hinting and docstrings

############################################
############################################
# Figures
############################################
############################################


[docs] def fig_placeholder( data: pandas.DataFrame, title: str = "Placeholder scatter plot", xlabel: str = "", ylabel: str = "", height: int = 450, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a placeholder scatter plot. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Placeholder scatter plot" Figure title. xlabel : str, default="" Figure x-axis label. ylabel : str, default="" Figure y-axis label. height : int, default=450 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ if data is None: x = [1, 2, 3, 4, 5] y = np.random.uniform(low=10, high=15, size=5) else: x = data["x"] y = data["y"] fig = go.Figure() fig.add_trace( go.Scatter(x=x, y=y, mode="markers", marker={"size": 10, "color": "blue"}) ) fig.update_layout( title={"text": title, "x": 0.5, "xanchor": "center"}, xaxis_title=xlabel, yaxis_title=ylabel, yaxis_range=[10, 15], height=height, minreducedwidth=500, ) return fig
[docs] def fig_pie( data: pandas.DataFrame, title: str = "Pie chart", xlabel: str = "", ylabel: str = "", base_color_map: dict[str, str] | None = None, names: str | int | pd.Series | typing.Iterable = "", values: str | int | pd.Series | typing.Iterable = "", height: int = 450, ): """:py:class:`plotly.graph_objs.Figure` : Returns a pie chart figure. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Placeholder scatter plot" Figure title. xlabel : str, default="" Figure x-axis label. ylabel : str, default="" Figure y-axis label. base_color_map : dict Map of sector values and colours. names : str, int, pd.Series, typing.Iterable, default="" Sector name(s)/label(s). values : str, int, pd.Series, typing.Iterable, default="" Sector values. height : int, default=450 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() fig = px.pie( df, values=values, names=names, title=title, color=names, color_discrete_map=base_color_map, ) fig.update_layout( title={"text": title, "x": 0.5, "xanchor": "center"}, xaxis_title=xlabel, yaxis_title=ylabel, height=height, minreducedwidth=500, ) return fig
[docs] def fig_timelines( data: pandas.DataFrame, title: str = "Timeline", label_col: str = "", group_col: str = "", start_date: str = "start_date", end_date: str = "end_date", size_col: str | None = None, min_width: int = 2, max_width: int = 10, height: int = 500, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a timeline figure. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Timeline" Figure title. label_col : str, default="" Label column. group_col : str, default="" Group column. start_date : str, default="start_date" Start date column. end_date : str, default="end_date" End date column. size_col : str, None, default=None Size column. min_width : int, default=2 Figure minimum width. max_width : int, default=10 Figure maximum width. height : int, default=500 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() df[start_date] = pd.to_datetime(df[start_date]) df[end_date] = pd.to_datetime(df[end_date]) max_end = df[end_date].max() # Assign colors by group_col unique_groups = df[group_col].unique() color_map = { group: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] for i, group in enumerate(unique_groups) } # Assign line widths if size_col is used if size_col and df[size_col].notnull().any(): values = df[size_col].fillna(0).astype(float) min_val, max_val = values.min(), values.max() if min_val == max_val: widths = { row[label_col]: (min_width + max_width) / 2 for _, row in df.iterrows() } else: widths = { row[label_col]: min_width + (val - min_val) / (max_val - min_val) * (max_width - min_width) for row, val in zip(df.to_dict(orient="records"), values) } else: widths = {row[label_col]: 3 for _, row in df.iterrows()} fig = go.Figure() for _, row in df.iterrows(): y = row[label_col] x_start = row[start_date] x_end = row[end_date] if pd.notnull(row[end_date]) else max_end ongoing = pd.isnull(row[end_date]) color = color_map[row[group_col]] width = widths[y] symbol = ["circle", "arrow-right"] if ongoing else ["circle", "circle"] size = [14, 20] if ongoing else [14, 14] fig.add_trace( go.Scatter( x=[x_start, x_end], y=[y, y], mode="lines+markers", line={"color": color, "width": width}, marker={ "size": size, "symbol": symbol, "color": color, "line": {"width": 1, "color": color}, "opacity": 1, }, name=row[group_col], legendgroup=row[group_col], showlegend=(row[group_col] not in [t.name for t in fig.data]), ) ) fig.update_layout( title={"text": title, "x": 0.5, "xanchor": "center"}, xaxis_title="Date", yaxis={"title": label_col, "tickfont": {"size": 10}}, margin={"l": 250, "r": 20, "t": 40, "b": 40}, height=height, minreducedwidth=500, ) return fig
[docs] def fig_sunburst( data: pandas.DataFrame, title: str = "Sunburst Chart", path: list[str | int] | pd.Series | typing.Iterable | None = ["level0", "level1"], values: str | int | pd.Series | typing.Iterable = "values", base_color_map: dict[str, str] | None = None, height: int = 430, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a sunburst plot. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Sunburst Chart" Figure title. path : str, int, pd.Series, typing.Iterable, None, default=["level0", "level1"] Column names defining a hierarhy of sectors, from root to leaves. values : str, int, pd.Series, typing.Iterable, default="values" A column name in the data defining sector values, or a Pandas Series or an iterable containing sector values. base_color_map : dict, default=None Map of sector values/marks and colours. height : int, default=430 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() fig = px.sunburst( df, path=path, values=values, ) fig.update_traces( sort=False, selector={"type": "sunburst"}, insidetextorientation="radial", customdata=path, hovertemplate="%{label}<br>N=%{value}<extra></extra>", ) fig.update_layout( title={"text": title, "x": 0.5, "xanchor": "center"}, height=height, minreducedwidth=500, ) return fig
[docs] def fig_bar_chart( data: pandas.DataFrame, title: str = "Bar Chart", xlabel: str = "", ylabel: str = "", index_column: str = "index", barmode: str = "stack", xaxis_tickformat: str = "%m-%Y", base_color_map: dict[str, str] | None = None, height: int = 340, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a bar chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Bar Chart" Figure title. xlabel : str, default="" Figure x-axis label. ylabel : str, default="" Figure y-axis label. index_column : str, default="index" Index column. barmode : str, default="stack" How bars with the same location coordinate are displayed: possible values are `"stack"`, `"relative"`, `"group"`, `"overlay"``. For reference see the `Plotly documentation <https://plotly.com/python-api-reference/generated/plotly.graph_objects.Layout.html>`_. xaxis_tickformat : str, default="%m-%Y" x-axis tick format. base_color_map : dict, default=None Map of bar values and colours. height : int, default=340 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() df = df.set_index(index_column) # Generate dynamic colors if base_color_map is not provided if base_color_map is None: unique_groups = df.columns color_palette = px.colors.qualitative.Plotly base_color_map = { group: color_palette[i % len(color_palette)] for i, group in enumerate(unique_groups) } # Create traces for each stack_group with colors from the base_color_map traces = [] for stack_group in df.columns: # Assign color from base_color_map color = base_color_map.get(stack_group, "#000") traces.append( go.Bar( x=df.index, y=df[stack_group], name=stack_group, orientation="v", marker={"color": color}, ) ) # Layout settings with customized x-axis tick format if barmode == "group": bargap = 0.1 else: bargap = 0 layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, barmode=barmode, bargap=bargap, xaxis={ "title": xlabel, "tickformat": xaxis_tickformat, # Display x-axis in MM-YYYY format "tickvals": df.index, # Optional: only specific dates if needed }, yaxis={"title": ylabel}, legend={"x": 1.05, "y": 1}, margin={"l": 100, "r": 100, "t": 100, "b": 50}, paper_bgcolor="white", plot_bgcolor="white", height=height, minreducedwidth=500, ) fig = go.Figure(data=traces, layout=layout) return fig
[docs] def fig_upset( data: tuple[pd.DataFrame], title: str = "Upset Plot", height: int = 480, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns an upset plot. Parameters ---------- data : tuple Incoming data as two Pandas dataframes, the first for counts, and the second for intersections. title : str, default="Upset Plot" Figure title. height : int, default=480 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ counts = data[0].copy() intersections = data[1].copy() hlabel = "label" slabel = "short_label" hoverlabels = counts[hlabel].tolist() labels = counts[slabel].tolist() column_widths = [intersections.shape[0], counts.shape[0]] # Initialize subplots fig = make_subplots( rows=2, cols=2, shared_xaxes=True, shared_yaxes=True, column_widths=column_widths, vertical_spacing=0.02, # Space between the plots horizontal_spacing=0.01, ) # Create bar chart traces for intersection sizes bar_traces = [] for ii in intersections.index: color = rgb_to_rgba(px.colors.sequential.Purples_r[ii % 5], 1) hoverlabel = "<br>".join(intersections.loc[ii, hlabel]) n = intersections.loc[ii, "count"] customdata = f"Intersection of<br>{hoverlabel}<br><br>Count = {n}" bar_traces.append( go.Bar( y=[n], x=[ii], orientation="v", name="", customdata=[customdata], hovertemplate="%{customdata}", width=0.9, offset=-0.45, marker={"color": color}, showlegend=False, ) ) # Add bar traces to the top subplot for trace in bar_traces: fig.add_trace(trace, row=1, col=1) bar_traces = [] for ii in counts.index: hoverlabel = counts.loc[ii, hlabel] color = rgb_to_rgba(px.colors.sequential.Oranges_r[ii % 5], 1) n = counts.loc[ii, "count"] bar_traces.append( go.Bar( y=[-1 - ii], x=[n], orientation="h", name="", customdata=[f"{hoverlabel}<br><br>Count = {n}"], hovertemplate="%{customdata}", width=0.9, offset=-0.45, marker={"color": color}, showlegend=False, ) ) # Add bar traces to the top subplot for trace in bar_traces: fig.add_trace(trace, row=2, col=2) # Create matrix scatter plot and lines for ii in intersections.index: intersection = intersections.loc[ii, hlabel] y_coords = [ -1 - x for x in range(len(hoverlabels)) if hoverlabels[x] in intersection ] x_coords = [ii] * len(y_coords) # Add a line connecting the points # Only add a line if there are at least two points if len(y_coords) > 1: fig.add_trace( go.Scatter( x=x_coords, y=y_coords, mode="lines", line={"color": "black", "width": 2}, showlegend=False, hovertemplate="%{x}", name="", ), row=2, col=1, ) # Add scatter plot for each point in the intersection fig.add_trace( go.Scatter( x=x_coords, y=y_coords, mode="markers", marker={"size": 10, "color": "black"}, showlegend=False, customdata=["<br>".join(intersection)] * len(y_coords), hovertemplate="%{customdata}", name="", ), row=2, col=1, ) # Update y-axis for the bar chart subplot fig.update_yaxes(title_text="Intersection Size", row=1, col=1) # Update x-axis for the bar chart subplot fig.update_xaxes(title_text="Set Size", side="top", row=2, col=2) # Update y-axis for the matrix subplot to show category names # instead of numeric fig.update_yaxes( tickvals=[-1 - i for i in range(len(labels))], ticktext=labels, showgrid=False, row=2, col=1, labelalias=dict(zip(labels, hoverlabels)), ) # Hide x-axis line for the intersection size subplot fig.update_xaxes(showline=False, tickformat=",d", row=1, col=1) # Hide x-axis ticks and labels for the matrix subplot fig.update_xaxes( ticks="", showticklabels=False, showgrid=False, zeroline=False, row=2, col=1 ) # Hide y-axis ticks and labels for the set size subplot fig.update_yaxes( ticks="", showticklabels=False, showgrid=False, zeroline=False, row=2, col=2 ) # Set the overall layout properties fig.update_layout( title={"text": title, "x": 0.5, "xanchor": "center"}, # showlegend=False, legend={ "orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1, }, height=height, minreducedwidth=500, ) return fig
[docs] def fig_count_chart( data: pandas.DataFrame, title: str = "Count Chart", xlabel: str = "Count", ylabel: str = "Variable", base_color_map: dict[str, str] | None = None, height: int = 350, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a count chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Count Chart" Figure title. xlabel : str, default="Count" Figure x-axis label. ylabel : str, default="Variable" Figure y-axis label. base_color_map : dict, default=None Map of bar values and colours. height : int, default=350 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() column_names = ["label", "count", "short_label"] # Error Handling if not all(col in df.columns for col in column_names): error_str = "Dataframe must contain the following columns: " error_str += f"{column_names}" raise ValueError(error_str) # Prepare Data Traces traces = [] default_color = "#007E71" yes_color = ( base_color_map.get("Yes", default_color) if base_color_map else default_color ) for ii in reversed(range(df.shape[0])): hoverlabel = df.loc[ii, column_names[0]] yes_count = df.loc[ii, column_names[1]] label = df.loc[ii, column_names[2]] # Add 'Yes' bar traces.append( go.Bar( x=[yes_count], y=[label], name="Yes", orientation="h", width=0.9, offset=-0.45, marker={"color": yes_color}, customdata=[hoverlabel], hovertemplate="%{customdata}: %{x:.2f}", # Show legend only for the first showlegend=(ii == 0), ) ) xlim = [0, df[column_names[1]].max()] layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, barmode="stack", xaxis={"title": xlabel, "range": xlim}, yaxis={ "title": ylabel, "automargin": True, "tickmode": "array", "tickvals": df[column_names[2]], "ticktext": df[column_names[2]], }, bargap=0.1, # Smaller gap between bars. Adjust this value as needed. legend={ "orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1, }, margin={"l": 100, "r": 100, "t": 100, "b": 50}, height=height, minreducedwidth=500, ) fig = go.Figure(data=traces, layout=layout) return fig
[docs] def fig_frequency_chart( data: pandas.DataFrame, title: str = "Frequency Chart", xlabel: str = "Proportion", ylabel: str = "Variable", base_color_map: dict[str, str] | None = None, height: int = 350, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a frequency chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Frequency Chart" Figure title. xlabel : str, default="Proportion" Figure x-axis label. ylabel : str, default="Variable" Figure y-axis label. base_color_map : dict Map of bar values and colours. height : int, default=350 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() column_names = ["label", "proportion", "short_label"] # Error Handling if not all(col in df.columns for col in column_names): error_str = "Dataframe must contain the following columns: " error_str += f"{column_names}" raise ValueError(error_str) # Prepare Data Traces traces = [] default_color = "#007E71" yes_color = ( base_color_map.get("Yes", default_color) if base_color_map else default_color ) no_color = ( hex_to_rgba(base_color_map.get("No", default_color), 0.5) if base_color_map else hex_to_rgba(default_color, 0.5) ) for ii in reversed(range(df.shape[0])): hoverlabel = df.loc[ii, column_names[0]] yes_count = df.loc[ii, column_names[1]] label = df.loc[ii, column_names[2]] no_count = 1 - yes_count # Add 'Yes' bar traces.append( go.Bar( x=[yes_count], y=[label], name="Yes", orientation="h", width=0.9, offset=-0.45, marker={"color": yes_color}, customdata=[hoverlabel], hovertemplate="%{customdata}: %{x:.2f}", # Show legend only for the first showlegend=(ii == 0), ) ) # Add 'No' bar traces.append( go.Bar( x=[no_count], y=[label], name="No", orientation="h", width=0.9, offset=-0.45, marker={"color": no_color}, customdata=[hoverlabel], hovertemplate="%{customdata}: %{x:.2f}", # Show legend only for the first showlegend=(ii == 0), ) ) layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, barmode="stack", xaxis={"title": xlabel, "range": [0, 1]}, yaxis={ "title": ylabel, "automargin": True, "tickmode": "array", "tickvals": df[column_names[2]], "ticktext": df[column_names[2]], }, bargap=0.1, # Smaller gap between bars. Adjust this value as needed. legend={ "orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1, }, margin={"l": 100, "r": 100, "t": 100, "b": 50}, height=height, minreducedwidth=500, ) fig = go.Figure(data=traces, layout=layout) return fig
[docs] def fig_table( data: pandas.DataFrame, table_key: str = "", columnwidth: typing.Iterable[int | float] | None = None, height: int = 500, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a table figure. Parameters ---------- data : pandas.DataFrame Incoming data. table_key : str, default="" Table key. columnwidth : typing.Iterable, default=None An iterable of column widths. height : int, default=500 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() bf_columns = ["<b>" + x + "</b>" for x in df.columns] df.rename(columns=dict(zip(df.columns, bf_columns)), inplace=True) df = df.fillna("") n = df.shape[1] default_firstwidth = 0.3 if columnwidth is not None: columnwidth = [x / sum(columnwidth) for x in columnwidth] else: if n < 2: columnwidth = [1] else: columnwidth = [default_firstwidth] columnwidth += [(1 - default_firstwidth) / (n - 1)] * (n - 1) fig = go.Figure( data=[ go.Table( header={ "values": list(df.columns), "fill_color": "#bbbbbb", "align": "left", "font": {"size": 13}, }, cells={ "values": [df[col] for col in df.columns], "align": ["left"] + ["right"] * (n - 1), "font": {"size": 12}, }, columnwidth=columnwidth, ), ] ) fig.update_layout( title={"text": table_key, "x": 0.95, "y": 0.08, "font": {"size": 12}}, height=height, minreducedwidth=500, ) return fig
[docs] def fig_dual_stack_pyramid( data: pandas.DataFrame, title: str = "Dual-Sided Stacked Pyramid Chart", xlabel: str = "Count", ylabel: str = "Category", base_color_map: dict[str, str] | None = None, height: int = 430, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a dual-sided stacked pyramid chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Dual-Sided Stacked Pyramid Chart" Figure title. xlabel : str, default="Count" Figure x-axis label. ylabel : str, default="Category" Figure y-axis label. base_color_map : dict, default=None Map of bar values and colours. height : int, default=430 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ # noqa : E501 df = data.copy() # Error Handling required_columns = ["y_axis", "side", "stack_group", "value"] if any(x not in tuple(df.columns) for x in required_columns): error_str = "Dataframe must contain the following columns: " error_str += f"{required_columns}" raise ValueError(error_str) if df.empty: raise ValueError("The DataFrame is empty.") # if len(df['side'].unique()) != 2: # TODO # fig = {} left_side_label = "" right_side_label = "" if (df["left_side"] == 1).any(): left_side_label = df.loc[(df["left_side"] == 1), "side"].values[0] if (df["left_side"] == 0).any(): right_side_label = df.loc[(df["left_side"] == 0), "side"].values[0] # Dynamic Color Mapping if (base_color_map is not None) and (not isinstance(base_color_map, dict)): error_str = "color_mapping must be a dictionary with stack_group" error_str += "as keys and color codes as values." raise ValueError(error_str) color_map = {} for stack_group, color in base_color_map.items(): for side in df["side"].unique(): if side == df["side"].unique()[0]: # Convert to RGBA with 50% opacity modified_color = hex_to_rgba(color, 0.75) else: # Convert to RGBA with full opacity modified_color = hex_to_rgba(color, 1) color_map[(side, stack_group)] = modified_color # Prepare Data Traces traces = [] max_value = df["value"].abs().max() for side in df["side"].unique(): for stack_group in df["stack_group"].unique(): subset = df[((df["side"] == side) & (df["stack_group"] == stack_group))] if subset.empty: continue # Get color from the color_map using both side and stack_group color = color_map.get((side, stack_group)) # x_val = ( # -subset['value'] if (side == df['side'].unique()[0]) # else subset['value']) x_val = -subset["value"] if subset["left_side"].any() else subset["value"] traces.append( go.Bar( y=subset["y_axis"], x=x_val, name=f"{side} {stack_group}", orientation="h", # Use the color from the color_map marker={"color": color}, ) ) max_value = df.groupby(["side", "y_axis"], observed=True).sum()["value"].max() max_value += max_value % 2 # Layout settings tickvals = [ -int(max_value), -int(max_value / 2), 0, int(max_value / 2), int(max_value), ] layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, barmode="relative", xaxis={ "title": xlabel, "range": [-max_value, max_value], "automargin": True, "tickvals": tickvals, # Labels as positive numbers "ticktext": [str(abs(x)) for x in tickvals], }, yaxis={ "title": ylabel, "automargin": True, "categoryorder": "array", "categoryarray": df["y_axis"], }, annotations=[ { "x": 0.2, # Position at 10% from the left edge of the graph "y": 1.1, # Position just above the top of the graph "xref": "paper", "yref": "paper", "text": left_side_label, "showarrow": False, # 'font': {'family': 'Arial', 'size': 14, 'color': 'black'}, "align": "center", }, { # Position at 90% from the left edge of the graph # (i.e., near the right edge) "x": 0.8, "y": 1.1, # Position just above the top of the graph "xref": "paper", "yref": "paper", "text": right_side_label, "showarrow": False, # 'font': {'family': 'Arial', 'size': 14, 'color': 'black'}, "align": "center", }, ], shapes=[ # Line at x=0 for reference { "type": "line", "x0": 0, "y0": 0, # Start point of the line (from the bottom) "x1": 0, "y1": 1, # End point of the line (goes to the top) # Reference to x axis and paper for y axis "xref": "x", "yref": "paper", "line": {"color": "black", "width": 2}, } ], legend={"x": 1.05, "y": 1}, margin={"l": 100, "r": 100, "t": 100, "b": 50}, paper_bgcolor="white", plot_bgcolor="white", height=height, minreducedwidth=500, ) fig = go.Figure(data=traces, layout=layout) return fig
[docs] def fig_flowchart( data: pandas.DataFrame, height: int = 430, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a flowchart. Parameters ---------- data : pandas.DataFrame Incoming data. height : int, default=430 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() arrows = [] arrow_to = df["arrow_to"].apply(lambda x: x.replace(" ", "")) arrow_to = arrow_to.loc[arrow_to != ""] ind_start = arrow_to.index.repeat( arrow_to.apply(lambda x: len(x.split(","))) ).tolist() ind_end = [int(x) for x in ",".join(arrow_to).split(",")] for ii in range(len(ind_start)): arrow_start_x = df.loc[ind_start[ii], "x"] arrow_start_y = df.loc[ind_start[ii], "y"] arrow_end_x = df.loc[ind_end[ii], "x"] arrow_end_y = df.loc[ind_end[ii], "y"] new_arrows = pd.DataFrame(columns=["x", "y", "ax", "ay", "arrowhead"]) new_arrows["ax"] = [arrow_start_x, (arrow_end_x + arrow_start_x) / 2] new_arrows["ay"] = [arrow_start_y, (arrow_end_y + arrow_start_y) / 2] new_arrows["x"] = [(arrow_end_x + arrow_start_x) / 2, arrow_end_x] new_arrows["y"] = [(arrow_end_y + arrow_start_y) / 2, arrow_end_y] new_arrows["arrowhead"] = [1, 0] arrows = arrows + [new_arrows] arrow_data = pd.concat(arrows, axis=0).reset_index(drop=True) arrow_metadata = { "showarrow": True, "arrowwidth": 1.5, "arrowcolor": "rgba(100, 100, 100, 0.5)", "axref": "x", "ayref": "y", "xref": "x", "yref": "y", "text": "", } arrows = [{**arrow, **arrow_metadata} for arrow in arrow_data.to_dict("records")] df.drop(columns="arrow_to", inplace=True) annotation_metadata = { "showarrow": False, "xanchor": "center", "yanchor": "middle", "bgcolor": "rgba(150,150,150,1)", "bordercolor": "rgba(100,100,100,0.5)", "borderwidth": 1, "borderpad": 5, } annotations = [ {**annotation, **annotation_metadata} for annotation in df.to_dict("records") ] layout = go.Layout( annotations=arrows + annotations, xaxis={"visible": False, "showgrid": False, "range": [0, 1]}, yaxis={"visible": False, "showgrid": False, "range": [0, 1]}, plot_bgcolor="rgba(0, 0, 0, 0)", height=height, minreducedwidth=500, ) fig = go.Figure(layout=layout) return fig
[docs] def fig_forest_plot( data: pandas.DataFrame, title: str = "Forest Plot", xlabel: str = "Odds Ratio (95% CI)", ylabel: str = "", reorder: bool = True, labels: typing.Iterable[str] = ["Variable", "OddsRatio", "LowerCI", "UpperCI"], marker: dict[str, typing.Any] | None = None, noeffect_line: bool = True, height: int = 600, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a forest plot. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Forest Plot" Figure title. xlabel : str, default="Odds Ratio (95% CI)" Figure x-axis label. ylabel : str, default="" Figure y-axis label. reorder : bool, default=True Sort values. labels : typing.Iterable, default=["Variable", "OddsRatio", "LowerCI", "UpperCI"] Column of labels. marker : dict, default=None Marker properties dict. no_effect_line : bool, default=True Add no effect line. height : int, default=600 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() # Ordering Values -> Descending Order if reorder: df = df.sort_values(by=labels[1], ascending=True) else: df = df.loc[::-1] if marker is None: marker = {"color": "blue", "size": 10} # Error Handling if not set(labels).issubset(df.columns): print(df.columns) error_str = f"Dataframe must contain the following columns: {labels}" raise ValueError(error_str) # Prepare Data Traces traces = [] # Add the point estimates as scatter plot points traces.append( go.Scatter( x=df[labels[1]], y=df[labels[0]], mode="markers", name="Odds Ratio", marker=marker, ) ) # Add the confidence intervals as lines for index, row in df.iterrows(): traces.append( go.Scatter( x=[row[labels[2]], row[labels[3]]], y=[row[labels[0]], row[labels[0]]], mode="lines", showlegend=False, line={"color": marker["color"], "width": 2}, ) ) if noeffect_line is not None: if isinstance(noeffect_line, dict) is False: noeffect_line = {"color": "red", "width": 2} add_shape = [ { # Line of no effect "type": "line", "x0": 1, "y0": -0.5, "x1": 1, "y1": len(df[labels[0]]) - 0.5, "line": noeffect_line, } ] else: add_shape = None # Define layout layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, xaxis={"title": xlabel}, yaxis={ "title": ylabel, "automargin": True, "tickmode": "array", "tickvals": df[labels[0]].tolist(), "ticktext": df[labels[0]].tolist(), "range": [-1, len(df[labels[0]])], }, shapes=add_shape, margin={"l": 100, "r": 100, "t": 100, "b": 50}, height=height, minreducedwidth=500, showlegend=False, ) fig = go.Figure(data=traces, layout=layout) return fig
[docs] def fig_text( data: pandas.DataFrame, height: int = 430, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a figure with an annotation. Parameters ---------- data : pandas.DataFrame Incoming data. height : int, default=430 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ fig = go.Figure() text = "<br>".join(data["paragraphs"].values) fig.add_annotation(x=0, y=0, text=text, showarrow=False) fig.update_layout( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", xaxis={"visible": False, "range": [-1, 1]}, yaxis={"visible": False, "range": [-1, 1]}, height=height, minreducedwidth=500, ) return fig
[docs] def fig_kaplan_meier( data: tuple[pd.DataFrame], title: str = "Kaplan-Meier Plot", xlabel: str = "Time (days)", ylabel: str = "Survival Probability", groups: typing.Iterable[str] | None = None, index_column: str = "index", base_color_map: dict[str, str] | None = None, xlim: typing.Iterable[float | int] | None = None, p_value: float | None = None, height: int = 480, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a Kaplan-Meier plot. Parameters ---------- data : tuple Incoming data as two Pandas dataframes, the first for the plot, and the second for the risk table. title : str, default="Kaplan-Meier Plot" Figure title. xlabel : str, default="Time (days)" Figure x-axis label. ylabel : str, default="Survival Probability" Figure y-axis label. groups : typing.Iterable, default=None Groups. index_column : str, default="index" Index column. base_color_map : dict, default=None Colour map. xlim : typing.Iterable, default=None xlim. p_value : float, default=None p-value. height : int, default=480 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df_km = data[0].copy() risk_table = data[1].copy() if "Group" not in risk_table.columns: raise KeyError("risk table must contain column 'Groups'") if "timeline" not in df_km.columns: raise KeyError("KM curves table must contain column 'timeline'") df_km = df_km.set_index("timeline") if groups is None: groups = [c for c in df_km.columns if "_lower_" not in c and "_upper_" not in c] if base_color_map is None: colors = [ f"hsl({i * (360 / len(groups))}, 70%, 50%)" for i in range(len(groups)) ] else: colors = list(base_color_map.values()) # unique_groups = ["HIV positive", "HIV negative"] # Create the figure with two rows: one for the plot and one for risk table fig = make_subplots( rows=2, cols=1, # shared_xaxes=True, row_heights=[0.7, 0.3], vertical_spacing=0.1, subplot_titles=[title, ""], ) for group, color in zip(groups, colors): ci_lower_column = [ col for col in df_km.columns if col.startswith(group + "_lower") ][0] ci_upper_column = [ col for col in df_km.columns if col.startswith(group + "_upper") ][0] ci = df_km[[col for col in df_km.columns if col.startswith(group + "_")]] ci = ci.dropna() ci_lower = ci[ci_lower_column] ci_upper = ci[ci_upper_column] # Add confidence interval as shaded area ci_x = ci_upper.index.tolist() + ci_lower.index[::-1].tolist() ci_y = ci_upper.tolist() + ci_lower[::-1].tolist() fig.add_trace( go.Scatter( x=ci_x, y=ci_y, fill="toself", fillcolor=color.replace("hsl", "hsla") .replace("rgb", "rgba") .replace(")", ",0.2)"), line={"color": "rgba(255,255,255,0)", "shape": "hv"}, name=f"CI {group}", showlegend=False, hoverinfo="text", text=[f"CI {group}" for _ in range(len(ci_upper) + len(ci_lower))], ), row=1, col=1, ) for group, color in zip(groups, colors): survival = df_km[group] # Add survival curve fig.add_trace( go.Scatter( x=survival.index, y=survival.values, mode="lines", name=str(group), line={"color": color, "shape": "hv"}, ), row=1, col=1, ) # Add p-value annotation to the plot if p_value is not None: p_value_text = ( "p-value: <0.001" if p_value < 0.001 else f"p-value: {p_value:.3f}" ) fig.add_annotation( text=p_value_text, x=0.95, y=95, xref="paper", yref="y1", showarrow=False, font={"size": 12, "color": "black"}, bgcolor="white", bordercolor="black", borderwidth=1, ) # Add risk table as second row for ii in range(risk_table.shape[0]): yval = np.arange(risk_table.shape[0])[::-1][ii] color = dict(zip(groups, colors)).get(risk_table.loc[ii, index_column], "black") fig.add_trace( go.Scatter( x=risk_table.drop(columns=index_column).columns.astype(float), y=np.repeat(yval, risk_table.shape[1] - 1), mode="text", text=risk_table.drop(columns=index_column).loc[ii], textposition="middle center", textfont={"color": color}, showlegend=False, ), row=2, col=1, ) # Configure axes and layout fig.update_yaxes( title_text=ylabel, range=[-1, 101], # tickvals=np.arange(0, 110, 10), # ticktext=[f'{i}%' for i in range(0, 110, 10)], row=1, col=1, ) fig.update_xaxes( title_text=xlabel, # range=[risk_table.] tickvals=risk_table.columns[1:], row=1, col=1, ) if xlim is not None: xlim[0] = min((xlim[0], xlim[0] - (xlim[1] - xlim[0]) * 0.02)) xlim[1] = max((xlim[1], xlim[1] + (xlim[1] - xlim[0]) * 0.02)) else: xlim = [df_km.index.min(), df_km.index.max()] xlim[0] = min((xlim[0], xlim[0] - (xlim[1] - xlim[0]) * 0.02)) xlim[1] = max((xlim[1], xlim[1] + (xlim[1] - xlim[0]) * 0.02)) fig.add_trace( go.Scatter( x=[xlim[0], xlim[0]], y=[-0.5, risk_table.shape[0] - 0.5], mode="lines", line={"color": "black", "width": 1}, showlegend=False, ), row=2, col=1, ) fig.add_trace( go.Scatter( x=[xlim[0], xlim[1]], y=[risk_table.shape[0] - 0.5, risk_table.shape[0] - 0.5], mode="lines", line={"color": "black", "width": 1}, showlegend=False, ), row=2, col=1, ) fig.update_xaxes(range=xlim, row=1, col=1) fig.update_xaxes(visible=False, showgrid=False, range=xlim, row=2, col=1) fig.update_yaxes( range=[-0.5, risk_table.shape[0] - 0.5], showgrid=False, title_text="Number at risk", tickvals=np.arange(risk_table.shape[0])[::-1], ticktext=risk_table[index_column], row=2, col=1, ) fig.update_layout( shapes=[ dict( type="rect", xref="x domain", yref="y domain", x0=0, y0=0, x1=1, y1=1, fillcolor="grey", opacity=0.1, layer="below", line_width=0.5, ) ] ) fig.update_layout( height=height, plot_bgcolor="rgba(0, 0, 0, 0)", showlegend=True, minreducedwidth=500, ) return fig
[docs] def fig_line_chart( data: pandas.DataFrame, title: str = "Line chart", xlabel: str = "", ylabel: str = "", height: int = 480, line_column: str = "", index_column: str = "index", lower_column: str | None = None, upper_column: str | None = None, line_color: str | None = None, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a line chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Line chart" Figure title. xlabel : str, default="" Figure x-axis label. ylabel : str, default="" Figure y-axis label. height : int, default=480 Figure height. line_column : str, default="" Line column. index_column : str, default="index" Index column. lower_column : str, default=None Lower column. upper_column : str, default=None Upper column. line_color : str, default=None Line colour. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() # Ensure correct index df = df.set_index(index_column) # Create line trace line_trace = go.Scatter( x=df.index, y=df[line_column], mode="lines+markers", name=line_column.replace("_", " ").title(), marker={"color": line_color}, line={"color": line_color, "width": 2, "dash": "solid"}, ) data = [line_trace] yaxis_min = 0 yaxis_max = df[line_column].max() * 1.1 if (upper_column is not None) & (lower_column is not None): bounds_trace = go.Scatter( x=df.index.tolist() + df.index[::-1].tolist(), y=df[upper_column].tolist() + df[lower_column][::-1].tolist(), fill="toself", fillcolor="rgba(150,150,150,0.2)", line={"color": "rgba(255,255,255,0)"}, showlegend=False, ) data = [line_trace, bounds_trace] yaxis_min = min((0, df[lower_column].min() * 1.1)) yaxis_max = df[upper_column].max() * 1.1 # Define layout layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, xaxis={ "title": xlabel, "tickmode": "array", "tickvals": df.index, "ticktext": df.index, # Force display as years "tickangle": -30, "tickfont": {"size": 10}, }, yaxis={"title": ylabel, "range": [yaxis_min, yaxis_max]}, legend={"x": 0.85, "y": 1, "bgcolor": "rgba(255,255,255,0.5)"}, margin={"l": 60, "r": 60, "t": 50, "b": 80}, paper_bgcolor="white", plot_bgcolor="white", height=height, minreducedwidth=500, ) fig = go.Figure(data=data, layout=layout) return fig
[docs] def fig_bar_line_chart( data: pandas.DataFrame, title: str = "Combined bar line chart", xlabel: str = "", ylabel_left: str = "", ylabel_right: str = "", bar_column: str = "", line_column: str = "", index_column: str = "index", lower_column: str | None = None, upper_column: str | None = None, bar_color: str | None = None, line_color: str | None = None, height: int = 500, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a bar-line chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="Combined bar line chart" Figure title. xlabel : str, default="" Figure x-axis label. ylabel : str, default="" Figure y-axis label. bar_column : str, default="" Bar column. line_column : str, default="" Line column. index_column : str, default="index" Index column. lower_column : str, default=None Lower column. upper_column : str, default=None Upper column. bar_color : str, default=None Bar colour. line_color : str, default=None Line colour. height : int, default=500 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ df = data.copy() # Ensure correct index df = df.set_index(index_column) # # Format x-axis labels to show only the year # x_labels = df.index.strftime('%Y') # Create bar trace bar_trace = go.Bar( x=df.index, y=df[bar_column], name=bar_column.replace("_", " ").title(), marker={"color": bar_color}, yaxis="y", ) # Create line trace line_trace = go.Scatter( x=df.index, y=df[line_column], mode="lines+markers", name=line_column.replace("_", " ").title(), marker={"color": line_color}, line={"color": line_color, "width": 2, "dash": "solid"}, yaxis="y2", ) data = [bar_trace, line_trace] yaxis2_min = 0 yaxis2_max = df[line_column].max() * 1.1 if (upper_column is not None) & (lower_column is not None): bounds_trace = go.Scatter( x=df.index.tolist() + df.index[::-1].tolist(), y=df[upper_column].tolist() + df[lower_column][::-1].tolist(), fill="toself", fillcolor="rgba(150,150,150,0.2)", line={"color": "rgba(255,255,255,0)"}, showlegend=False, yaxis="y2", ) data = [bar_trace, line_trace, bounds_trace] yaxis2_min = min((0, df[lower_column].min() * 1.1)) yaxis2_max = df[upper_column].max() * 1.1 # Define layout layout = go.Layout( title={"text": title, "x": 0.5, "xanchor": "center"}, barmode="stack", bargap=0.3, xaxis={ "title": xlabel, "tickmode": "array", "tickvals": df.index, "ticktext": df.index, # Force display as years "tickangle": -30, "tickfont": {"size": 10}, }, yaxis={"title": ylabel_left, "range": [0, df[bar_column].max() * 1.1]}, yaxis2={ "title": ylabel_right, "overlaying": "y", "side": "right", "showgrid": False, "range": [yaxis2_min, yaxis2_max], }, legend={"x": 0.85, "y": 1, "bgcolor": "rgba(255,255,255,0.5)"}, margin={"l": 60, "r": 60, "t": 50, "b": 80}, paper_bgcolor="white", plot_bgcolor="white", height=height, minreducedwidth=500, ) fig = go.Figure(data=data, layout=layout) return fig
[docs] def fig_heatmaps( data: pandas.DataFrame, title: str = "", subplot_titles: list[str] | None = None, ylabel: str = "", xlabel: str = "", colorbar_label: str = "", index_column: str = "index", zmin: float | None = None, zmax: float | None = None, include_annotations: bool = False, base_color_map: dict[str, str] | None = None, height: int = 750, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a heatmaps chart. Parameters ---------- data : pandas.DataFrame Incoming data. title : str, default="" Figure title. subplot_titles : list, default=None Subplot titles. xlabel : str, default="" Figure x-axis label. ylabel : str, default="" Figure y-axis label. colorbar_label : str, default="" Colour bar label. index_column : str, default="index" Index column. zmin : float, default=None zmin. zmax : float, default=None zmax. include_annotations : bool, default=False Include annotations. base_color_map : dict, default=None Colour map. height : int, default=750 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ if isinstance(data, tuple) is False: data = (data,) # Create subplots for the heatmaps fig = make_subplots( rows=len(data), cols=1, subplot_titles=subplot_titles, vertical_spacing=0.1, y_title=ylabel, ) if zmin is None: zmin = min( data[ii].drop(columns=index_column).min().min() for ii in range(len(data)) ) if zmax is None: zmax = max( data[ii].drop(columns=index_column).max().max() for ii in range(len(data)) ) if base_color_map is None: base_color_map = "viridis" for ii in range(len(data)): df = data[ii].set_index(index_column) if include_annotations: text = df.loc[::-1].astype(str).values texttemplate = "%{text}" else: text = None texttemplate = None fig.add_trace( go.Heatmap( z=df.loc[::-1].values, x=df.loc[::-1].columns, y=df.loc[::-1].index, text=text, texttemplate=texttemplate, zmin=zmin, zmax=zmax, colorscale=base_color_map, colorbar=({"title": colorbar_label} if ii == 0 else None), showscale=(True if ii == 0 else False), ), row=(ii + 1), col=1, ) # Update layout fig.update_layout( height=height, title_text=title, title_x=0.5, title_xref="paper", showlegend=False, # margin={'l': 150} ) for ii in range(1, len(data)): fig.update_xaxes(showticklabels=False, row=ii, col=1) fig.update_xaxes( showticklabels=True, tickangle=0, title=xlabel, row=len(data) + 1, col=1, ) return fig
[docs] def fig_sankey( data: pandas.DataFrame, height: int = 500, ) -> plotly.graph_objs.Figure: """:py:class:`plotly.graph_objs.Figure` : Returns a Sankey plot. Parameters ---------- data : pandas.DataFrame Incoming data. height : int, default=500 Figure height. Returns ------- plotly.graph_objs.Figure The Plotly figure. """ node = data[0].copy() link = data[1].copy() annotations = data[2].copy() node_metadata = { "hovertemplate": "%{customdata}", "pad": 15, "thickness": 20, "line": {"color": "black", "width": 1.2}, } link_metadata = { "hovertemplate": "%{source.customdata} to %{target.customdata}", "line": {"color": "rgba(0,0,0,0.3)", "width": 0.3}, } fig = go.Figure( data=[ go.Sankey( arrangement="snap", valueformat=".0f", node={**node.to_dict(orient="list"), **node_metadata}, link={**link.to_dict(orient="list"), **link_metadata}, ) ], layout=go.Layout( annotations=annotations.to_dict(orient="records"), height=height, minreducedwidth=500, ), ) return fig
############################################ ############################################ # Formatting: colours ############################################ ############################################
[docs] def hex_to_rgb(hex_color: str) -> tuple[int]: """:py:class:`tuple` : Converts a hex colour to an RGB colour tuple. Parameters ---------- hex_color : str Hex colour string. Returns ------- tuple RGB colour tuple. """ hex_color = hex_color.lstrip("#") rgb_color = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) return rgb_color
[docs] def hex_to_rgba(hex_color: str, opacity: float) -> str: """:py:class:`str` : Converts a hex colour to an RGBA (red-green-blue-alpha) colour string. Parameters ---------- hex_color : str Hex colour string. opacity : float Opacity/transparency, a value between 0.0 (fully transparent) and 1.0 (fully opaque). Returns ------- str An RGBA colour string (RGB + opacity/transparency). """ hex_color = hex_color.lstrip("#") hlen = len(hex_color) rgba_color = "rgba(" + ", ".join( str(int(hex_color[i : i + hlen // 3], 16)) for i in range(0, hlen, hlen // 3) ) rgba_color += f", {opacity})" return rgba_color
[docs] def rgb_to_rgba(rgb_color: tuple[int], alpha: float) -> str: """:py:class:`str` : Converts an RGB colour tuple and alpha value to an RGBA colour string. Parameters ---------- rgb_color : tuple RGB color tuple. alpha : float Opacity/transparency value between 0.0 (fully transparent) and 1.0 (fully opaque). Returns ------- str RGBA colour string. """ # noqa: E501 rgba_color = f"rgba({rgb_color[0]}, {rgb_color[1]}, {rgb_color[2]}, {alpha})" return rgba_color