from __future__ import annotations
__all__ = [
"fig_bar_chart",
"fig_bar_line_chart",
"fig_count_chart",
"fig_dual_stack_pyramid",
"fig_flowchart",
"fig_forest_plot",
"fig_frequency_chart",
"fig_heatmaps",
"fig_kaplan_meier",
"fig_line_chart",
"fig_pie",
"fig_placeholder",
"fig_sankey",
"fig_sunburst",
"fig_table",
"fig_text",
"fig_timelines",
"fig_upset",
"hex_to_rgb",
"hex_to_rgba",
"rgb_to_rgba",
]
# -- IMPORTS --
# -- Standard libraries --
import typing
# -- 3rd party libraries --
import numpy as np
import pandas
import plotly
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
# -- Internal libraries --
pd = pandas # An alias to allow Pandas code refs to work independently
# of Pandas Intersphinx refs in type hinting and docstrings
############################################
############################################
# Figures
############################################
############################################
[docs]
def fig_placeholder(
data: pandas.DataFrame,
title: str = "Placeholder scatter plot",
xlabel: str = "",
ylabel: str = "",
height: int = 450,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a placeholder scatter plot.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Placeholder scatter plot"
Figure title.
xlabel : str, default=""
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
height : int, default=450
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
if data is None:
x = [1, 2, 3, 4, 5]
y = np.random.uniform(low=10, high=15, size=5)
else:
x = data["x"]
y = data["y"]
fig = go.Figure()
fig.add_trace(
go.Scatter(x=x, y=y, mode="markers", marker={"size": 10, "color": "blue"})
)
fig.update_layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
xaxis_title=xlabel,
yaxis_title=ylabel,
yaxis_range=[10, 15],
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_pie(
data: pandas.DataFrame,
title: str = "Pie chart",
xlabel: str = "",
ylabel: str = "",
base_color_map: dict[str, str] | None = None,
names: str | int | pd.Series | typing.Iterable = "",
values: str | int | pd.Series | typing.Iterable = "",
height: int = 450,
):
""":py:class:`plotly.graph_objs.Figure` : Returns a pie chart figure.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Placeholder scatter plot"
Figure title.
xlabel : str, default=""
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
base_color_map : dict
Map of sector values and colours.
names : str, int, pd.Series, typing.Iterable, default=""
Sector name(s)/label(s).
values : str, int, pd.Series, typing.Iterable, default=""
Sector values.
height : int, default=450
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
fig = px.pie(
df,
values=values,
names=names,
title=title,
color=names,
color_discrete_map=base_color_map,
)
fig.update_layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
xaxis_title=xlabel,
yaxis_title=ylabel,
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_timelines(
data: pandas.DataFrame,
title: str = "Timeline",
label_col: str = "",
group_col: str = "",
start_date: str = "start_date",
end_date: str = "end_date",
size_col: str | None = None,
min_width: int = 2,
max_width: int = 10,
height: int = 500,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a timeline figure.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Timeline"
Figure title.
label_col : str, default=""
Label column.
group_col : str, default=""
Group column.
start_date : str, default="start_date"
Start date column.
end_date : str, default="end_date"
End date column.
size_col : str, None, default=None
Size column.
min_width : int, default=2
Figure minimum width.
max_width : int, default=10
Figure maximum width.
height : int, default=500
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
df[start_date] = pd.to_datetime(df[start_date])
df[end_date] = pd.to_datetime(df[end_date])
max_end = df[end_date].max()
# Assign colors by group_col
unique_groups = df[group_col].unique()
color_map = {
group: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]
for i, group in enumerate(unique_groups)
}
# Assign line widths if size_col is used
if size_col and df[size_col].notnull().any():
values = df[size_col].fillna(0).astype(float)
min_val, max_val = values.min(), values.max()
if min_val == max_val:
widths = {
row[label_col]: (min_width + max_width) / 2 for _, row in df.iterrows()
}
else:
widths = {
row[label_col]: min_width
+ (val - min_val) / (max_val - min_val) * (max_width - min_width)
for row, val in zip(df.to_dict(orient="records"), values)
}
else:
widths = {row[label_col]: 3 for _, row in df.iterrows()}
fig = go.Figure()
for _, row in df.iterrows():
y = row[label_col]
x_start = row[start_date]
x_end = row[end_date] if pd.notnull(row[end_date]) else max_end
ongoing = pd.isnull(row[end_date])
color = color_map[row[group_col]]
width = widths[y]
symbol = ["circle", "arrow-right"] if ongoing else ["circle", "circle"]
size = [14, 20] if ongoing else [14, 14]
fig.add_trace(
go.Scatter(
x=[x_start, x_end],
y=[y, y],
mode="lines+markers",
line={"color": color, "width": width},
marker={
"size": size,
"symbol": symbol,
"color": color,
"line": {"width": 1, "color": color},
"opacity": 1,
},
name=row[group_col],
legendgroup=row[group_col],
showlegend=(row[group_col] not in [t.name for t in fig.data]),
)
)
fig.update_layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
xaxis_title="Date",
yaxis={"title": label_col, "tickfont": {"size": 10}},
margin={"l": 250, "r": 20, "t": 40, "b": 40},
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_sunburst(
data: pandas.DataFrame,
title: str = "Sunburst Chart",
path: list[str | int] | pd.Series | typing.Iterable | None = ["level0", "level1"],
values: str | int | pd.Series | typing.Iterable = "values",
base_color_map: dict[str, str] | None = None,
height: int = 430,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a sunburst plot.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Sunburst Chart"
Figure title.
path : str, int, pd.Series, typing.Iterable, None, default=["level0", "level1"]
Column names defining a hierarhy of sectors, from root to leaves.
values : str, int, pd.Series, typing.Iterable, default="values"
A column name in the data defining sector values, or a Pandas Series
or an iterable containing sector values.
base_color_map : dict, default=None
Map of sector values/marks and colours.
height : int, default=430
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
fig = px.sunburst(
df,
path=path,
values=values,
)
fig.update_traces(
sort=False,
selector={"type": "sunburst"},
insidetextorientation="radial",
customdata=path,
hovertemplate="%{label}<br>N=%{value}<extra></extra>",
)
fig.update_layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_bar_chart(
data: pandas.DataFrame,
title: str = "Bar Chart",
xlabel: str = "",
ylabel: str = "",
index_column: str = "index",
barmode: str = "stack",
xaxis_tickformat: str = "%m-%Y",
base_color_map: dict[str, str] | None = None,
height: int = 340,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a bar chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Bar Chart"
Figure title.
xlabel : str, default=""
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
index_column : str, default="index"
Index column.
barmode : str, default="stack"
How bars with the same location coordinate are displayed: possible
values are `"stack"`, `"relative"`, `"group"`, `"overlay"``. For
reference see the `Plotly documentation <https://plotly.com/python-api-reference/generated/plotly.graph_objects.Layout.html>`_.
xaxis_tickformat : str, default="%m-%Y"
x-axis tick format.
base_color_map : dict, default=None
Map of bar values and colours.
height : int, default=340
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
df = df.set_index(index_column)
# Generate dynamic colors if base_color_map is not provided
if base_color_map is None:
unique_groups = df.columns
color_palette = px.colors.qualitative.Plotly
base_color_map = {
group: color_palette[i % len(color_palette)]
for i, group in enumerate(unique_groups)
}
# Create traces for each stack_group with colors from the base_color_map
traces = []
for stack_group in df.columns:
# Assign color from base_color_map
color = base_color_map.get(stack_group, "#000")
traces.append(
go.Bar(
x=df.index,
y=df[stack_group],
name=stack_group,
orientation="v",
marker={"color": color},
)
)
# Layout settings with customized x-axis tick format
if barmode == "group":
bargap = 0.1
else:
bargap = 0
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
barmode=barmode,
bargap=bargap,
xaxis={
"title": xlabel,
"tickformat": xaxis_tickformat, # Display x-axis in MM-YYYY format
"tickvals": df.index, # Optional: only specific dates if needed
},
yaxis={"title": ylabel},
legend={"x": 1.05, "y": 1},
margin={"l": 100, "r": 100, "t": 100, "b": 50},
paper_bgcolor="white",
plot_bgcolor="white",
height=height,
minreducedwidth=500,
)
fig = go.Figure(data=traces, layout=layout)
return fig
[docs]
def fig_upset(
data: tuple[pd.DataFrame],
title: str = "Upset Plot",
height: int = 480,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns an upset plot.
Parameters
----------
data : tuple
Incoming data as two Pandas dataframes, the first for counts, and the
second for intersections.
title : str, default="Upset Plot"
Figure title.
height : int, default=480
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
counts = data[0].copy()
intersections = data[1].copy()
hlabel = "label"
slabel = "short_label"
hoverlabels = counts[hlabel].tolist()
labels = counts[slabel].tolist()
column_widths = [intersections.shape[0], counts.shape[0]]
# Initialize subplots
fig = make_subplots(
rows=2,
cols=2,
shared_xaxes=True,
shared_yaxes=True,
column_widths=column_widths,
vertical_spacing=0.02, # Space between the plots
horizontal_spacing=0.01,
)
# Create bar chart traces for intersection sizes
bar_traces = []
for ii in intersections.index:
color = rgb_to_rgba(px.colors.sequential.Purples_r[ii % 5], 1)
hoverlabel = "<br>".join(intersections.loc[ii, hlabel])
n = intersections.loc[ii, "count"]
customdata = f"Intersection of<br>{hoverlabel}<br><br>Count = {n}"
bar_traces.append(
go.Bar(
y=[n],
x=[ii],
orientation="v",
name="",
customdata=[customdata],
hovertemplate="%{customdata}",
width=0.9,
offset=-0.45,
marker={"color": color},
showlegend=False,
)
)
# Add bar traces to the top subplot
for trace in bar_traces:
fig.add_trace(trace, row=1, col=1)
bar_traces = []
for ii in counts.index:
hoverlabel = counts.loc[ii, hlabel]
color = rgb_to_rgba(px.colors.sequential.Oranges_r[ii % 5], 1)
n = counts.loc[ii, "count"]
bar_traces.append(
go.Bar(
y=[-1 - ii],
x=[n],
orientation="h",
name="",
customdata=[f"{hoverlabel}<br><br>Count = {n}"],
hovertemplate="%{customdata}",
width=0.9,
offset=-0.45,
marker={"color": color},
showlegend=False,
)
)
# Add bar traces to the top subplot
for trace in bar_traces:
fig.add_trace(trace, row=2, col=2)
# Create matrix scatter plot and lines
for ii in intersections.index:
intersection = intersections.loc[ii, hlabel]
y_coords = [
-1 - x for x in range(len(hoverlabels)) if hoverlabels[x] in intersection
]
x_coords = [ii] * len(y_coords)
# Add a line connecting the points
# Only add a line if there are at least two points
if len(y_coords) > 1:
fig.add_trace(
go.Scatter(
x=x_coords,
y=y_coords,
mode="lines",
line={"color": "black", "width": 2},
showlegend=False,
hovertemplate="%{x}",
name="",
),
row=2,
col=1,
)
# Add scatter plot for each point in the intersection
fig.add_trace(
go.Scatter(
x=x_coords,
y=y_coords,
mode="markers",
marker={"size": 10, "color": "black"},
showlegend=False,
customdata=["<br>".join(intersection)] * len(y_coords),
hovertemplate="%{customdata}",
name="",
),
row=2,
col=1,
)
# Update y-axis for the bar chart subplot
fig.update_yaxes(title_text="Intersection Size", row=1, col=1)
# Update x-axis for the bar chart subplot
fig.update_xaxes(title_text="Set Size", side="top", row=2, col=2)
# Update y-axis for the matrix subplot to show category names
# instead of numeric
fig.update_yaxes(
tickvals=[-1 - i for i in range(len(labels))],
ticktext=labels,
showgrid=False,
row=2,
col=1,
labelalias=dict(zip(labels, hoverlabels)),
)
# Hide x-axis line for the intersection size subplot
fig.update_xaxes(showline=False, tickformat=",d", row=1, col=1)
# Hide x-axis ticks and labels for the matrix subplot
fig.update_xaxes(
ticks="", showticklabels=False, showgrid=False, zeroline=False, row=2, col=1
)
# Hide y-axis ticks and labels for the set size subplot
fig.update_yaxes(
ticks="", showticklabels=False, showgrid=False, zeroline=False, row=2, col=2
)
# Set the overall layout properties
fig.update_layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
# showlegend=False,
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_count_chart(
data: pandas.DataFrame,
title: str = "Count Chart",
xlabel: str = "Count",
ylabel: str = "Variable",
base_color_map: dict[str, str] | None = None,
height: int = 350,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a count chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Count Chart"
Figure title.
xlabel : str, default="Count"
Figure x-axis label.
ylabel : str, default="Variable"
Figure y-axis label.
base_color_map : dict, default=None
Map of bar values and colours.
height : int, default=350
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
column_names = ["label", "count", "short_label"]
# Error Handling
if not all(col in df.columns for col in column_names):
error_str = "Dataframe must contain the following columns: "
error_str += f"{column_names}"
raise ValueError(error_str)
# Prepare Data Traces
traces = []
default_color = "#007E71"
yes_color = (
base_color_map.get("Yes", default_color) if base_color_map else default_color
)
for ii in reversed(range(df.shape[0])):
hoverlabel = df.loc[ii, column_names[0]]
yes_count = df.loc[ii, column_names[1]]
label = df.loc[ii, column_names[2]]
# Add 'Yes' bar
traces.append(
go.Bar(
x=[yes_count],
y=[label],
name="Yes",
orientation="h",
width=0.9,
offset=-0.45,
marker={"color": yes_color},
customdata=[hoverlabel],
hovertemplate="%{customdata}: %{x:.2f}",
# Show legend only for the first
showlegend=(ii == 0),
)
)
xlim = [0, df[column_names[1]].max()]
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
barmode="stack",
xaxis={"title": xlabel, "range": xlim},
yaxis={
"title": ylabel,
"automargin": True,
"tickmode": "array",
"tickvals": df[column_names[2]],
"ticktext": df[column_names[2]],
},
bargap=0.1, # Smaller gap between bars. Adjust this value as needed.
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
margin={"l": 100, "r": 100, "t": 100, "b": 50},
height=height,
minreducedwidth=500,
)
fig = go.Figure(data=traces, layout=layout)
return fig
[docs]
def fig_frequency_chart(
data: pandas.DataFrame,
title: str = "Frequency Chart",
xlabel: str = "Proportion",
ylabel: str = "Variable",
base_color_map: dict[str, str] | None = None,
height: int = 350,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a frequency chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Frequency Chart"
Figure title.
xlabel : str, default="Proportion"
Figure x-axis label.
ylabel : str, default="Variable"
Figure y-axis label.
base_color_map : dict
Map of bar values and colours.
height : int, default=350
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
column_names = ["label", "proportion", "short_label"]
# Error Handling
if not all(col in df.columns for col in column_names):
error_str = "Dataframe must contain the following columns: "
error_str += f"{column_names}"
raise ValueError(error_str)
# Prepare Data Traces
traces = []
default_color = "#007E71"
yes_color = (
base_color_map.get("Yes", default_color) if base_color_map else default_color
)
no_color = (
hex_to_rgba(base_color_map.get("No", default_color), 0.5)
if base_color_map
else hex_to_rgba(default_color, 0.5)
)
for ii in reversed(range(df.shape[0])):
hoverlabel = df.loc[ii, column_names[0]]
yes_count = df.loc[ii, column_names[1]]
label = df.loc[ii, column_names[2]]
no_count = 1 - yes_count
# Add 'Yes' bar
traces.append(
go.Bar(
x=[yes_count],
y=[label],
name="Yes",
orientation="h",
width=0.9,
offset=-0.45,
marker={"color": yes_color},
customdata=[hoverlabel],
hovertemplate="%{customdata}: %{x:.2f}",
# Show legend only for the first
showlegend=(ii == 0),
)
)
# Add 'No' bar
traces.append(
go.Bar(
x=[no_count],
y=[label],
name="No",
orientation="h",
width=0.9,
offset=-0.45,
marker={"color": no_color},
customdata=[hoverlabel],
hovertemplate="%{customdata}: %{x:.2f}",
# Show legend only for the first
showlegend=(ii == 0),
)
)
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
barmode="stack",
xaxis={"title": xlabel, "range": [0, 1]},
yaxis={
"title": ylabel,
"automargin": True,
"tickmode": "array",
"tickvals": df[column_names[2]],
"ticktext": df[column_names[2]],
},
bargap=0.1, # Smaller gap between bars. Adjust this value as needed.
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
margin={"l": 100, "r": 100, "t": 100, "b": 50},
height=height,
minreducedwidth=500,
)
fig = go.Figure(data=traces, layout=layout)
return fig
[docs]
def fig_table(
data: pandas.DataFrame,
table_key: str = "",
columnwidth: typing.Iterable[int | float] | None = None,
height: int = 500,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a table figure.
Parameters
----------
data : pandas.DataFrame
Incoming data.
table_key : str, default=""
Table key.
columnwidth : typing.Iterable, default=None
An iterable of column widths.
height : int, default=500
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
bf_columns = ["<b>" + x + "</b>" for x in df.columns]
df.rename(columns=dict(zip(df.columns, bf_columns)), inplace=True)
df = df.fillna("")
n = df.shape[1]
default_firstwidth = 0.3
if columnwidth is not None:
columnwidth = [x / sum(columnwidth) for x in columnwidth]
else:
if n < 2:
columnwidth = [1]
else:
columnwidth = [default_firstwidth]
columnwidth += [(1 - default_firstwidth) / (n - 1)] * (n - 1)
fig = go.Figure(
data=[
go.Table(
header={
"values": list(df.columns),
"fill_color": "#bbbbbb",
"align": "left",
"font": {"size": 13},
},
cells={
"values": [df[col] for col in df.columns],
"align": ["left"] + ["right"] * (n - 1),
"font": {"size": 12},
},
columnwidth=columnwidth,
),
]
)
fig.update_layout(
title={"text": table_key, "x": 0.95, "y": 0.08, "font": {"size": 12}},
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_dual_stack_pyramid(
data: pandas.DataFrame,
title: str = "Dual-Sided Stacked Pyramid Chart",
xlabel: str = "Count",
ylabel: str = "Category",
base_color_map: dict[str, str] | None = None,
height: int = 430,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a dual-sided stacked pyramid chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Dual-Sided Stacked Pyramid Chart"
Figure title.
xlabel : str, default="Count"
Figure x-axis label.
ylabel : str, default="Category"
Figure y-axis label.
base_color_map : dict, default=None
Map of bar values and colours.
height : int, default=430
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
""" # noqa : E501
df = data.copy()
# Error Handling
required_columns = ["y_axis", "side", "stack_group", "value"]
if any(x not in tuple(df.columns) for x in required_columns):
error_str = "Dataframe must contain the following columns: "
error_str += f"{required_columns}"
raise ValueError(error_str)
if df.empty:
raise ValueError("The DataFrame is empty.")
# if len(df['side'].unique()) != 2: # TODO
# fig = {}
left_side_label = ""
right_side_label = ""
if (df["left_side"] == 1).any():
left_side_label = df.loc[(df["left_side"] == 1), "side"].values[0]
if (df["left_side"] == 0).any():
right_side_label = df.loc[(df["left_side"] == 0), "side"].values[0]
# Dynamic Color Mapping
if (base_color_map is not None) and (not isinstance(base_color_map, dict)):
error_str = "color_mapping must be a dictionary with stack_group"
error_str += "as keys and color codes as values."
raise ValueError(error_str)
color_map = {}
for stack_group, color in base_color_map.items():
for side in df["side"].unique():
if side == df["side"].unique()[0]:
# Convert to RGBA with 50% opacity
modified_color = hex_to_rgba(color, 0.75)
else:
# Convert to RGBA with full opacity
modified_color = hex_to_rgba(color, 1)
color_map[(side, stack_group)] = modified_color
# Prepare Data Traces
traces = []
max_value = df["value"].abs().max()
for side in df["side"].unique():
for stack_group in df["stack_group"].unique():
subset = df[((df["side"] == side) & (df["stack_group"] == stack_group))]
if subset.empty:
continue
# Get color from the color_map using both side and stack_group
color = color_map.get((side, stack_group))
# x_val = (
# -subset['value'] if (side == df['side'].unique()[0])
# else subset['value'])
x_val = -subset["value"] if subset["left_side"].any() else subset["value"]
traces.append(
go.Bar(
y=subset["y_axis"],
x=x_val,
name=f"{side} {stack_group}",
orientation="h",
# Use the color from the color_map
marker={"color": color},
)
)
max_value = df.groupby(["side", "y_axis"], observed=True).sum()["value"].max()
max_value += max_value % 2
# Layout settings
tickvals = [
-int(max_value),
-int(max_value / 2),
0,
int(max_value / 2),
int(max_value),
]
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
barmode="relative",
xaxis={
"title": xlabel,
"range": [-max_value, max_value],
"automargin": True,
"tickvals": tickvals,
# Labels as positive numbers
"ticktext": [str(abs(x)) for x in tickvals],
},
yaxis={
"title": ylabel,
"automargin": True,
"categoryorder": "array",
"categoryarray": df["y_axis"],
},
annotations=[
{
"x": 0.2, # Position at 10% from the left edge of the graph
"y": 1.1, # Position just above the top of the graph
"xref": "paper",
"yref": "paper",
"text": left_side_label,
"showarrow": False,
# 'font': {'family': 'Arial', 'size': 14, 'color': 'black'},
"align": "center",
},
{
# Position at 90% from the left edge of the graph
# (i.e., near the right edge)
"x": 0.8,
"y": 1.1, # Position just above the top of the graph
"xref": "paper",
"yref": "paper",
"text": right_side_label,
"showarrow": False,
# 'font': {'family': 'Arial', 'size': 14, 'color': 'black'},
"align": "center",
},
],
shapes=[
# Line at x=0 for reference
{
"type": "line",
"x0": 0,
"y0": 0, # Start point of the line (from the bottom)
"x1": 0,
"y1": 1, # End point of the line (goes to the top)
# Reference to x axis and paper for y axis
"xref": "x",
"yref": "paper",
"line": {"color": "black", "width": 2},
}
],
legend={"x": 1.05, "y": 1},
margin={"l": 100, "r": 100, "t": 100, "b": 50},
paper_bgcolor="white",
plot_bgcolor="white",
height=height,
minreducedwidth=500,
)
fig = go.Figure(data=traces, layout=layout)
return fig
[docs]
def fig_flowchart(
data: pandas.DataFrame,
height: int = 430,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a flowchart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
height : int, default=430
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
arrows = []
arrow_to = df["arrow_to"].apply(lambda x: x.replace(" ", ""))
arrow_to = arrow_to.loc[arrow_to != ""]
ind_start = arrow_to.index.repeat(
arrow_to.apply(lambda x: len(x.split(",")))
).tolist()
ind_end = [int(x) for x in ",".join(arrow_to).split(",")]
for ii in range(len(ind_start)):
arrow_start_x = df.loc[ind_start[ii], "x"]
arrow_start_y = df.loc[ind_start[ii], "y"]
arrow_end_x = df.loc[ind_end[ii], "x"]
arrow_end_y = df.loc[ind_end[ii], "y"]
new_arrows = pd.DataFrame(columns=["x", "y", "ax", "ay", "arrowhead"])
new_arrows["ax"] = [arrow_start_x, (arrow_end_x + arrow_start_x) / 2]
new_arrows["ay"] = [arrow_start_y, (arrow_end_y + arrow_start_y) / 2]
new_arrows["x"] = [(arrow_end_x + arrow_start_x) / 2, arrow_end_x]
new_arrows["y"] = [(arrow_end_y + arrow_start_y) / 2, arrow_end_y]
new_arrows["arrowhead"] = [1, 0]
arrows = arrows + [new_arrows]
arrow_data = pd.concat(arrows, axis=0).reset_index(drop=True)
arrow_metadata = {
"showarrow": True,
"arrowwidth": 1.5,
"arrowcolor": "rgba(100, 100, 100, 0.5)",
"axref": "x",
"ayref": "y",
"xref": "x",
"yref": "y",
"text": "",
}
arrows = [{**arrow, **arrow_metadata} for arrow in arrow_data.to_dict("records")]
df.drop(columns="arrow_to", inplace=True)
annotation_metadata = {
"showarrow": False,
"xanchor": "center",
"yanchor": "middle",
"bgcolor": "rgba(150,150,150,1)",
"bordercolor": "rgba(100,100,100,0.5)",
"borderwidth": 1,
"borderpad": 5,
}
annotations = [
{**annotation, **annotation_metadata} for annotation in df.to_dict("records")
]
layout = go.Layout(
annotations=arrows + annotations,
xaxis={"visible": False, "showgrid": False, "range": [0, 1]},
yaxis={"visible": False, "showgrid": False, "range": [0, 1]},
plot_bgcolor="rgba(0, 0, 0, 0)",
height=height,
minreducedwidth=500,
)
fig = go.Figure(layout=layout)
return fig
[docs]
def fig_forest_plot(
data: pandas.DataFrame,
title: str = "Forest Plot",
xlabel: str = "Odds Ratio (95% CI)",
ylabel: str = "",
reorder: bool = True,
labels: typing.Iterable[str] = ["Variable", "OddsRatio", "LowerCI", "UpperCI"],
marker: dict[str, typing.Any] | None = None,
noeffect_line: bool = True,
height: int = 600,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a forest plot.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Forest Plot"
Figure title.
xlabel : str, default="Odds Ratio (95% CI)"
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
reorder : bool, default=True
Sort values.
labels : typing.Iterable, default=["Variable", "OddsRatio", "LowerCI", "UpperCI"]
Column of labels.
marker : dict, default=None
Marker properties dict.
no_effect_line : bool, default=True
Add no effect line.
height : int, default=600
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
# Ordering Values -> Descending Order
if reorder:
df = df.sort_values(by=labels[1], ascending=True)
else:
df = df.loc[::-1]
if marker is None:
marker = {"color": "blue", "size": 10}
# Error Handling
if not set(labels).issubset(df.columns):
print(df.columns)
error_str = f"Dataframe must contain the following columns: {labels}"
raise ValueError(error_str)
# Prepare Data Traces
traces = []
# Add the point estimates as scatter plot points
traces.append(
go.Scatter(
x=df[labels[1]],
y=df[labels[0]],
mode="markers",
name="Odds Ratio",
marker=marker,
)
)
# Add the confidence intervals as lines
for index, row in df.iterrows():
traces.append(
go.Scatter(
x=[row[labels[2]], row[labels[3]]],
y=[row[labels[0]], row[labels[0]]],
mode="lines",
showlegend=False,
line={"color": marker["color"], "width": 2},
)
)
if noeffect_line is not None:
if isinstance(noeffect_line, dict) is False:
noeffect_line = {"color": "red", "width": 2}
add_shape = [
{
# Line of no effect
"type": "line",
"x0": 1,
"y0": -0.5,
"x1": 1,
"y1": len(df[labels[0]]) - 0.5,
"line": noeffect_line,
}
]
else:
add_shape = None
# Define layout
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
xaxis={"title": xlabel},
yaxis={
"title": ylabel,
"automargin": True,
"tickmode": "array",
"tickvals": df[labels[0]].tolist(),
"ticktext": df[labels[0]].tolist(),
"range": [-1, len(df[labels[0]])],
},
shapes=add_shape,
margin={"l": 100, "r": 100, "t": 100, "b": 50},
height=height,
minreducedwidth=500,
showlegend=False,
)
fig = go.Figure(data=traces, layout=layout)
return fig
[docs]
def fig_text(
data: pandas.DataFrame,
height: int = 430,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a figure with an annotation.
Parameters
----------
data : pandas.DataFrame
Incoming data.
height : int, default=430
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
fig = go.Figure()
text = "<br>".join(data["paragraphs"].values)
fig.add_annotation(x=0, y=0, text=text, showarrow=False)
fig.update_layout(
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
xaxis={"visible": False, "range": [-1, 1]},
yaxis={"visible": False, "range": [-1, 1]},
height=height,
minreducedwidth=500,
)
return fig
[docs]
def fig_kaplan_meier(
data: tuple[pd.DataFrame],
title: str = "Kaplan-Meier Plot",
xlabel: str = "Time (days)",
ylabel: str = "Survival Probability",
groups: typing.Iterable[str] | None = None,
index_column: str = "index",
base_color_map: dict[str, str] | None = None,
xlim: typing.Iterable[float | int] | None = None,
p_value: float | None = None,
height: int = 480,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a Kaplan-Meier plot.
Parameters
----------
data : tuple
Incoming data as two Pandas dataframes, the first for the plot, and the
second for the risk table.
title : str, default="Kaplan-Meier Plot"
Figure title.
xlabel : str, default="Time (days)"
Figure x-axis label.
ylabel : str, default="Survival Probability"
Figure y-axis label.
groups : typing.Iterable, default=None
Groups.
index_column : str, default="index"
Index column.
base_color_map : dict, default=None
Colour map.
xlim : typing.Iterable, default=None
xlim.
p_value : float, default=None
p-value.
height : int, default=480
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df_km = data[0].copy()
risk_table = data[1].copy()
if "Group" not in risk_table.columns:
raise KeyError("risk table must contain column 'Groups'")
if "timeline" not in df_km.columns:
raise KeyError("KM curves table must contain column 'timeline'")
df_km = df_km.set_index("timeline")
if groups is None:
groups = [c for c in df_km.columns if "_lower_" not in c and "_upper_" not in c]
if base_color_map is None:
colors = [
f"hsl({i * (360 / len(groups))}, 70%, 50%)" for i in range(len(groups))
]
else:
colors = list(base_color_map.values())
# unique_groups = ["HIV positive", "HIV negative"]
# Create the figure with two rows: one for the plot and one for risk table
fig = make_subplots(
rows=2,
cols=1,
# shared_xaxes=True,
row_heights=[0.7, 0.3],
vertical_spacing=0.1,
subplot_titles=[title, ""],
)
for group, color in zip(groups, colors):
ci_lower_column = [
col for col in df_km.columns if col.startswith(group + "_lower")
][0]
ci_upper_column = [
col for col in df_km.columns if col.startswith(group + "_upper")
][0]
ci = df_km[[col for col in df_km.columns if col.startswith(group + "_")]]
ci = ci.dropna()
ci_lower = ci[ci_lower_column]
ci_upper = ci[ci_upper_column]
# Add confidence interval as shaded area
ci_x = ci_upper.index.tolist() + ci_lower.index[::-1].tolist()
ci_y = ci_upper.tolist() + ci_lower[::-1].tolist()
fig.add_trace(
go.Scatter(
x=ci_x,
y=ci_y,
fill="toself",
fillcolor=color.replace("hsl", "hsla")
.replace("rgb", "rgba")
.replace(")", ",0.2)"),
line={"color": "rgba(255,255,255,0)", "shape": "hv"},
name=f"CI {group}",
showlegend=False,
hoverinfo="text",
text=[f"CI {group}" for _ in range(len(ci_upper) + len(ci_lower))],
),
row=1,
col=1,
)
for group, color in zip(groups, colors):
survival = df_km[group]
# Add survival curve
fig.add_trace(
go.Scatter(
x=survival.index,
y=survival.values,
mode="lines",
name=str(group),
line={"color": color, "shape": "hv"},
),
row=1,
col=1,
)
# Add p-value annotation to the plot
if p_value is not None:
p_value_text = (
"p-value: <0.001" if p_value < 0.001 else f"p-value: {p_value:.3f}"
)
fig.add_annotation(
text=p_value_text,
x=0.95,
y=95,
xref="paper",
yref="y1",
showarrow=False,
font={"size": 12, "color": "black"},
bgcolor="white",
bordercolor="black",
borderwidth=1,
)
# Add risk table as second row
for ii in range(risk_table.shape[0]):
yval = np.arange(risk_table.shape[0])[::-1][ii]
color = dict(zip(groups, colors)).get(risk_table.loc[ii, index_column], "black")
fig.add_trace(
go.Scatter(
x=risk_table.drop(columns=index_column).columns.astype(float),
y=np.repeat(yval, risk_table.shape[1] - 1),
mode="text",
text=risk_table.drop(columns=index_column).loc[ii],
textposition="middle center",
textfont={"color": color},
showlegend=False,
),
row=2,
col=1,
)
# Configure axes and layout
fig.update_yaxes(
title_text=ylabel,
range=[-1, 101],
# tickvals=np.arange(0, 110, 10),
# ticktext=[f'{i}%' for i in range(0, 110, 10)],
row=1,
col=1,
)
fig.update_xaxes(
title_text=xlabel,
# range=[risk_table.]
tickvals=risk_table.columns[1:],
row=1,
col=1,
)
if xlim is not None:
xlim[0] = min((xlim[0], xlim[0] - (xlim[1] - xlim[0]) * 0.02))
xlim[1] = max((xlim[1], xlim[1] + (xlim[1] - xlim[0]) * 0.02))
else:
xlim = [df_km.index.min(), df_km.index.max()]
xlim[0] = min((xlim[0], xlim[0] - (xlim[1] - xlim[0]) * 0.02))
xlim[1] = max((xlim[1], xlim[1] + (xlim[1] - xlim[0]) * 0.02))
fig.add_trace(
go.Scatter(
x=[xlim[0], xlim[0]],
y=[-0.5, risk_table.shape[0] - 0.5],
mode="lines",
line={"color": "black", "width": 1},
showlegend=False,
),
row=2,
col=1,
)
fig.add_trace(
go.Scatter(
x=[xlim[0], xlim[1]],
y=[risk_table.shape[0] - 0.5, risk_table.shape[0] - 0.5],
mode="lines",
line={"color": "black", "width": 1},
showlegend=False,
),
row=2,
col=1,
)
fig.update_xaxes(range=xlim, row=1, col=1)
fig.update_xaxes(visible=False, showgrid=False, range=xlim, row=2, col=1)
fig.update_yaxes(
range=[-0.5, risk_table.shape[0] - 0.5],
showgrid=False,
title_text="Number at risk",
tickvals=np.arange(risk_table.shape[0])[::-1],
ticktext=risk_table[index_column],
row=2,
col=1,
)
fig.update_layout(
shapes=[
dict(
type="rect",
xref="x domain",
yref="y domain",
x0=0,
y0=0,
x1=1,
y1=1,
fillcolor="grey",
opacity=0.1,
layer="below",
line_width=0.5,
)
]
)
fig.update_layout(
height=height,
plot_bgcolor="rgba(0, 0, 0, 0)",
showlegend=True,
minreducedwidth=500,
)
return fig
[docs]
def fig_line_chart(
data: pandas.DataFrame,
title: str = "Line chart",
xlabel: str = "",
ylabel: str = "",
height: int = 480,
line_column: str = "",
index_column: str = "index",
lower_column: str | None = None,
upper_column: str | None = None,
line_color: str | None = None,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a line chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Line chart"
Figure title.
xlabel : str, default=""
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
height : int, default=480
Figure height.
line_column : str, default=""
Line column.
index_column : str, default="index"
Index column.
lower_column : str, default=None
Lower column.
upper_column : str, default=None
Upper column.
line_color : str, default=None
Line colour.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
# Ensure correct index
df = df.set_index(index_column)
# Create line trace
line_trace = go.Scatter(
x=df.index,
y=df[line_column],
mode="lines+markers",
name=line_column.replace("_", " ").title(),
marker={"color": line_color},
line={"color": line_color, "width": 2, "dash": "solid"},
)
data = [line_trace]
yaxis_min = 0
yaxis_max = df[line_column].max() * 1.1
if (upper_column is not None) & (lower_column is not None):
bounds_trace = go.Scatter(
x=df.index.tolist() + df.index[::-1].tolist(),
y=df[upper_column].tolist() + df[lower_column][::-1].tolist(),
fill="toself",
fillcolor="rgba(150,150,150,0.2)",
line={"color": "rgba(255,255,255,0)"},
showlegend=False,
)
data = [line_trace, bounds_trace]
yaxis_min = min((0, df[lower_column].min() * 1.1))
yaxis_max = df[upper_column].max() * 1.1
# Define layout
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
xaxis={
"title": xlabel,
"tickmode": "array",
"tickvals": df.index,
"ticktext": df.index, # Force display as years
"tickangle": -30,
"tickfont": {"size": 10},
},
yaxis={"title": ylabel, "range": [yaxis_min, yaxis_max]},
legend={"x": 0.85, "y": 1, "bgcolor": "rgba(255,255,255,0.5)"},
margin={"l": 60, "r": 60, "t": 50, "b": 80},
paper_bgcolor="white",
plot_bgcolor="white",
height=height,
minreducedwidth=500,
)
fig = go.Figure(data=data, layout=layout)
return fig
[docs]
def fig_bar_line_chart(
data: pandas.DataFrame,
title: str = "Combined bar line chart",
xlabel: str = "",
ylabel_left: str = "",
ylabel_right: str = "",
bar_column: str = "",
line_column: str = "",
index_column: str = "index",
lower_column: str | None = None,
upper_column: str | None = None,
bar_color: str | None = None,
line_color: str | None = None,
height: int = 500,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a bar-line chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default="Combined bar line chart"
Figure title.
xlabel : str, default=""
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
bar_column : str, default=""
Bar column.
line_column : str, default=""
Line column.
index_column : str, default="index"
Index column.
lower_column : str, default=None
Lower column.
upper_column : str, default=None
Upper column.
bar_color : str, default=None
Bar colour.
line_color : str, default=None
Line colour.
height : int, default=500
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
df = data.copy()
# Ensure correct index
df = df.set_index(index_column)
# # Format x-axis labels to show only the year
# x_labels = df.index.strftime('%Y')
# Create bar trace
bar_trace = go.Bar(
x=df.index,
y=df[bar_column],
name=bar_column.replace("_", " ").title(),
marker={"color": bar_color},
yaxis="y",
)
# Create line trace
line_trace = go.Scatter(
x=df.index,
y=df[line_column],
mode="lines+markers",
name=line_column.replace("_", " ").title(),
marker={"color": line_color},
line={"color": line_color, "width": 2, "dash": "solid"},
yaxis="y2",
)
data = [bar_trace, line_trace]
yaxis2_min = 0
yaxis2_max = df[line_column].max() * 1.1
if (upper_column is not None) & (lower_column is not None):
bounds_trace = go.Scatter(
x=df.index.tolist() + df.index[::-1].tolist(),
y=df[upper_column].tolist() + df[lower_column][::-1].tolist(),
fill="toself",
fillcolor="rgba(150,150,150,0.2)",
line={"color": "rgba(255,255,255,0)"},
showlegend=False,
yaxis="y2",
)
data = [bar_trace, line_trace, bounds_trace]
yaxis2_min = min((0, df[lower_column].min() * 1.1))
yaxis2_max = df[upper_column].max() * 1.1
# Define layout
layout = go.Layout(
title={"text": title, "x": 0.5, "xanchor": "center"},
barmode="stack",
bargap=0.3,
xaxis={
"title": xlabel,
"tickmode": "array",
"tickvals": df.index,
"ticktext": df.index, # Force display as years
"tickangle": -30,
"tickfont": {"size": 10},
},
yaxis={"title": ylabel_left, "range": [0, df[bar_column].max() * 1.1]},
yaxis2={
"title": ylabel_right,
"overlaying": "y",
"side": "right",
"showgrid": False,
"range": [yaxis2_min, yaxis2_max],
},
legend={"x": 0.85, "y": 1, "bgcolor": "rgba(255,255,255,0.5)"},
margin={"l": 60, "r": 60, "t": 50, "b": 80},
paper_bgcolor="white",
plot_bgcolor="white",
height=height,
minreducedwidth=500,
)
fig = go.Figure(data=data, layout=layout)
return fig
[docs]
def fig_heatmaps(
data: pandas.DataFrame,
title: str = "",
subplot_titles: list[str] | None = None,
ylabel: str = "",
xlabel: str = "",
colorbar_label: str = "",
index_column: str = "index",
zmin: float | None = None,
zmax: float | None = None,
include_annotations: bool = False,
base_color_map: dict[str, str] | None = None,
height: int = 750,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a heatmaps chart.
Parameters
----------
data : pandas.DataFrame
Incoming data.
title : str, default=""
Figure title.
subplot_titles : list, default=None
Subplot titles.
xlabel : str, default=""
Figure x-axis label.
ylabel : str, default=""
Figure y-axis label.
colorbar_label : str, default=""
Colour bar label.
index_column : str, default="index"
Index column.
zmin : float, default=None
zmin.
zmax : float, default=None
zmax.
include_annotations : bool, default=False
Include annotations.
base_color_map : dict, default=None
Colour map.
height : int, default=750
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
if isinstance(data, tuple) is False:
data = (data,)
# Create subplots for the heatmaps
fig = make_subplots(
rows=len(data),
cols=1,
subplot_titles=subplot_titles,
vertical_spacing=0.1,
y_title=ylabel,
)
if zmin is None:
zmin = min(
data[ii].drop(columns=index_column).min().min() for ii in range(len(data))
)
if zmax is None:
zmax = max(
data[ii].drop(columns=index_column).max().max() for ii in range(len(data))
)
if base_color_map is None:
base_color_map = "viridis"
for ii in range(len(data)):
df = data[ii].set_index(index_column)
if include_annotations:
text = df.loc[::-1].astype(str).values
texttemplate = "%{text}"
else:
text = None
texttemplate = None
fig.add_trace(
go.Heatmap(
z=df.loc[::-1].values,
x=df.loc[::-1].columns,
y=df.loc[::-1].index,
text=text,
texttemplate=texttemplate,
zmin=zmin,
zmax=zmax,
colorscale=base_color_map,
colorbar=({"title": colorbar_label} if ii == 0 else None),
showscale=(True if ii == 0 else False),
),
row=(ii + 1),
col=1,
)
# Update layout
fig.update_layout(
height=height,
title_text=title,
title_x=0.5,
title_xref="paper",
showlegend=False,
# margin={'l': 150}
)
for ii in range(1, len(data)):
fig.update_xaxes(showticklabels=False, row=ii, col=1)
fig.update_xaxes(
showticklabels=True,
tickangle=0,
title=xlabel,
row=len(data) + 1,
col=1,
)
return fig
[docs]
def fig_sankey(
data: pandas.DataFrame,
height: int = 500,
) -> plotly.graph_objs.Figure:
""":py:class:`plotly.graph_objs.Figure` : Returns a Sankey plot.
Parameters
----------
data : pandas.DataFrame
Incoming data.
height : int, default=500
Figure height.
Returns
-------
plotly.graph_objs.Figure
The Plotly figure.
"""
node = data[0].copy()
link = data[1].copy()
annotations = data[2].copy()
node_metadata = {
"hovertemplate": "%{customdata}",
"pad": 15,
"thickness": 20,
"line": {"color": "black", "width": 1.2},
}
link_metadata = {
"hovertemplate": "%{source.customdata} to %{target.customdata}",
"line": {"color": "rgba(0,0,0,0.3)", "width": 0.3},
}
fig = go.Figure(
data=[
go.Sankey(
arrangement="snap",
valueformat=".0f",
node={**node.to_dict(orient="list"), **node_metadata},
link={**link.to_dict(orient="list"), **link_metadata},
)
],
layout=go.Layout(
annotations=annotations.to_dict(orient="records"),
height=height,
minreducedwidth=500,
),
)
return fig
############################################
############################################
# Formatting: colours
############################################
############################################
[docs]
def hex_to_rgb(hex_color: str) -> tuple[int]:
""":py:class:`tuple` : Converts a hex colour to an RGB colour tuple.
Parameters
----------
hex_color : str
Hex colour string.
Returns
-------
tuple
RGB colour tuple.
"""
hex_color = hex_color.lstrip("#")
rgb_color = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
return rgb_color
[docs]
def hex_to_rgba(hex_color: str, opacity: float) -> str:
""":py:class:`str` : Converts a hex colour to an RGBA (red-green-blue-alpha)
colour string.
Parameters
----------
hex_color : str
Hex colour string.
opacity : float
Opacity/transparency, a value between 0.0 (fully transparent) and
1.0 (fully opaque).
Returns
-------
str
An RGBA colour string (RGB + opacity/transparency).
"""
hex_color = hex_color.lstrip("#")
hlen = len(hex_color)
rgba_color = "rgba(" + ", ".join(
str(int(hex_color[i : i + hlen // 3], 16)) for i in range(0, hlen, hlen // 3)
)
rgba_color += f", {opacity})"
return rgba_color
[docs]
def rgb_to_rgba(rgb_color: tuple[int], alpha: float) -> str:
""":py:class:`str` : Converts an RGB colour tuple and alpha value to an RGBA colour string.
Parameters
----------
rgb_color : tuple
RGB color tuple.
alpha : float
Opacity/transparency value between 0.0 (fully transparent) and 1.0
(fully opaque).
Returns
-------
str
RGBA colour string.
""" # noqa: E501
rgba_color = f"rgba({rgb_color[0]}, {rgb_color[1]}, {rgb_color[2]}, {alpha})"
return rgba_color