Coverage for src/sensai/util/io.py: 26%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import io 

2import logging 

3import os 

4from typing import Sequence, Optional, Tuple, List, Any, TYPE_CHECKING 

5 

6if TYPE_CHECKING: 

7 from matplotlib import pyplot as plt 

8 import pandas as pd 

9 

10log = logging.getLogger(__name__) 

11 

12 

13class ResultWriter: 

14 log = log.getChild(__qualname__) 

15 

16 def __init__(self, result_dir, filename_prefix="", enabled: bool = True, close_figures: bool = False): 

17 """ 

18 :param result_dir: 

19 :param filename_prefix: 

20 :param enabled: whether the result writer is enabled; if it is not, it will create neither files nor directories 

21 :param close_figures: whether to close figures that are passed by default 

22 """ 

23 self.result_dir = result_dir 

24 self.filename_prefix = filename_prefix 

25 self.enabled = enabled 

26 self.close_figures_default = close_figures 

27 if self.enabled: 

28 os.makedirs(result_dir, exist_ok=True) 

29 

30 def child_with_added_prefix(self, prefix: str) -> "ResultWriter": 

31 """ 

32 Creates a derived result writer with an added prefix, i.e. the given prefix is appended to this 

33 result writer's prefix 

34 

35 :param prefix: the prefix to append 

36 :return: a new writer instance 

37 """ 

38 return ResultWriter(self.result_dir, filename_prefix=self.filename_prefix + prefix, enabled=self.enabled, 

39 close_figures=self.close_figures_default) 

40 

41 def child_for_subdirectory(self, dir_name: str) -> "ResultWriter": 

42 result_dir = os.path.join(self.result_dir, dir_name) 

43 return ResultWriter(result_dir, filename_prefix=self.filename_prefix, enabled=self.enabled, 

44 close_figures=self.close_figures_default) 

45 

46 def path(self, filename_suffix: str, extension_to_add=None, valid_other_extensions: Optional[Sequence[str]] = None) -> str: 

47 """ 

48 :param filename_suffix: the suffix to add (which may or may not already include a file extension) 

49 :param extension_to_add: if not None, the file extension to add (without the leading ".") unless 

50 the extension to add or one of the extenions in valid_extensions is already present 

51 :param valid_other_extensions: a sequence of valid other extensions (without the "."), only 

52 relevant if extensionToAdd is specified 

53 :return: the full path 

54 """ 

55 # replace forbidden characters 

56 filename_suffix = filename_suffix.replace(">=", "gte").replace(">", "gt") 

57 

58 if extension_to_add is not None: 

59 add_ext = True 

60 valid_extensions = set(valid_other_extensions) if valid_other_extensions is not None else set() 

61 valid_extensions.add(extension_to_add) 

62 if valid_extensions is not None: 

63 for ext in valid_extensions: 

64 if filename_suffix.endswith("." + ext): 

65 add_ext = False 

66 break 

67 if add_ext: 

68 filename_suffix += "." + extension_to_add 

69 path = os.path.join(self.result_dir, f"{self.filename_prefix}{filename_suffix}") 

70 return path 

71 

72 def write_text_file(self, filename_suffix: str, content: str): 

73 p = self.path(filename_suffix, extension_to_add="txt") 

74 if self.enabled: 

75 self.log.info(f"Saving text file {p}") 

76 with open(p, "w") as f: 

77 f.write(content) 

78 return p 

79 

80 def write_text_file_lines(self, filename_suffix: str, lines: List[str]): 

81 p = self.path(filename_suffix, extension_to_add="txt") 

82 if self.enabled: 

83 self.log.info(f"Saving text file {p}") 

84 write_text_file_lines(lines, p) 

85 return p 

86 

87 def write_data_frame_text_file(self, filename_suffix: str, df: "pd.DataFrame"): 

88 p = self.path(filename_suffix, extension_to_add="df.txt", valid_other_extensions="txt") 

89 if self.enabled: 

90 self.log.info(f"Saving data frame text file {p}") 

91 with open(p, "w") as f: 

92 f.write(df.to_string()) 

93 return p 

94 

95 def write_data_frame_csv_file(self, filename_suffix: str, df: "pd.DataFrame", index=True, header=True): 

96 p = self.path(filename_suffix, extension_to_add="csv") 

97 if self.enabled: 

98 self.log.info(f"Saving data frame CSV file {p}") 

99 df.to_csv(p, index=index, header=header) 

100 return p 

101 

