Coverage for src/sensai/util/io.py: 26%
141 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import io
2import logging
3import os
4from typing import Sequence, Optional, Tuple, List, Any, TYPE_CHECKING
6if TYPE_CHECKING:
7 from matplotlib import pyplot as plt
8 import pandas as pd
10log = logging.getLogger(__name__)
13class ResultWriter:
14 log = log.getChild(__qualname__)
16 def __init__(self, result_dir, filename_prefix="", enabled: bool = True, close_figures: bool = False):
17 """
18 :param result_dir:
19 :param filename_prefix:
20 :param enabled: whether the result writer is enabled; if it is not, it will create neither files nor directories
21 :param close_figures: whether to close figures that are passed by default
22 """
23 self.result_dir = result_dir
24 self.filename_prefix = filename_prefix
25 self.enabled = enabled
26 self.close_figures_default = close_figures
27 if self.enabled:
28 os.makedirs(result_dir, exist_ok=True)
30 def child_with_added_prefix(self, prefix: str) -> "ResultWriter":
31 """
32 Creates a derived result writer with an added prefix, i.e. the given prefix is appended to this
33 result writer's prefix
35 :param prefix: the prefix to append
36 :return: a new writer instance
37 """
38 return ResultWriter(self.result_dir, filename_prefix=self.filename_prefix + prefix, enabled=self.enabled,
39 close_figures=self.close_figures_default)
41 def child_for_subdirectory(self, dir_name: str) -> "ResultWriter":
42 result_dir = os.path.join(self.result_dir, dir_name)
43 return ResultWriter(result_dir, filename_prefix=self.filename_prefix, enabled=self.enabled,
44 close_figures=self.close_figures_default)
46 def path(self, filename_suffix: str, extension_to_add=None, valid_other_extensions: Optional[Sequence[str]] = None) -> str:
47 """
48 :param filename_suffix: the suffix to add (which may or may not already include a file extension)
49 :param extension_to_add: if not None, the file extension to add (without the leading ".") unless
50 the extension to add or one of the extenions in valid_extensions is already present
51 :param valid_other_extensions: a sequence of valid other extensions (without the "."), only
52 relevant if extensionToAdd is specified
53 :return: the full path
54 """
55 # replace forbidden characters
56 filename_suffix = filename_suffix.replace(">=", "gte").replace(">", "gt")
58 if extension_to_add is not None:
59 add_ext = True
60 valid_extensions = set(valid_other_extensions) if valid_other_extensions is not None else set()
61 valid_extensions.add(extension_to_add)
62 if valid_extensions is not None:
63 for ext in valid_extensions:
64 if filename_suffix.endswith("." + ext):
65 add_ext = False
66 break
67 if add_ext:
68 filename_suffix += "." + extension_to_add
69 path = os.path.join(self.result_dir, f"{self.filename_prefix}{filename_suffix}")
70 return path
72 def write_text_file(self, filename_suffix: str, content: str):
73 p = self.path(filename_suffix, extension_to_add="txt")
74 if self.enabled:
75 self.log.info(f"Saving text file {p}")
76 with open(p, "w") as f:
77 f.write(content)
78 return p
80 def write_text_file_lines(self, filename_suffix: str, lines: List[str]):
81 p = self.path(filename_suffix, extension_to_add="txt")
82 if self.enabled:
83 self.log.info(f"Saving text file {p}")
84 write_text_file_lines(lines, p)
85 return p
87 def write_data_frame_text_file(self, filename_suffix: str, df: "pd.DataFrame"):
88 p = self.path(filename_suffix, extension_to_add="df.txt", valid_other_extensions="txt")
89 if self.enabled:
90 self.log.info(f"Saving data frame text file {p}")
91 with open(p, "w") as f:
92 f.write(df.to_string())
93 return p
95 def write_data_frame_csv_file(self, filename_suffix: str, df: "pd.DataFrame", index=True, header=True):
96 p = self.path(filename_suffix, extension_to_add="csv")
97 if self.enabled:
98 self.log.info(f"Saving data frame CSV file {p}")
99 df.to_csv(p, index=index, header=header)
100 return p
102 def write_figure(self, filename_suffix: str, fig: "plt.Figure", close_figure: Optional[bool] = None):
103 """
104 :param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"}
105 :param fig: the figure to save
106 :param close_figure: whether to close the figure after having saved it; if None, use default passed at construction
107 :return: the path to the file that was written (or would have been written if the writer was enabled)
108 """
109 from matplotlib import pyplot as plt
110 p = self.path(filename_suffix, extension_to_add="png", valid_other_extensions=("jpg",))
111 if self.enabled:
112 self.log.info(f"Saving figure {p}")
113 fig.savefig(p, bbox_inches="tight")
114 must_close_figure = close_figure if close_figure is not None else self.close_figures_default
115 if must_close_figure:
116 plt.close(fig)
117 return p
119 def write_figures(self, figures: Sequence[Tuple[str, "plt.Figure"]], close_figures=False):
120 for name, fig in figures:
121 self.write_figure(name, fig, close_figure=close_figures)
123 def write_pickle(self, filename_suffix: str, obj: Any):
124 from .pickle import dump_pickle
125 p = self.path(filename_suffix, extension_to_add="pickle")
126 if self.enabled:
127 self.log.info(f"Saving pickle {p}")
128 dump_pickle(obj, p)
129 return p
132def write_text_file_lines(lines: List[str], path):
133 """
134 :param lines: the lines to write (without a trailing newline, which will be added)
135 :param path: the path of the text file to write to
136 """
137 with open(path, "w") as f:
138 for line in lines:
139 f.write(line)
140 f.write("\n")
143def read_text_file_lines(path, strip=True, skip_empty=True) -> List[str]:
144 """
145 :param path: the path of the text file to read from
146 :param strip: whether to strip each line, removing whitespace/newline characters
147 :param skip_empty: whether to skip any lines that are empty (after stripping)
148 :return: the list of lines
149 """
150 lines = []
151 with open(path, "r") as f:
152 for line in f.readlines():
153 if strip:
154 line = line.strip()
155 if not skip_empty or line != "":
156 lines.append(line)
157 return lines
160def is_s3_path(path: str):
161 return path.startswith("s3://")
164class S3Object:
165 def __init__(self, path):
166 assert is_s3_path(path)
167 self.path = path
168 self.bucket, self.object = self.path[5:].split("/", 1)
170 class OutputFile:
171 def __init__(self, s3_object: "S3Object"):
172 self.s3Object = s3_object
173 self.buffer = io.BytesIO()
175 def write(self, obj: bytes):
176 self.buffer.write(obj)
178 def __enter__(self):
179 return self
181 def __exit__(self, exc_type, exc_val, exc_tb):
182 self.s3Object.put(self.buffer.getvalue())
184 def get_file_content(self):
185 return self._get_s3_object().get()['Body'].read()
187 def open_file(self, mode):
188 assert mode in ("wb", "rb")
189 if mode == "rb":
190 content = self.get_file_content()
191 return io.BytesIO(content)
193 elif mode == "wb":
194 return self.OutputFile(self)
196 else:
197 raise ValueError(mode)
199 def put(self, obj: bytes):
200 self._get_s3_object().put(Body=obj)
202 def _get_s3_object(self):
203 import boto3
204 session = boto3.session.Session(profile_name=os.getenv("AWS_PROFILE"))
205 s3 = session.resource("s3")
206 return s3.Bucket(self.bucket).Object(self.object)
209def create_path(root: str, *path_elems: str, is_dir: bool, make_dirs: bool = False) -> str:
210 path = os.path.join(root, *path_elems)
211 if make_dirs:
212 dir_path = path if is_dir else os.path.dirname(path)
213 os.makedirs(dir_path, exist_ok=True)
214 return path
217def create_file_path(root, *path_elems, make_dirs: bool = False) -> str:
218 return create_path(root, *path_elems, is_dir=False, make_dirs=make_dirs)
221def create_dir_path(root, *path_elems, make_dirs: bool = False) -> str:
222 return create_path(root, *path_elems, is_dir=True, make_dirs=make_dirs)