Source code for tstrends.label_tuning.remaining_value_tuner

"""
Remaining value change tuner for trend labels.

This module provides a tuner that enhances trend labels with information about
the remaining value change until the end of a continuous trend interval.
"""

from itertools import pairwise
from typing import Any
import numpy as np

from tstrends.label_tuning.base import BaseLabelTuner, BasePostprocessor


[docs] class RemainingValueTuner(BaseLabelTuner): """ A tuner that calculates the remaining value change in intervals of continuous labels. For each point in the time series, it calculates the absolute value change from the current position to the end of the current trend interval. Attributes: postprocessors (Optional[list[BasePostprocessor]], optional): Steps applied in order after computing remaining values (e.g. :class:`~tstrends.label_tuning.shifting.Shifter`, :class:`~tstrends.label_tuning.filtering.ForwardLookingFilter`, smoothers). Pass ``time_series`` and ``labels`` into each step for filters; smoothers and shifters ignore context as documented on each class. """
[docs] def __init__(self, postprocessors: list[BasePostprocessor] | None = None): """ Initialize the RemainingValueTuner. """ self.postprocessors = postprocessors or [] for postprocessor in self.postprocessors: if not isinstance(postprocessor, BasePostprocessor): raise TypeError( f"postprocessors must be a list of BasePostprocessor, got {type(postprocessor)}" )
[docs] def tune( self, time_series: list[float], labels: list[int], enforce_monotonicity: bool = False, normalize_over_interval: bool = False, **kwargs: Any, ) -> list[float]: """ Tune trend labels to provide information about remaining value change. For each point in the time series, calculates how much the value will change until the end of the current trend interval. The sign of the result matches the original label. Args: time_series (list[float]): The price series used for trend detection. labels (list[int]): The original trend labels (-1, 1) or (-1, 0, 1). enforce_monotonicity (bool, optional): If True, the labels in each interval will not reverse on uncaptured countertrends. normalize_over_interval (bool, optional): If True, the remaining value change will be normalized over the interval. Returns: list[float]: Tuned up labels with the list of postprocessors applied. """ self._verify_inputs(time_series, labels) # Convert inputs to numpy arrays for vectorized operations ts_array = np.array(time_series) labels_array = np.array(labels) intervals = list(pairwise(self._find_trend_intervals(labels))) result = np.zeros(len(time_series)) for start, end in intervals: if labels_array[start] == 0: continue interval_slice = slice(start, end) extremme_value = ( max(ts_array[start : end + 1]) if labels_array[start] == 1 else min(ts_array[start : end + 1]) ) if enforce_monotonicity: cum_func = ( np.minimum.accumulate if labels_array[end - 1] == -1 else np.maximum.accumulate ) reference_values = cum_func(ts_array[interval_slice]) else: reference_values = ts_array[interval_slice] interval_values = extremme_value - reference_values if normalize_over_interval: interval_values = self._normalize_values(interval_values) result[interval_slice] = interval_values if self.postprocessors: for step in self.postprocessors: result = step.process(result, time_series, labels) return result.tolist()
def _find_trend_intervals(self, labels: list[int]) -> list[int]: """ Find the first index of each continuous label interval in the label series. Args: labels (list[int]): The original trend labels (-1, 1) or (-1, 0, 1). Returns: list[int]: List of indices where each interval starts, including 0. """ # Start with 0 as first interval always starts at beginning change_indices = [0] # Add indices where values change (start of new intervals) change_indices.extend( i + 1 for i in range(len(labels) - 1) if labels[i] != labels[i + 1] ) return change_indices + [len(labels) - 1] def _normalize_values(self, values: np.ndarray) -> np.ndarray: """ Normalize values to a [-1, 1] range while preserving the sign. Args: values (np.ndarray): Array of values to normalize. Returns: np.ndarray: Normalized values in [-1, 1] range. """ if not values.any(): # More idiomatic check for all zeros return values max_abs = np.abs(values).max() # More concise max absolute value return np.clip(values / max_abs if max_abs > 0 else values, -1.0, 1.0)