Coverage for functions \ flipdare \ analysis \ plotter.py: 92%

131 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-05-08 12:22 +1000

1#!/usr/bin/env python 

2# Copyright (c) 2026 Flipdare Pty Ltd. All rights reserved. 

3# 

4# This file is part of Flipdare's proprietary software and contains 

5# confidential and copyrighted material. Unauthorised copying, 

6# modification, distribution, or use of this file is strictly 

7# prohibited without prior written permission from Flipdare Pty Ltd. 

8# 

9# This software includes third-party components licensed under MIT, 

10# BSD, and Apache 2.0 licences. See THIRD_PARTY_NOTICES for details. 

11# 

12 

13from typing import ClassVar, TypeIs 

14from dataclasses import dataclass 

15from enum import StrEnum 

16import io 

17import pandas as pd 

18import seaborn as sns 

19import numpy as np 

20import matplotlib.pyplot as plt 

21from matplotlib.axes import Axes 

22from flipdare.app_log import LOG 

23from flipdare.app_types import AnalysisArrayType, AnalysisDataType 

24from flipdare.constants import GRAPH_LOG_SCALE_THRESHOLD, IS_TRACE 

25 

26__all__ = ["Plotter", "ScatterData"] 

27 

28# default theme 

29sns.set_theme(style="whitegrid") 

30 

31 

32@dataclass 

33class ScatterData: 

34 points: list[float] 

35 indices: list[int] 

36 label: str 

37 

38 def __post_init__(self) -> None: 

39 if len(self.points) != len(self.indices): 

40 msg = f"Inconsistent array lengths: {len(self.points)} points != {len(self.indices)} indices." 

41 raise ValueError(msg) 

42 

43 

44class PlotStrategy(StrEnum): 

45 LOG_SCALE = "log_scale" 

46 NORMAL_SCALE = "normal_scale" 

47 

48 

49def _is_nested( 

50 val: AnalysisDataType | AnalysisArrayType, 

51) -> TypeIs[AnalysisDataType]: 

52 """Precisely narrows type in BOTH if and else branches.""" 

53 return bool(val and isinstance(val[0], list)) 

54 

55 

56class Plotter: 

57 __slots__ = ( 

58 "_data", 

59 "_graph_title", 

60 "_legend_labels", 

61 "_log_scale_threshold", 

62 "_scatter_data", 

63 "_x_label", 

64 "_x_values", 

65 "_y_label", 

66 ) 

67 

68 # fmt: off 

69 _MARKERS:ClassVar[list[str]] = ["s", "o", "^", "D", "v", "p", "*", "X", "P"] 

70 

71 _COLORS:ClassVar[list[str]] = [ 

72 "orange", "teal", "green", 

73 "purple", "brown", "pink", 

74 "gray", "olive", "cyan"] 

75 

76 _SCATTER_COLOR:ClassVar[list[str]] = [ 

77 "coral", "mediumturquoise", "lightgreen", 

78 "mediumpurple", "rosybrown", "lightpink", 

79 "silver", "olivedrab", "lightcyan"] 

80 # fmt: on 

81 

82 def __init__( 

83 self, 

84 title: str, 

85 x_label: str, 

86 y_label: str, 

87 data: list[AnalysisArrayType] | AnalysisArrayType, 

88 legend_labels: list[str], 

89 x_values: list[str] | None = None, 

90 scatter_data: ScatterData | list[ScatterData] | None = None, 

91 log_scale_threshold: float = GRAPH_LOG_SCALE_THRESHOLD, 

92 ) -> None: 

93 self._graph_title = title 

94 self._x_label = x_label 

95 self._y_label = y_label 

96 if not data: 

97 self._data = [] 

98 elif _is_nested(data): 

99 self._data = data 

100 else: 

101 self._data = [data] 

102 

103 self._scatter_data = scatter_data 

104 self._x_values = x_values 

105 self._legend_labels = legend_labels 

106 self._log_scale_threshold = log_scale_threshold 

107 

108 @property 

109 def plot_strategy(self) -> PlotStrategy: 

110 # Calculate the max value for each individual line 

111 data = self._data 

112 threshold = self._log_scale_threshold 

113 

114 arr = np.array(data, dtype=float) 

115 min_v, max_v = np.nanmin(arr), np.nanmax(arr) 

116 # Calculate the ratio between the absolute highest and lowest peaks 

117 if IS_TRACE: 

118 LOG().trace(f"Data range for plot: min={min_v}, max={max_v}") 

119 

120 min_v = 1 if min_v == 0 else min_v # Avoid division by zero in ratio calculation 

121 scale_ratio = max_v / min_v 

122 strategy = PlotStrategy.LOG_SCALE if scale_ratio > threshold else PlotStrategy.NORMAL_SCALE 

123 

124 if IS_TRACE: 

125 msg = f"Data scale ratio {scale_ratio:.2f} exceeds threshold {threshold}. Using {strategy.value} strategy." 

126 LOG().trace(msg) 

127 

128 return strategy 

129 

130 def create(self) -> io.BytesIO: 

131 raw_data = self._data 

132 if len(raw_data) == 0: 

133 raise ValueError("Data list is empty. Cannot create plot.") 

134 

135 x_label = self._x_label 

136 y_label = self._y_label 

137 fig, ax = plt.subplots(figsize=(12, 6)) 

138 

139 data = self._data 

140 for idx in range(len(data)): 

141 data_entry = np.array(data[idx], dtype=float) 

142 

143 # 1. Construct DataFrame with BOTH columns to satisfy Seaborn's lookup 

144 # We use range(len()) to create the X-axis points (0, 1, 2...) 

145 df = pd.DataFrame({x_label: range(len(data_entry)), y_label: data_entry}) 

146 