102 def write_figure(self, filename_suffix: str, fig: "plt.Figure", close_figure: Optional[bool] = None): 

103 """ 

104 :param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"} 

105 :param fig: the figure to save 

106 :param close_figure: whether to close the figure after having saved it; if None, use default passed at construction 

107 :return: the path to the file that was written (or would have been written if the writer was enabled) 

108 """ 

109 from matplotlib import pyplot as plt 

110 p = self.path(filename_suffix, extension_to_add="png", valid_other_extensions=("jpg",)) 

111 if self.enabled: 

112 self.log.info(f"Saving figure {p}") 

113 fig.savefig(p, bbox_inches="tight") 

114 must_close_figure = close_figure if close_figure is not None else self.close_figures_default 

115 if must_close_figure: 

116 plt.close(fig) 

117 return p 

118 

119 def write_figures(self, figures: Sequence[Tuple[str, "plt.Figure"]], close_figures=False): 

120 for name, fig in figures: 

121 self.write_figure(name, fig, close_figure=close_figures) 

122 

123 def write_pickle(self, filename_suffix: str, obj: Any): 

124 from .pickle import dump_pickle 

125 p = self.path(filename_suffix, extension_to_add="pickle") 

126 if self.enabled: 

127 self.log.info(f"Saving pickle {p}") 

128 dump_pickle(obj, p) 

129 return p 

130 

131 

132def write_text_file_lines(lines: List[str], path): 

133 """ 

134 :param lines: the lines to write (without a trailing newline, which will be added) 

135 :param path: the path of the text file to write to 

136 """ 

137 with open(path, "w") as f: 

138 for line in lines: 

139 f.write(line) 

140 f.write("\n") 

141 

142 

143def read_text_file_lines(path, strip=True, skip_empty=True) -> List[str]: 

144 """ 

145 :param path: the path of the text file to read from 

146 :param strip: whether to strip each line, removing whitespace/newline characters 

147 :param skip_empty: whether to skip any lines that are empty (after stripping) 

148 :return: the list of lines 

149 """ 

150 lines = [] 

151 with open(path, "r") as f: 

152 for line in f.readlines(): 

153 if strip: 

154 line = line.strip() 

155 if not skip_empty or line != "": 

156 lines.append(line) 

157 return lines 

158 

159 

160def is_s3_path(path: str): 

161 return path.startswith("s3://") 

162 

163 

164class S3Object: 

165 def __init__(self, path): 

166 assert is_s3_path(path) 

167 self.path = path 

168 self.bucket, self.object = self.path[5:].split("/", 1) 

169 

170 class OutputFile: 

171 def __init__(self, s3_object: "S3Object"): 

172 self.s3Object = s3_object 

173 self.buffer = io.BytesIO() 

174 

175 def write(self, obj: bytes): 

176 self.buffer.write(obj) 

177 

178 def __enter__(self): 

179 return self 

180 

181 def __exit__(self, exc_type, exc_val, exc_tb): 

182 self.s3Object.put(self.buffer.getvalue()) 

183 

184 def get_file_content(self): 

185 return self._get_s3_object().get()['Body'].read() 

186 

187 def open_file(self, mode): 

188 assert mode in ("wb", "rb") 

189 if mode == "rb": 

190 content = self.get_file_content() 

191 return io.BytesIO(content) 

192 

193 elif mode == "wb": 

194 return self.OutputFile(self) 

195 

196 else: 

197 raise ValueError(mode) 

198 

199 def put(self, obj: bytes): 

200 self._get_s3_object().put(Body=obj) 

201 

202 def _get_s3_object(self): 

203 import boto3 

204 session = boto3.session.Session(profile_name=os.getenv("AWS_PROFILE")) 

205 s3 = session.resource("s3") 

206 return s3.Bucket(self.bucket).Object(self.object) 

207 

208 

209def create_path(root: str, *path_elems: str, is_dir: bool, make_dirs: bool = False) -> str: 

210 path = os.path.join(root, *path_elems) 

211 if make_dirs: 

212 dir_path = path if is_dir else os.path.dirname(path) 

213 os.makedirs(dir_path, exist_ok=True) 

214 return path 

215 

216 

217def create_file_path(root, *path_elems, make_dirs: bool = False) -> str: 

218 return create_path(root, *path_elems, is_dir=False, make_dirs=make_dirs) 

219 

220 

221def create_dir_path(root, *path_elems, make_dirs: bool = False) -> str: 

222 return create_path(root, *path_elems, is_dir=True, make_dirs=make_dirs)