Coverage for src/sensai/util/io.py: 26%
130 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import io
2import logging
3import os
4from typing import Sequence, Optional, Tuple, List, Any, TYPE_CHECKING
6if TYPE_CHECKING:
7 from matplotlib import pyplot as plt
8 import pandas as pd
10log = logging.getLogger(__name__)
13class ResultWriter:
14 log = log.getChild(__qualname__)
16 def __init__(self, result_dir, filename_prefix="", enabled: bool = True, close_figures: bool = False):
17 """
18 :param result_dir:
19 :param filename_prefix:
20 :param enabled: whether the result writer is enabled; if it is not, it will create neither files nor directories
21 :param close_figures: whether to close figures that are passed by default
22 """
23 self.result_dir = result_dir
24 self.filename_prefix = filename_prefix
25 self.enabled = enabled
26 self.close_figures_default = close_figures
27 if self.enabled:
28 os.makedirs(result_dir, exist_ok=True)
30 def child_with_added_prefix(self, prefix: str) -> "ResultWriter":
31 """
32 Creates a derived result writer with an added prefix, i.e. the given prefix is appended to this
33 result writer's prefix
35 :param prefix: the prefix to append
36 :return: a new writer instance
37 """
38 return ResultWriter(self.result_dir, filename_prefix=self.filename_prefix + prefix, enabled=self.enabled,
39 close_figures=self.close_figures_default)
41 def child_for_subdirectory(self, dir_name: str) -> "ResultWriter":
42 result_dir = os.path.join(self.result_dir, dir_name)
43 return ResultWriter(result_dir, filename_prefix=self.filename_prefix, enabled=self.enabled,
44 close_figures=self.close_figures_default)
46 def path(self, filename_suffix: str, extension_to_add=None, valid_other_extensions: Optional[Sequence[str]] = None) -> str:
47 """
48 :param filename_suffix: the suffix to add (which may or may not already include a file extension)
49 :param extension_to_add: if not None, the file extension to add (without the leading ".") unless
50 the extension to add or one of the extenions in valid_extensions is already present
51 :param valid_other_extensions: a sequence of valid other extensions (without the "."), only
52 relevant if extensionToAdd is specified
53 :return: the full path
54 """
55 if extension_to_add is not None:
56 add_ext = True
57 valid_extensions = set(valid_other_extensions) if valid_other_extensions is not None else set()
58 valid_extensions.add(extension_to_add)
59 if valid_extensions is not None:
60 for ext in valid_extensions:
61 if filename_suffix.endswith("." + ext):
62 add_ext = False
63 break
64 if add_ext:
65 filename_suffix += "." + extension_to_add
66 path = os.path.join(self.result_dir, f"{self.filename_prefix}{filename_suffix}")
67 return path
69 def write_text_file(self, filename_suffix: str, content: str):
70 p = self.path(filename_suffix, extension_to_add="txt")
71 if self.enabled:
72 self.log.info(f"Saving text file {p}")
73 with open(p, "w") as f:
74 f.write(content)
75 return p
77 def write_text_file_lines(self, filename_suffix: str, lines: List[str]):
78 p = self.path(filename_suffix, extension_to_add="txt")
79 if self.enabled:
80 self.log.info(f"Saving text file {p}")
81 write_text_file_lines(lines, p)
82 return p
84 def write_data_frame_text_file(self, filename_suffix: str, df: "pd.DataFrame"):
85 p = self.path(filename_suffix, extension_to_add="df.txt", valid_other_extensions="txt")
86 if self.enabled:
87 self.log.info(f"Saving data frame text file {p}")
88 with open(p, "w") as f:
89 f.write(df.to_string())
90 return p
92 def write_data_frame_csv_file(self, filename_suffix: str, df: "pd.DataFrame", index=True, header=True):
93 p = self.path(filename_suffix, extension_to_add="csv")
94 if self.enabled:
95 self.log.info(f"Saving data frame CSV file {p}")
96 df.to_csv(p, index=index, header=header)
97 return p
99 def write_figure(self, filename_suffix: str, fig: "plt.Figure", close_figure: Optional[bool] = None):
100 """
101 :param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"}
102 :param fig: the figure to save
103 :param close_figure: whether to close the figure after having saved it; if None, use default passed at construction
104 :return: the path to the file that was written (or would have been written if the writer was enabled)
105 """
106 from matplotlib import pyplot as plt
107 p = self.path(filename_suffix, extension_to_add="png", valid_other_extensions=("jpg",))
108 if self.enabled:
109 self.log.info(f"Saving figure {p}")
110 fig.savefig(p, bbox_inches="tight")
111 must_close_figure = close_figure if close_figure is not None else self.close_figures_default
112 if must_close_figure:
113 plt.close(fig)
114 return p
116 def write_figures(self, figures: Sequence[Tuple[str, "plt.Figure"]], close_figures=False):
117 for name, fig in figures:
118 self.write_figure(name, fig, close_figure=close_figures)
120 def write_pickle(self, filename_suffix: str, obj: Any):
121 from .pickle import dump_pickle
122 p = self.path(filename_suffix, extension_to_add="pickle")
123 if self.enabled:
124 self.log.info(f"Saving pickle {p}")
125 dump_pickle(obj, p)
126 return p
129def write_text_file_lines(lines: List[str], path):
130 """
131 :param lines: the lines to write (without a trailing newline, which will be added)
132 :param path: the path of the text file to write to
133 """
134 with open(path, "w") as f:
135 for line in lines:
136 f.write(line)
137 f.write("\n")
140def read_text_file_lines(path, strip=True, skip_empty=True) -> List[str]:
141 """
142 :param path: the path of the text file to read from
143 :param strip: whether to strip each line, removing whitespace/newline characters
144 :param skip_empty: whether to skip any lines that are empty (after stripping)
145 :return: the list of lines
146 """
147 lines = []
148 with open(path, "r") as f:
149 for line in f.readlines():
150 if strip:
151 line = line.strip()
152 if not skip_empty or line != "":
153 lines.append(line)
154 return lines
157def is_s3_path(path: str):
158 return path.startswith("s3://")
161class S3Object:
162 def __init__(self, path):
163 assert is_s3_path(path)
164 self.path = path
165 self.bucket, self.object = self.path[5:].split("/", 1)
167 class OutputFile:
168 def __init__(self, s3_object: "S3Object"):
169 self.s3Object = s3_object
170 self.buffer = io.BytesIO()
172 def write(self, obj: bytes):
173 self.buffer.write(obj)
175 def __enter__(self):
176 return self
178 def __exit__(self, exc_type, exc_val, exc_tb):
179 self.s3Object.put(self.buffer.getvalue())
181 def get_file_content(self):
182 return self._get_s3_object().get()['Body'].read()
184 def open_file(self, mode):
185 assert mode in ("wb", "rb")
186 if mode == "rb":
187 content = self.get_file_content()
188 return io.BytesIO(content)
190 elif mode == "wb":
191 return self.OutputFile(self)
193 else:
194 raise ValueError(mode)
196 def put(self, obj: bytes):
197 self._get_s3_object().put(Body=obj)
199 def _get_s3_object(self):
200 import boto3
201 session = boto3.session.Session(profile_name=os.getenv("AWS_PROFILE"))
202 s3 = session.resource("s3")
203 return s3.Bucket(self.bucket).Object(self.object)