Source code for sensai.util.cache_mysql
import enum
import logging
import pickle
import pandas as pd
from .cache import PersistentKeyValueCache, DelayedUpdateHook
log = logging.getLogger(__name__)
[docs]class MySQLPersistentKeyValueCache(PersistentKeyValueCache):
"""
Can cache arbitrary values in a MySQL database.
The keys are always strings at the database level, i.e. if a key is not a string, it is converted to a string using str().
"""
[docs] class ValueType(enum.Enum):
"""
The value type to use within the MySQL database.
Note that the binary BLOB types can be used for all Python types that can be pickled, so the lack
of specific types (e.g. for strings) is not a problem.
"""
# enum values are (SQL data type, isCachedValuePickled)
DOUBLE = ("DOUBLE", False)
BLOB = ("BLOB", True)
"""
for Python data types whose pickled representation is up to 64 KB
"""
MEDIUMBLOB = ("MEDIUMBLOB", True)
"""
for Python data types whose pickled representation is up to 16 MB
"""
def __init__(self, host: str, db: str, user: str, pw: str, value_type: ValueType, table_name="cache",
connect_params: dict | None = None, in_memory=False, max_key_length: int = 255, port=3306):
"""
:param host:
:param db:
:param user:
:param pw:
:param value_type: the type of value to store in the cache
:param table_name:
:param connect_params: additional parameters to pass to the pymysql.connect() function (e.g. ssl, etc.)
:param in_memory:
:param max_key_length: maximal length of the cache key string (keys are always strings) stored in the DB
(i.e. the MySQL type is VARCHAR[max_key_length])
:param port: the MySQL server port to connect to
"""
import pymysql
if connect_params is None:
connect_params = {}
self._connect = lambda: pymysql.connect(host=host, database=db, user=user, password=pw, port=port, autocommit=True,
**connect_params)
self._conn = self._connect()
self.table_name = table_name
self.max_key_length = max_key_length
cache_value_sql_type, self.is_cache_value_pickled = value_type.value
cursor = self._conn.cursor()
cursor.execute(f"SHOW TABLES;")
if table_name not in [r[0] for r in cursor.fetchall()]:
log.debug(f"Creating table {table_name}")
cursor.execute(f"CREATE TABLE {table_name} (cache_key VARCHAR({self.max_key_length}) PRIMARY KEY, "
f"cache_value {cache_value_sql_type});")
cursor.close()
self._in_memory_df = None if not in_memory else self._load_table_to_data_frame()
def _cursor(self):
try:
self._conn.ping(reconnect=True)
except Exception as e:
log.error(f"Error while pinging MySQL server: {e}; Reconnecting ...")
self._conn = self._connect()
return self._conn.cursor()
def _load_table_to_data_frame(self):
df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self._conn, index_col="cache_key")
if self.is_cache_value_pickled:
df["cache_value"] = df["cache_value"].apply(pickle.loads)
return df
[docs] def set(self, key, value):
key = str(key)
if len(key) > self.max_key_length:
raise ValueError(f"Key too long, maximal key length is {self.max_key_length}")
cursor = self._cursor()
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=%s", (key,))
stored_value = pickle.dumps(value) if self.is_cache_value_pickled else value
if cursor.fetchone()[0] == 0:
from pymysql.err import IntegrityError
try:
cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)",
(key, stored_value))
except IntegrityError as e:
if e.args[0] == 1062: # Duplicate entry
# This can only happen when the user is inserting the same value almost simultaneously (race condition)
args = list(e.args)
args[1] = f"{args[1]}; The duplicate entry is due to quasi-simultaneous insertions for the same key; " \
"Check your application logic!"
raise IntegrityError(*args)
else:
raise
else:
cursor.execute(f"UPDATE {self.table_name} SET cache_value=%s WHERE cache_key=%s", (stored_value, key))
cursor.close()
if self._in_memory_df is not None:
self._in_memory_df["cache_value"][str(key)] = value
[docs] def get(self, key):
value = self._get_from_in_memory_df(key)
if value is None:
value = self._get_from_table(key)
return value
def _get_from_table(self, key):
cursor = self._cursor()
cursor.execute(f"SELECT cache_value FROM {self.table_name} WHERE cache_key=%s", (str(key),))
row = cursor.fetchone()
cursor.close()
if row is None:
return None
stored_value = row[0]
value = pickle.loads(stored_value) if self.is_cache_value_pickled else stored_value
return value
def _get_from_in_memory_df(self, key):
if self._in_memory_df is None:
return None
try:
return self._in_memory_df["cache_value"][str(key)]
except Exception as e:
log.debug(f"Unable to load value for key {str(key)} from in-memory dataframe: {e}")
return None