147 # 2. Interpolate NaNs to bridge gaps in the line 

148 # 'linear' creates a straight line between the points you DO have 

149 df[y_label] = df[y_label].interpolate(method="linear") 

150 

151 # Note: interpolate() doesn't fill NaNs at the very start or very end. 

152 # If you want to fill those too, chain it with bfill() and ffill() 

153 df[y_label] = df[y_label].bfill().ffill() 

154 

155 color_idx = idx % len(self._COLORS) 

156 marker = self._MARKERS[color_idx] 

157 line_color = self._COLORS[color_idx] 

158 scatter_color = self._SCATTER_COLOR[color_idx] 

159 

160 # plotting 

161 self._plot_main_trend(ax, df, idx, marker, line_color) 

162 self._plot_scatter_points(ax, idx, scatter_color) 

163 # self._plot_missing_markers(ax, df) 

164 

165 self._set_plot_defaults(ax) 

166 fig.tight_layout() 

167 

168 buffer = io.BytesIO() 

169 plt.savefig(buffer, format="png") 

170 plt.close() 

171 buffer.seek(0) 

172 

173 return buffer 

174 

175 def _plot_main_trend( 

176 self, 

177 ax: Axes, 

178 df: pd.DataFrame, 

179 idx: int, 

180 marker_color: str, 

181 line_color: str, 

182 ) -> None: 

183 # 1. Plot the main trend 

184 x_label = self._x_label 

185 y_label = self._y_label 

186 legend_labels = self._legend_labels 

187 

188 line_label = legend_labels[idx] if legend_labels and idx < len(legend_labels) else y_label 

189 sns.lineplot( 

190 data=df, 

191 ax=ax, 

192 x=x_label, 

193 y=y_label, 

194 marker=marker_color, 

195 color=line_color, 

196 label=line_label, 

197 linewidth=max(3 - (idx * 0.8), 1.4), 

198 alpha=0.6, 

199 sort=False, # Preserve original order of x-values 

200 estimator=None, # Don't aggregate; plot raw data 

201 ) 

202 

203 def _plot_scatter_points(self, ax: Axes, idx: int, scatter_color: str) -> None: 

204 scatter_data: list[ScatterData] = [] 

205 if self._scatter_data: 

206 if isinstance(self._scatter_data, list): 

207 scatter_data = self._scatter_data 

208 else: 

209 scatter_data = [self._scatter_data] 

210 

211 scatter_values = scatter_data[idx] if scatter_data and idx < len(scatter_data) else None 

212 if scatter_values is None: 

213 return 

214 

215 # 2. Plot scatter points if provided (e.g. outliers) 

216 sns.scatterplot( 

217 x=scatter_values.indices, 

218 y=scatter_values.points, 

219 ax=ax, # Target the specific axes 

220 color=scatter_color, 

221 s=120, # Size 

222 label=scatter_values.label, 

223 zorder=5, # Ensures markers sit on top of the line 

224 marker="o", # Or any marker style you prefer 

225 ) 

226 

227 def _plot_missing_markers(self, ax: Axes, df: pd.DataFrame) -> None: 

228 # 1.b generate a scatter plot for missing indices. 

229 x_label = self._x_label 

230 y_label = self._y_label 

231 

232 missing_mask = df[y_label].isna() 

233 

234 if missing_mask.any(): 

235 # 2. Create a temporary Series where NaNs are filled by 

236 # drawing a straight line between the surrounding points 

237 y_interpolated = df[y_label].interpolate(method="linear") 

238 

239 # 3. Create a DataFrame containing ONLY the missing points 

240 # but with their new interpolated Y values 

241 df_missing = pd.DataFrame( 

242 {x_label: df.loc[missing_mask, x_label], y_label: y_interpolated[missing_mask]} 

243 ) 

244 

245 sns.scatterplot( 

246 data=df_missing, 

247 x=x_label, 

248 y=y_label, 

249 ax=ax, 

250 marker="x", 

251 color="black", 

252 s=50, 

253 linewidth=1.4, 

254 legend=False, 

255 zorder=3, # Ensures it sits on top of the line 

256 ) 

257 

258 def _set_plot_defaults(self, ax: Axes) -> None: 

259 # 1. Handle X-Axis Ticks and Rotation 

260 if self._x_values: 

261 # Set the numeric positions first, then the string labels 

262 ax.set_xticks(range(len(self._x_values))) 

263 ax.set_xticklabels(self._x_values, rotation=45) 

264 else: 

265 # Use tick_params for rotation if using default numeric indices 

266 ax.tick_params(axis="x", labelrotation=45) 

267 

268 # 2. Handle Scale and Labels via Match Case 

269 graph_title = self._graph_title 

270 x_label = self._x_label 

271 y_label = self._y_label 

272 

273 match self.plot_strategy: 

274 case PlotStrategy.LOG_SCALE: 

275 graph_title += " (Log Scale)" 

276 ax.set_yscale("log") 

277 ax.set_ylabel(f"{y_label} (Log Scale)") 

278 case PlotStrategy.NORMAL_SCALE: 

279 ax.set_yscale("linear") 

280 ax.set_ylabel(y_label) 

281 

282 ax.set_title(graph_title, fontsize=14) 

283 ax.set_xlabel(x_label) 

284 ax.set_ylabel(y_label) 

285 

286 # Note: lineplot sets labels automatically from the DataFrame columns 

287 ax.legend( 

288 loc="upper left", # Align the TOP LEFT of the legend... 

289 bbox_to_anchor=(1.02, 1), # ...to just right (1.02) and top (1) of the plot 

290 borderaxespad=0, # Remove padding between anchor and legend 

291 frameon=False, # Clean look 

292 fontsize="small", # Keeps the box compact 

293 )