Source code for tstrends.trend_labelling.ternary_CTL

from typing import overload, Literal

from .base_labeller import BaseLabeller
from .label_scaling import Labels, extract_label_values


[docs] class TernaryCTL(BaseLabeller): """Ternary Continuous Trend Labeller. This class implements an adaptation of the Continuous Trend Labeller (CTL) algorithm to a three-state labelling approach. A somewhat not so different approach is proposed in the second pass of the labelling algorithm outlined in the paper by Dezhkam et al. "A Bayesian-based classification framework for financial time series trend prediction." The algorithm identifies three distinct states in price movements: - Upward trends (label: Labels.UP or 1) - Neutral trends (label: Labels.NEUTRAL or 0) - Downward trends (label: Labels.DOWN or -1) Example: >>> labeller = TernaryCTL(marginal_change_thres=0.1, window_size=3) >>> prices = [1.0, 1.15, 1.2, 1.18, 1.0] >>> labels = labeller.get_labels(prices) >>> print(labels) # [-1, 1, 1, 0, -1] Note: The window_size parameter helps prevent the algorithm from getting stuck in prolonged sideways movements by forcing a state transition to NEUTRAL after the window is exceeded. It can artificially cut ongoing trends short, so it must be set carefully. """
[docs] def __init__(self, marginal_change_thres: float, window_size: int) -> None: """ Initialize the ternary trend labeller. Args: marginal_change_thres (float): The threshold for significant price movements as a percentage. window_size (int): The maximum window to look for trend confirmation before resetting state. """ if not isinstance(marginal_change_thres, float): raise TypeError("marginal_change_thres must be a float.") if not isinstance(window_size, int): raise TypeError("window_size must be an integer.") self.marginal_change_thres = marginal_change_thres self.window_size = window_size self.labels: list[Labels] = list()
def _get_first_label(self, time_series_list: list[float]) -> list[Labels]: """ Find upward trends in a time series of closing prices. This is the first step of the ternary trend labelling algorithm. Args: time_series_list (list[float]): List of closing prices. """ if time_series_list[0] > time_series_list[1]: return [Labels.DOWN] return [Labels.UP] def _is_significant_upward_move(self, current: float, reference: float) -> bool: """ Check if a current price is a significant upward move compared to a reference price. Args: current (float): The current price. reference (float): The reference price. Returns: bool: True if the current price is a significant upward move, False otherwise. """ return current >= reference * (1 + self.marginal_change_thres) def _is_significant_downward_move(self, current: float, reference: float) -> bool: """ Check if a current price is a significant downward move compared to a reference price. Args: current (float): The current price. reference (float): The reference price. Returns: bool: True if the current price is a significant downward move, False otherwise. """ return current <= reference * (1 - self.marginal_change_thres) def _generate_label_values(self) -> list[int]: """Convert Labels enum to their integer values""" return [label.value for label in self.labels] def _right_pad_labels(self, total_length: int) -> None: """ Right pad the labels list by duplicating the last element. Args: total_length (int): The target length of the padded list. Returns: list[Labels]: Padded list of label values with length equal to target_length. """ if len(self.labels) == 0: return None self.labels += [self.labels[-1]] * (total_length - len(self.labels)) def _update_labels(self, trend_start: int, current_idx: int, label: Labels) -> None: """ Update the labels list with a new label value. Args: trend_start (int): The starting index of the trend. current_idx (int): The current index of the price. label (Labels): The new label value to be added. """ self.labels += [label] * (current_idx - trend_start) def _has_price_crossed_reference_price( self, previous_price: float, current_price: float, reference_price: float ) -> bool: """ Check if the price has crossed the reference price. """ return (previous_price - reference_price) * ( current_price - reference_price ) <= 0 @overload def get_labels( self, time_series_list: list[float], return_labels_as_int: Literal[True] = True ) -> list[int]: ... @overload def get_labels( self, time_series_list: list[float], return_labels_as_int: Literal[False] ) -> list[Labels]: ...
[docs] def get_labels( self, time_series_list: list[float], return_labels_as_int: bool = True ) -> list[int] | list[Labels]: """Labels trends in a time series of closing prices using a ternary classification approach. The method identifies three distinct states in price movements: - Upward trends (label: Labels.UP) - Downward trends (label: Labels.DOWN) - No-action (label: Labels.NEUTRAL) The algorithm uses two key parameters: - marginal_change_thres: Defines the threshold for significant price movements as a percentage - window_size: Maximum window to look for trend confirmation before resetting state The labeling process works by tracking the current state and transitioning between states when price movements exceed thresholds, while using the window_size parameter to avoid getting stuck in prolonged sideways movements. Parameters ---------- time_series_list : list[float] List of closing prices. return_labels_as_int : bool, optional If True, returns integer labels (-1, 0, 1), if False returns Labels enum values. Defaults to True. Returns ------- Union[list[int], list[Labels]] List of labels. If return_labels_as_int is True, returns integers (-1, 0, 1), otherwise returns Labels enum values. """ self._verify_time_series(time_series_list) # Initialize labels self.labels = self._get_first_label(time_series_list) # Initialize trend start index trend_start = 0 # Iterate over prices starting from the second price for current_idx, current_price in enumerate(time_series_list[1:], start=1): reference_price = time_series_list[trend_start] window_exceeded = current_idx - trend_start > self.window_size match self.labels[-1]: case Labels.UP: # Upward trend if current_price > reference_price: self._update_labels(trend_start, current_idx, Labels.UP) elif self._is_significant_downward_move( current_price, reference_price ): self._update_labels(trend_start, current_idx, Labels.DOWN) elif window_exceeded: self._update_labels(trend_start, current_idx, Labels.NEUTRAL) else: continue trend_start = current_idx case Labels.DOWN: # Downward trend if current_price < reference_price: self._update_labels(trend_start, current_idx, Labels.DOWN) elif self._is_significant_upward_move( current_price, reference_price ): self._update_labels(trend_start, current_idx, Labels.UP) elif window_exceeded: self._update_labels(trend_start, current_idx, Labels.NEUTRAL) else: continue trend_start = current_idx case Labels.NEUTRAL: # No trend if self._is_significant_upward_move(current_price, reference_price): self._update_labels(trend_start, current_idx, Labels.UP) elif self._is_significant_downward_move( current_price, reference_price ): self._update_labels(trend_start, current_idx, Labels.DOWN) elif window_exceeded: self._update_labels(trend_start, current_idx, Labels.NEUTRAL) else: continue trend_start = current_idx self._right_pad_labels(len(time_series_list)) return ( extract_label_values(self.labels) if return_labels_as_int else self.labels )