Coverage for src/sensai/util/io.py: 26%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import io 

2import logging 

3import os 

4from typing import Sequence, Optional, Tuple, List, Any, TYPE_CHECKING 

5 

6if TYPE_CHECKING: 

7 from matplotlib import pyplot as plt 

8 import pandas as pd 

9 

10log = logging.getLogger(__name__) 

11 

12 

13class ResultWriter: 

14 log = log.getChild(__qualname__) 

15 

16 def __init__(self, result_dir, filename_prefix="", enabled: bool = True, close_figures: bool = False): 

17 """ 

18 :param result_dir: 

19 :param filename_prefix: 

20 :param enabled: whether the result writer is enabled; if it is not, it will create neither files nor directories 

21 :param close_figures: whether to close figures that are passed by default 

22 """ 

23 self.result_dir = result_dir 

24 self.filename_prefix = filename_prefix 

25 self.enabled = enabled 

26 self.close_figures_default = close_figures 

27 if self.enabled: 

28 os.makedirs(result_dir, exist_ok=True) 

29 

30 def child_with_added_prefix(self, prefix: str) -> "ResultWriter": 

31 """ 

32 Creates a derived result writer with an added prefix, i.e. the given prefix is appended to this 

33 result writer's prefix 

34 

35 :param prefix: the prefix to append 

36 :return: a new writer instance 

37 """ 

38 return ResultWriter(self.result_dir, filename_prefix=self.filename_prefix + prefix, enabled=self.enabled, 

39 close_figures=self.close_figures_default) 

40 

41 def child_for_subdirectory(self, dir_name: str) -> "ResultWriter": 

42 result_dir = os.path.join(self.result_dir, dir_name) 

43 return ResultWriter(result_dir, filename_prefix=self.filename_prefix, enabled=self.enabled, 

44 close_figures=self.close_figures_default) 

45 

46 def path(self, filename_suffix: str, extension_to_add=None, valid_other_extensions: Optional[Sequence[str]] = None) -> str: 

47 """ 

48 :param filename_suffix: the suffix to add (which may or may not already include a file extension) 

49 :param extension_to_add: if not None, the file extension to add (without the leading ".") unless 

50 the extension to add or one of the extenions in valid_extensions is already present 

51 :param valid_other_extensions: a sequence of valid other extensions (without the "."), only 

52 relevant if extensionToAdd is specified 

53 :return: the full path 

54 """ 

55 if extension_to_add is not None: 

56 add_ext = True 

57 valid_extensions = set(valid_other_extensions) if valid_other_extensions is not None else set() 

58 valid_extensions.add(extension_to_add) 

59 if valid_extensions is not None: 

60 for ext in valid_extensions: 

61 if filename_suffix.endswith("." + ext): 

62 add_ext = False 

63 break 

64 if add_ext: 

65 filename_suffix += "." + extension_to_add 

66 path = os.path.join(self.result_dir, f"{self.filename_prefix}{filename_suffix}") 

67 return path 

68 

69 def write_text_file(self, filename_suffix: str, content: str): 

70 p = self.path(filename_suffix, extension_to_add="txt") 

71 if self.enabled: 

72 self.log.info(f"Saving text file {p}") 

73 with open(p, "w") as f: 

74 f.write(content) 

75 return p 

76 

77 def write_text_file_lines(self, filename_suffix: str, lines: List[str]): 

78 p = self.path(filename_suffix, extension_to_add="txt") 

79 if self.enabled: 

80 self.log.info(f"Saving text file {p}") 

81 write_text_file_lines(lines, p) 

82 return p 

83 

84 def write_data_frame_text_file(self, filename_suffix: str, df: "pd.DataFrame"): 

85 p = self.path(filename_suffix, extension_to_add="df.txt", valid_other_extensions="txt") 

86 if self.enabled: 

87 self.log.info(f"Saving data frame text file {p}") 

88 with open(p, "w") as f: 

89 f.write(df.to_string()) 

90 return p 

91 

92 def write_data_frame_csv_file(self, filename_suffix: str, df: "pd.DataFrame", index=True, header=True): 

