Source code for tstrends.trend_labelling.binary_CTL

from dataclasses import dataclass
from typing import overload, Literal

from .base_labeller import BaseLabeller
from .label_scaling import Labels, extract_label_values


[docs] @dataclass class TrendState: """Holds the state for trend labelling process.""" current_high: float = 0.0 current_low: float = 0.0 curr_high_time: int = 0 curr_low_time: int = 0 current_direction: Labels = Labels.NEUTRAL extreme_point_idx: int = 0
[docs] def set_upwards_trend(self, price: float, time_idx: int) -> None: """Set the state for an upward trend. Args: price: The current price value time_idx: The current time index """ self.current_high = price self.curr_high_time = time_idx self.extreme_point_idx = time_idx self.current_direction = Labels.UP
[docs] def set_downwards_trend(self, price: float, time_idx: int) -> None: """Set the state for a downward trend. Args: price: The current price value time_idx: The current time index """ self.current_low = price self.curr_low_time = time_idx self.extreme_point_idx = time_idx self.current_direction = Labels.DOWN
[docs] class BinaryCTL(BaseLabeller): """Binary Continuous Trend Labeller. This class implements a binary trend labelling algorithm based on the paper by Wu, D., Wang, X., Su, J., Tang, B., & Wu, S. "A Labeling Method for Financial Time Series Prediction Based on Trends". The algorithm identifies two distinct states in price movements: - Upward trends (label: Labels.UP or 1) - Downward trends (label: Labels.DOWN or -1) Example: >>> labeller = BinaryCTL(omega=0.1) >>> prices = [1.0, 1.15, 1.2, 1.0] >>> labels = labeller.get_labels(prices) >>> print(labels) # [-1, 1, 1, -1] Note: The omega parameter determines how significant a price movement must be to be considered a trend change. Higher values result in fewer trend changes being identified. """
[docs] def __init__(self, omega: float) -> None: """ Initialize the continuous trend labeller. Args: omega (float): The proportion threshold parameter of the trend definition. """ if not isinstance(omega, float): raise TypeError("omega must be a float.") self.omega = omega self._state: TrendState = TrendState() self._labels: list[Labels] = list()
def _initialize_labels(self, length: int) -> None: """Initialize the labels list with neutral values. Args: length: Length of the time series """ self._labels = [Labels.NEUTRAL] * length def _update_labels(self, start_idx: int, end_idx: int, label_value: Labels) -> None: """Update a range of labels with the specified value. Args: start_idx: Start index (inclusive) end_idx: End index (inclusive) label_value: The label value to set """ for i in range(start_idx, end_idx + 1): self._labels[i] = label_value def _detect_initial_trend(self, time_series_list: list[float]) -> None: """ Detect the initial trend direction by finding the first significant price movement. Args: time_series_list: The input time series data """ first_price = time_series_list[0] for i, price in enumerate(time_series_list): if price > first_price * (1 + self.omega): self._state.set_upwards_trend(price, i) self._update_labels(0, i - 1, Labels.UP) return elif price < first_price * (1 - self.omega): self._state.set_downwards_trend(price, i) self._update_labels(0, i - 1, Labels.DOWN) return def _handle_uptrend(self, price: float, time_idx: int) -> None: """ Handle the uptrend case in continuous trend detection. Args: price: Current price value time_idx: Current time index """ if price > self._state.current_high: self._state.set_upwards_trend(price, time_idx) return elif price < self._state.current_high * (1 - self.omega): self._update_labels( self._state.curr_low_time + 1, self._state.curr_high_time, Labels.UP, ) self._state.set_downwards_trend(price, time_idx) def _handle_downtrend(self, price: float, time_idx: int) -> None: """ Handle the downtrend case in continuous trend detection. Args: price: Current price value time_idx: Current time index """ if price < self._state.current_low: self._state.set_downwards_trend(price, time_idx) elif price > self._state.current_low * (1 + self.omega): self._update_labels( self._state.curr_high_time + 1, self._state.curr_low_time, Labels.DOWN, ) self._state.set_upwards_trend(price, time_idx) @overload def get_labels( self, time_series_list: list[float], return_labels_as_int: Literal[True] = True ) -> list[int]: ... @overload def get_labels( self, time_series_list: list[float], return_labels_as_int: Literal[False] ) -> list[Labels]: ...
[docs] def get_labels( self, time_series_list: list[float], return_labels_as_int: bool = True ) -> list[int] | list[Labels]: """Auto-labels a price time series based on the provided algorithm. Parameters ---------- time_series_list : list[float] The original time series data X = [x1, x2, ..., xN] return_labels_as_int : bool, optional If True, returns integer labels (-1, 1), if False returns Labels enum values. Defaults to True. Returns ------- Union[list[int], list[Labels]] The label vector Y. If return_labels_as_int is True, returns integers (-1, 1), otherwise returns Labels enum values (Labels.DOWN, Labels.UP). """ self._verify_time_series(time_series_list) # Initialize labels and state self._initialize_labels(len(time_series_list)) # Detect initial trend direction self._detect_initial_trend(time_series_list) # Continue trend detection for the rest of the series for i in range(self._state.extreme_point_idx + 1, len(time_series_list)): if self._state.current_direction == Labels.UP: self._handle_uptrend(time_series_list[i], i) elif self._state.current_direction == Labels.DOWN: self._handle_downtrend(time_series_list[i], i) # Label the last interval if self._state.curr_low_time != self._state.curr_high_time: self._update_labels( min(self._state.curr_low_time, self._state.curr_high_time) + 1, len(time_series_list) - 1, self._state.current_direction, ) return ( extract_label_values(self._labels) if return_labels_as_int else self._labels )