Coverage for src/sensai/util/aggregation.py: 44%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import collections 

2from typing import Hashable, Dict, Optional 

3 

4from .string import ToStringMixin 

5 

6 

7class RelativeFrequencyCounter(ToStringMixin): 

8 """ 

9 Counts the absolute and relative frequency of an event 

10 """ 

11 def __init__(self): 

12 self.num_total = 0 

13 self.num_relevant = 0 

14 

15 def count(self, is_relevant_event) -> None: 

16 """ 

17 Adds to the count. 

18 The nominator is incremented only if we are counting a relevant event. 

19 The denominator is always incremented. 

20 

21 :param is_relevant_event: whether we are counting a relevant event 

22 """ 

23 self.num_total += 1 

24 if is_relevant_event: 

25 self.num_relevant += 1 

26 

27 def _tostring_object_info(self): 

28 info = f"{self.num_relevant}/{self.num_total}" 

29 if self.num_total > 0: 

30 info += f", {100 * self.num_relevant / self.num_total:.2f}%" 

31 return info 

32 

33 def add(self, relative_frequency_counter: "RelativeFrequencyCounter") -> None: 

34 """ 

35 Adds the counts of the given counter to this object 

36 

37 :param relative_frequency_counter: the counter whose data to add 

38 """ 

39 self.num_total += relative_frequency_counter.num_total 

40 self.num_relevant += relative_frequency_counter.num_relevant 

41 

42 def get_relative_frequency(self) -> Optional[float]: 

43 """ 

44 :return: the relative frequency (between 0 and 1) or None if nothing was counted (0 events considered) 

45 """ 

46 if self.num_total == 0: 

47 return None 

48 return self.num_relevant / self.num_total 

49 

50 

51class DistributionCounter(ToStringMixin): 

52 """ 

53 Supports the counting of the frequencies with which (mutually exclusive) events occur 

54 """ 

55 def __init__(self): 

56 self.counts = collections.defaultdict(self._zero) 

57 self.total_count = 0 

58 

59 @staticmethod 

60 def _zero(): 

61 return 0 

62 

63 def count(self, event: Hashable) -> None: 

64 """ 

65 Increments the count of the given event 

66 

67 :param event: the event/key whose count to increment, which must be hashable 

68 """ 

69 self.total_count += 1 

70 self.counts[event] += 1 

71 

72 def get_distribution(self) -> Dict[Hashable, float]: 

73 """ 

74 :return: a dictionary mapping events (as previously passed to count) to their relative frequencies 

75 """ 

76 return {k: v/self.total_count for k, v in self.counts.items()} 

77 

78 def _tostring_object_info(self): 

79 return ", ".join([f"{str(k)}: {v} ({v/self.total_count:.3f})" for k, v in self.counts.items()]) 

80 

81 

82class WeightedMean(ToStringMixin): 

83 """ 

84 Computes a weighted mean of values 

85 """ 

86 def __init__(self): 

87 self.weighted_value_sum = 0 

88 self.weight_sum = 0 

89 

90 def _tostring_object_info(self) -> str: 

91 return f"{self.weighted_value_sum / self.weight_sum}" 

92 

93 def add(self, value, weight=1) -> None: 

94 """ 

95 Adds the given value with the given weight to the calculation 

96 

97 :param value: the value 

98 :param weight: the weight with which to consider the value 

99 """ 

100 self.weighted_value_sum += value * weight 

101 self.weight_sum += weight 

102 

103 def get_weighted_mean(self): 

104 """ 

105 :return: the weighted mean of all values that have been added 

106 """ 

107 return self.weighted_value_sum / self.weight_sum