93 p = self.path(filename_suffix, extension_to_add="csv") 

94 if self.enabled: 

95 self.log.info(f"Saving data frame CSV file {p}") 

96 df.to_csv(p, index=index, header=header) 

97 return p 

98 

99 def write_figure(self, filename_suffix: str, fig: "plt.Figure", close_figure: Optional[bool] = None): 

100 """ 

101 :param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"} 

102 :param fig: the figure to save 

103 :param close_figure: whether to close the figure after having saved it; if None, use default passed at construction 

104 :return: the path to the file that was written (or would have been written if the writer was enabled) 

105 """ 

106 from matplotlib import pyplot as plt 

107 p = self.path(filename_suffix, extension_to_add="png", valid_other_extensions=("jpg",)) 

108 if self.enabled: 

109 self.log.info(f"Saving figure {p}") 

110 fig.savefig(p, bbox_inches="tight") 

111 must_close_figure = close_figure if close_figure is not None else self.close_figures_default 

112 if must_close_figure: 

113 plt.close(fig) 

114 return p 

115 

116 def write_figures(self, figures: Sequence[Tuple[str, "plt.Figure"]], close_figures=False): 

117 for name, fig in figures: 

118 self.write_figure(name, fig, close_figure=close_figures) 

119 

120 def write_pickle(self, filename_suffix: str, obj: Any): 

121 from .pickle import dump_pickle 

122 p = self.path(filename_suffix, extension_to_add="pickle") 

123 if self.enabled: 

124 self.log.info(f"Saving pickle {p}") 

125 dump_pickle(obj, p) 

126 return p 

127 

128 

129def write_text_file_lines(lines: List[str], path): 

130 """ 

131 :param lines: the lines to write (without a trailing newline, which will be added) 

132 :param path: the path of the text file to write to 

133 """ 

134 with open(path, "w") as f: 

135 for line in lines: 

136 f.write(line) 

137 f.write("\n") 

138 

139 

140def read_text_file_lines(path, strip=True, skip_empty=True) -> List[str]: 

141 """ 

142 :param path: the path of the text file to read from 

143 :param strip: whether to strip each line, removing whitespace/newline characters 

144 :param skip_empty: whether to skip any lines that are empty (after stripping) 

145 :return: the list of lines 

146 """ 

147 lines = [] 

148 with open(path, "r") as f: 

149 for line in f.readlines(): 

150 if strip: 

151 line = line.strip() 

152 if not skip_empty or line != "": 

153 lines.append(line) 

154 return lines 

155 

156 

157def is_s3_path(path: str): 

158 return path.startswith("s3://") 

159 

160 

161class S3Object: 

162 def __init__(self, path): 

163 assert is_s3_path(path) 

164 self.path = path 

165 self.bucket, self.object = self.path[5:].split("/", 1) 

166 

167 class OutputFile: 

168 def __init__(self, s3_object: "S3Object"): 

169 self.s3Object = s3_object 

170 self.buffer = io.BytesIO() 

171 

172 def write(self, obj: bytes): 

173 self.buffer.write(obj) 

174 

175 def __enter__(self): 

176 return self 

177 

178 def __exit__(self, exc_type, exc_val, exc_tb): 

179 self.s3Object.put(self.buffer.getvalue()) 

180 

181 def get_file_content(self): 

182 return self._get_s3_object().get()['Body'].read() 

183 

184 def open_file(self, mode): 

185 assert mode in ("wb", "rb") 

186 if mode == "rb": 

187 content = self.get_file_content() 

188 return io.BytesIO(content) 

189 

190 elif mode == "wb": 

191 return self.OutputFile(self) 

192 

193 else: 

194 raise ValueError(mode) 

195 

196 def put(self, obj: bytes): 

197 self._get_s3_object().put(Body=obj) 

198 

199 def _get_s3_object(self): 

200 import boto3 

201 session = boto3.session.Session(profile_name=os.getenv("AWS_PROFILE")) 

202 s3 = session.resource("s3") 

203 return s3.Bucket(self.bucket).Object(self.object)