Source code for sensai.util.cache_mysql
import enum
import logging
import pickle
import pandas as pd
from .cache import PersistentKeyValueCache, DelayedUpdateHook
log = logging.getLogger(__name__)
[docs]class MySQLPersistentKeyValueCache(PersistentKeyValueCache):
[docs] class ValueType(enum.Enum):
DOUBLE = ("DOUBLE", False) # (SQL data type, isCachedValuePickled)
BLOB = ("BLOB", True)
def __init__(self, host, db, user, pw, value_type: ValueType, table_name="cache", deferred_commit_delay_secs=1.0, in_memory=False):
import MySQLdb
self.conn = MySQLdb.connect(host=host, database=db, user=user, password=pw)
self.table_name = table_name
self.max_key_length = 255
self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs)
self._num_entries_to_be_committed = 0
cache_value_sql_type, self.is_cache_value_pickled = value_type.value
cursor = self.conn.cursor()
cursor.execute(f"SHOW TABLES;")
if table_name not in [r[0] for r in cursor.fetchall()]:
cursor.execute(f"CREATE TABLE {table_name} (cache_key VARCHAR({self.max_key_length}) PRIMARY KEY, "
f"cache_value {cache_value_sql_type});")
cursor.close()
self._in_memory_df = None if not in_memory else self._load_table_to_data_frame()
def _load_table_to_data_frame(self):
df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self.conn, index_col="cache_key")
if self.is_cache_value_pickled:
df["cache_value"] = df["cache_value"].apply(pickle.loads)
return df
[docs] def set(self, key, value):
key = str(key)
if len(key) > self.max_key_length:
raise ValueError(f"Key too long, maximal key length is {self.max_key_length}")
cursor = self.conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=%s", (key,))
stored_value = pickle.dumps(value) if self.is_cache_value_pickled else value
if cursor.fetchone()[0] == 0:
cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)",
(key, stored_value))
else:
cursor.execute(f"UPDATE {self.table_name} SET cache_value=%s WHERE cache_key=%s", (stored_value, key))
self._num_entries_to_be_committed += 1
self._update_hook.handle_update()
cursor.close()
if self._in_memory_df is not None:
self._in_memory_df["cache_value"][str(key)] = value
[docs] def get(self, key):
value = self._get_from_in_memory_df(key)
if value is None:
value = self._get_from_table(key)
return value
def _get_from_table(self, key):
cursor = self.conn.cursor()
cursor.execute(f"SELECT cache_value FROM {self.table_name} WHERE cache_key=%s", (str(key),))
row = cursor.fetchone()
if row is None:
return None
stored_value = row[0]
value = pickle.loads(stored_value) if self.is_cache_value_pickled else stored_value
return value
def _get_from_in_memory_df(self, key):
if self._in_memory_df is None:
return None
try:
return self._in_memory_df["cache_value"][str(key)]
except Exception as e:
log.debug(f"Unable to load value for key {str(key)} from in-memory dataframe: {e}")
return None
def _commit(self):
log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the database")
self.conn.commit()
self._num_entries_to_be_committed = 0