Coverage for src/sensai/util/cache_mysql.py: 24%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import enum 

2import logging 

3import pickle 

4import pandas as pd 

5 

6from .cache import PersistentKeyValueCache, DelayedUpdateHook 

7 

8log = logging.getLogger(__name__) 

9 

10 

11class MySQLPersistentKeyValueCache(PersistentKeyValueCache): 

12 

13 class ValueType(enum.Enum): 

14 DOUBLE = ("DOUBLE", False) # (SQL data type, isCachedValuePickled) 

15 BLOB = ("BLOB", True) 

16 

17 def __init__(self, host, db, user, pw, value_type: ValueType, table_name="cache", deferred_commit_delay_secs=1.0, in_memory=False): 

18 import MySQLdb 

19 self.conn = MySQLdb.connect(host=host, database=db, user=user, password=pw) 

20 self.table_name = table_name 

21 self.max_key_length = 255 

22 self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs) 

23 self._num_entries_to_be_committed = 0 

24 

25 cache_value_sql_type, self.is_cache_value_pickled = value_type.value 

26 

27 cursor = self.conn.cursor() 

28 cursor.execute(f"SHOW TABLES;") 

29 if table_name not in [r[0] for r in cursor.fetchall()]: 

30 cursor.execute(f"CREATE TABLE {table_name} (cache_key VARCHAR({self.max_key_length}) PRIMARY KEY, " 

31 f"cache_value {cache_value_sql_type});") 

32 cursor.close() 

33 

34 self._in_memory_df = None if not in_memory else self._load_table_to_data_frame() 

35 

36 def _load_table_to_data_frame(self): 

37 df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self.conn, index_col="cache_key") 

38 if self.is_cache_value_pickled: 

39 df["cache_value"] = df["cache_value"].apply(pickle.loads) 

40 return df 

41 

42 def set(self, key, value): 

43 key = str(key) 

44 if len(key) > self.max_key_length: 

45 raise ValueError(f"Key too long, maximal key length is {self.max_key_length}") 

46 cursor = self.conn.cursor() 

47 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=%s", (key,)) 

48 stored_value = pickle.dumps(value) if self.is_cache_value_pickled else value 

49 if cursor.fetchone()[0] == 0: 

50 cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)", 

51 (key, stored_value)) 

52 else: 

53 cursor.execute(f"UPDATE {self.table_name} SET cache_value=%s WHERE cache_key=%s", (stored_value, key)) 

54 self._num_entries_to_be_committed += 1 

55 self._update_hook.handle_update() 

56 cursor.close() 

57 if self._in_memory_df is not None: 

58 self._in_memory_df["cache_value"][str(key)] = value 

59 

60 def get(self, key): 

61 value = self._get_from_in_memory_df(key) 

62 if value is None: 

63 value = self._get_from_table(key) 

64 return value 

65 

66 def _get_from_table(self, key): 

67 cursor = self.conn.cursor() 

68 cursor.execute(f"SELECT cache_value FROM {self.table_name} WHERE cache_key=%s", (str(key),)) 

69 row = cursor.fetchone() 

70 if row is None: 

71 return None 

72 stored_value = row[0] 

73 value = pickle.loads(stored_value) if self.is_cache_value_pickled else stored_value 

74 return value 

75 

76 def _get_from_in_memory_df(self, key): 

77 if self._in_memory_df is None: 

78 return None 

79 try: 

80 return self._in_memory_df["cache_value"][str(key)] 

81 except Exception as e: 

82 log.debug(f"Unable to load value for key {str(key)} from in-memory dataframe: {e}") 

83 return None 

84 

85 def _commit(self): 

86 log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the database") 

87 self.conn.commit() 

88 self._num_entries_to_be_committed = 0