diff --git a/.picktest.py.swp b/.picktest.py.swp new file mode 100644 index 000000000..68105ebd1 Binary files /dev/null and b/.picktest.py.swp differ diff --git a/src/memmachine/common/ralational_table_storage/relational_data_store.py b/src/memmachine/common/ralational_table_storage/relational_data_store.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py b/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py index a67dcdb4e..cf750cab6 100644 --- a/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py +++ b/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py @@ -437,7 +437,7 @@ async def delete_nodes( ): async with self._semaphore: await self._driver.execute_query( - """ + """delete UNWIND $node_uuids AS node_uuid MATCH (n {uuid: node_uuid}) DETACH DELETE n diff --git a/src/memmachine/configuration/episodic_config.py b/src/memmachine/configuration/episodic_config.py new file mode 100644 index 000000000..a3a92b2ef --- /dev/null +++ b/src/memmachine/configuration/episodic_config.py @@ -0,0 +1,128 @@ +import string +from typing import Annotated + +from pydantic import BaseModel, Field, InstanceOf, field_validator, model_validator + +from memmachine.common.embedder import Embedder +from memmachine.common.language_model import LanguageModel +from memmachine.common.metrics_factory import MetricsFactory +from memmachine.common.reranker import Reranker +from memmachine.common.vector_graph_store import VectorGraphStore +from memmachine.session_manager_interface import SessionDataManager + + +class ShortTermMemoryParams(BaseModel): + """ + Parameters for configuring the short-term memory. + Attriutes: + session_key (str): The unique identifier for the session. + llm_model (LanguageModel): The language model to use for summarization. + data_manager (SessionDataManager): The session data manager. + summary_prompt_system (str): The system prompt for the summarization. + summary_prompt_user (str): The user prompt for the summarization. + message_capacity (int): The maximum number of messages to summarize. + enabled (bool): Whether the short-term memory is enabled. + """ + + session_key: Annotated[str, Field(..., min_length=1)] + llm_model: InstanceOf[LanguageModel] = Field(..., description="The language model to use for summarization") + data_manager: InstanceOf[SessionDataManager] | None = Field(default=None, description="The session data manager") + summary_prompt_system: Annotated[str, Field(..., min_length=1, description="The system prompt for the summarization")] + summary_prompt_user: Annotated[str, Field(..., min_length=1, description="The user prompt for the summarization")] + message_capacity: Annotated[int, Field(default=64000, gt=0, description="The maximum length of short-term memory")] + enabled: bool = True + + @field_validator("summary_prompt_user") + def validate_sumary_user_prompt(cls, v): + fields = [fname for _, fname, _, _ in string.Formatter().parse(v) if fname] + if len(fields) != 3: + raise ValueError(f"Expect 3 fields in {v}") + if "episodes" not in fields: + raise ValueError(f"Expect 'episodes' in {v}") + if "summary" not in fields: + raise ValueError(f"Expect 'summary' in {v}") + if "max_length" not in fields: + raise ValueError(f"Expect 'max_length' in {v}") + return v + +class LongTermMemoryParams(BaseModel): + """ + Parameters for DeclarativeMemory. + + Attributes: + session_id (str): + Session identifier. + max_chunk_length (int): + Maximum length of a chunk in characters + (default: 1000). + vector_graph_store (VectorGraphStore): + VectorGraphStore instance + for storing and retrieving memories. + embedder (Embedder): + Embedder instance for creating embeddings. + reranker (Reranker): + Reranker instance for reranking search results. + """ + + session_id: str = Field( + ..., + description="Session identifier", + ) + max_chunk_length: int = Field( + 1000, + description="Maximum length of a chunk in characters.", + gt=0, + ) + vector_graph_store: InstanceOf[VectorGraphStore] = Field( + ..., + description="VectorGraphStore instance for storing and retrieving memories", + ) + embedder: InstanceOf[Embedder] = Field( + ..., + description="Embedder instance for creating embeddings", + ) + reranker: InstanceOf[Reranker] = Field( + ..., + description="Reranker instance for reranking search results", + ) + enabled: bool = True + + +class EpisodicMemoryParams(BaseModel): + """ + Parameters for configuring the EpisodicMemory. + Attributes: + session_key (str): The unique identifier for the session. + metrics_factory (MetricsFactory): The metrics factory. + short_term_memory (ShortTermMemoryParams): The short-term memory parameters. + long_term_memory (LongTermMemoryParams): The long-term memory parameters. + enabled (bool): Whether the episodic memory is enabled. + """ + session_key: Annotated[str, Field(..., min_length=1, description="The unique identifier for the session")] + metrics_factory: InstanceOf[MetricsFactory] = Field(..., description="The metrics factory") + short_term_memory: ShortTermMemoryParams | None = Field(default=None, description="The short-term memory parameters") + long_term_memory: LongTermMemoryParams | None = Field(default=None, description="The long-term memory parameters") + enabled: bool = True + + @model_validator(mode="after") + def validate_memory_params(self): + if self.enabled is False: + return self + if self.short_term_memory is None and self.long_term_memory is None: + raise ValueError( + "At least one of short_term_memory or long_term_memory must be provided." + ) + return self + + +class EpisodicMemoryManagerParam(BaseModel): + """ + Parameters for configuring the EpisodicMemoryManager. + Attributes: + instance_cache_size (int): The maximum number of instances to cache. + max_life_time (int): The maximum idle lifetime of an instance in seconds. + session_storage (SessionDataManager): The session storage. + """ + instance_cache_size: Annotated[int, Field(default=100, gt=0, description="The maximum number of instances to cache")] + max_life_time: Annotated[int, Field(default=600, gt=0, description="The maximum idle lifetime of an instance in seconds")] + session_storage: InstanceOf[SessionDataManager] = Field(..., description="Session storage") diff --git a/src/memmachine/episodic_memory/data_types.py b/src/memmachine/episodic_memory/data_types.py index fd06ce347..f3dacdd61 100644 --- a/src/memmachine/episodic_memory/data_types.py +++ b/src/memmachine/episodic_memory/data_types.py @@ -14,6 +14,12 @@ class ContentType(Enum): STRING = "string" # Other content types like 'vector', 'image' could be added here. +class EpisodeType(Enum): + """Enumeration for the type of an Episode.""" + + MESSAGE = "message" + THOUGHT = "thought" + ACTION = "action" @dataclass class SessionInfo: @@ -92,7 +98,11 @@ class Episode: uuid: UUID """A unique identifier (UUID) for the episode.""" - episode_type: str + sequence_num: int + """Sequence number of the Episode""" + session_key: str + """The identifier for the session to which this episode belongs.""" + episode_type: EpisodeType """ A string indicating the type of the episode (e.g., 'message', 'thought', 'action'). @@ -103,12 +113,10 @@ class Episode: """The actual data of the episode, which can be of any type.""" timestamp: datetime """The date and time when the episode occurred.""" - group_id: str - """Identifier for the group (e.g., a specific chat room or DM).""" - session_id: str - """Identifier for the session to which this episode belongs.""" producer_id: str """The identifier of the user or agent that created this episode.""" + producer_role: str + """The role of the producer (e.g., 'HR', 'agent', 'engineer').""" produced_for_id: str | None = None """The identifier of the intended recipient, if any.""" user_metadata: JSONValue = None diff --git a/src/memmachine/episodic_memory/episodic_memory.py b/src/memmachine/episodic_memory/episodic_memory.py index 02cc3b20c..30043cc09 100644 --- a/src/memmachine/episodic_memory/episodic_memory.py +++ b/src/memmachine/episodic_memory/episodic_memory.py @@ -17,22 +17,18 @@ """ import asyncio -import copy import logging import uuid from datetime import datetime +import time from typing import cast -from memmachine.common.language_model.language_model_builder import ( - LanguageModelBuilder, -) -from memmachine.common.metrics_factory.metrics_factory_builder import ( - MetricsFactoryBuilder, -) +from memmachine.configuration.episodic_config import EpisodicMemoryParams -from .data_types import ContentType, Episode, MemoryContext + +from .data_types import ContentType, Episode, EpisodeType from .long_term_memory.long_term_memory import LongTermMemory -from .short_term_memory.session_memory import SessionMemory +from .short_term_memory.short_term_memory import ShortTermMemory logger = logging.getLogger(__name__) @@ -40,81 +36,39 @@ class EpisodicMemory: # pylint: disable=too-many-instance-attributes """ - Represents a single, isolated memory instance for a specific context. + Represents a single, isolated memory instance for a specific session. This class orchestrates the interaction between short-term (session) memory and long-term (declarative) memory. It manages the lifecycle of the memory, handles adding new information (episodes), and provides methods to retrieve contextual information for queries. - Each instance is tied to a unique `MemoryContext` (defined by group, agent, - user, and session IDs) and is managed by a central - `EpisodicMemoryManager`. + Each instance is tied to a unique session key """ - def __init__(self, manager, config: dict, memory_context: MemoryContext): + def __init__( + self, + param: EpisodicMemoryParams, + session_memory: ShortTermMemory | None = None, + long_term_memory: LongTermMemory | None = None + ): # pylint: disable=too-many-instance-attributes """ Initializes a EpisodicMemory instance. Args: manager: The EpisodicMemoryManager that created this instance. - config: A dictionary containing the configuration for this memory - instance. - memory_context: The unique context for this memory instance. - """ - self._memory_context = memory_context - self._manager = manager # The manager that created this instance - self._lock = asyncio.Lock() # Lock for thread-safe operations - - model_config = config.get("model") - short_config = config.get("sessionmemory", {}) - long_term_config = config.get("long_term_memory", {}) - - self._ref_count = 1 # For reference counting to manage lifecycle - self._session_memory: SessionMemory | None = None - self._long_term_memory: LongTermMemory | None = None - metrics_manager = MetricsFactoryBuilder.build("prometheus", {}, {}) - - if len(short_config) > 0 and short_config.get("enabled") != "false": - model_name = short_config.get("model_name") - if model_name is None or len(model_name) < 1: - raise ValueError("Invalid model name") - - if model_config is None or model_config.get(model_name) is None: - raise ValueError("Invalid model configuration") - - model_config = copy.deepcopy(model_config.get(model_name)) - """ - only support prometheus now. - TODO: support different metrics and make it configurable - """ - model_config["metrics_factory_id"] = "prometheus" - model_vendor = model_config.pop("model_vendor") - metrics_injection = {} - metrics_injection["prometheus"] = metrics_manager - - llm_model = LanguageModelBuilder.build( - model_vendor, - model_config, - metrics_injection, - ) - - # Initialize short-term session memory - self._session_memory = SessionMemory( - llm_model, - config.get("prompts", {}).get("episode_summary_prompt_system"), - config.get("prompts", {}).get("episode_summary_prompt_user"), - short_config.get("message_capacity", 1000), - short_config.get("max_message_length", 128000), - short_config.get("max_token_num", 65536), - self._memory_context, - ) - - if len(long_term_config) > 0 and long_term_config.get("enabled") != "false": - # Initialize long-term declarative memory - self._long_term_memory = LongTermMemory(config, self._memory_context) - if self._session_memory is None and self._long_term_memory is None: + param: The paraters required to initialize the episodic memory + """ + self._session_key = param.session_key + self._closed = False + self._short_term_memory: ShortTermMemory | None = session_memory + self._long_term_memory: LongTermMemory | None = long_term_memory + metrics_manager = param.metrics_factory + self._enabled = param.enabled + if not self._enabled: + return + if self._short_term_memory is None and self._long_term_memory is None: raise ValueError("No memory is configured") # Initialize metrics @@ -131,24 +85,34 @@ def __init__(self, manager, config: dict, memory_context: MemoryContext): "query_count", "Count of query processing" ) + @classmethod + async def create(cls, param: EpisodicMemoryParams) -> 'EpisodicMemory': + session_memory: ShortTermMemory | None = None + if param.short_term_memory and param.short_term_memory.enabled: + session_memory = await ShortTermMemory.create(param.short_term_memory) + long_term_memory: LongTermMemory | None = None + if param.long_term_memory and param.long_term_memory.enabled: + long_term_memory = LongTermMemory(param.long_term_memory) + return EpisodicMemory(param, session_memory, long_term_memory) + @property - def short_term_memory(self) -> SessionMemory | None: + def short_term_memory(self) -> ShortTermMemory | None: """ Get the short-term memory of the episodic memory instance Returns: The short-term memory of the episodic memory instance. """ - return self._session_memory + return self._short_term_memory @short_term_memory.setter - def short_term_memory(self, value: SessionMemory | None): + def short_term_memory(self, value: ShortTermMemory | None): """ Set the short-term memory of the episodic memory instance This makes the short term memory can be injected Args: value: The new short-term memory of the episodic memory instance. """ - self._session_memory = value + self._short_term_memory = value @property def long_term_memory(self) -> LongTermMemory | None: @@ -169,48 +133,18 @@ def long_term_memory(self, value: LongTermMemory | None): """ self._long_term_memory = value - def get_memory_context(self) -> MemoryContext: - """ - Get the memory context of the episodic memory instance - Returns: - The memory context of the episodic memory instance. - """ - return self._memory_context - - def get_reference_count(self) -> int: - """ - Get the reference count of the episodic memory instance - Returns: - The reference count of the episodic memory instance. - """ - return self._ref_count - - async def reference(self) -> bool: + @property + def session_key(self) -> str: """ - Increments the reference count for this instance. - - Used by the manager to track how many clients are actively using this - memory instance. - + Get the session key of the episodic memory instance Returns: - True if the reference was successfully added, False if the instance - is already closed. + The session key of the episodic memory instance. """ - async with self._lock: - if self._ref_count <= 0: - return False - self._ref_count += 1 - return True + return self._session_key async def add_memory_episode( self, - producer: str, - produced_for: str, - episode_content: str | list[float], - episode_type: str, - content_type: ContentType, - timestamp: datetime | None = None, - metadata: dict | None = None, + episode: Episode ): # pylint: disable=too-many-arguments # pylint: disable=too-many-positional-arguments @@ -232,60 +166,25 @@ async def add_memory_episode( Returns: True if the episode was added successfully, False otherwise. """ - # Validate that the producer and recipient are part of this memory - # context - if ( - producer not in self._memory_context.user_id - and producer not in self._memory_context.agent_id - ): - logger.error("The producer %s does not belong to the session", producer) - raise ValueError(f"The producer {producer} does not belong to the session") - - if ( - produced_for not in self._memory_context.user_id - and produced_for not in self._memory_context.agent_id - ): - logger.error( - "The produced_for %s does not belong to the session", - produced_for, - ) - raise ValueError( - f"""The produced_for {produced_for} does not belong to - the session""" - ) - - start_time = datetime.now() - - # Create a new Episode object - episode = Episode( - uuid=uuid.uuid4(), - episode_type=episode_type, - content_type=content_type, - content=episode_content, - timestamp=timestamp if timestamp else datetime.now(), - group_id=self._memory_context.group_id, - session_id=self._memory_context.session_id, - producer_id=producer, - produced_for_id=produced_for, - user_metadata=metadata, - ) + if not self._enabled: + return + start_time = time.monotonic_ns() + if self._closed: + raise RuntimeError(f"Memory is closed {self._session_key}") # Add the episode to both memory stores concurrently tasks = [] - if self._session_memory: - tasks.append(self._session_memory.add_episode(episode)) + if self._short_term_memory: + tasks.append(self._short_term_memory.add_episode(episode)) if self._long_term_memory: tasks.append(self._long_term_memory.add_episode(episode)) await asyncio.gather( *tasks, ) - end_time = datetime.now() - delta = end_time - start_time - self._ingestion_latency_summary.observe( - delta.total_seconds() * 1000 + delta.microseconds / 1000 - ) + end_time = time.monotonic_ns() + delta = (end_time - start_time) / 1000000 + self._ingestion_latency_summary.observe(delta) self._ingestion_counter.increment() - return True async def close(self): """ @@ -296,21 +195,30 @@ async def close(self): stores and notifies the manager to remove this instance from its registry. """ - async with self._lock: - self._ref_count -= 1 - if self._ref_count > 0: - return - - # If no more references, proceed with closing - logger.info("Closing context memory: %s", str(self._memory_context)) - tasks = [] - if self._session_memory: - tasks.append(self._session_memory.close()) - await asyncio.gather( - *tasks, - ) - await self._manager.delete_context_memory(self._memory_context) + self._closed = True + if not self._enabled: + return + tasks = [] + if self._short_term_memory: + tasks.append(self._short_term_memory.close()) + if self._long_term_memory: + tasks.append(self._long_term_memory.close()) + await asyncio.gather(*tasks) + return + + + + async def delete_episode(self, uuid: uuid.UUID): + """Delete one episode by uuid""" + if not self._enabled: return + tasks = [] + if self._short_term_memory: + tasks.append(self._short_term_memory.delete_episode(uuid)) + if self._long_term_memory: + tasks.append(self._long_term_memory.delete_episode(uuid)) + await asyncio.gather(*tasks) + return async def delete_data(self): """ @@ -318,14 +226,15 @@ async def delete_data(self): context. This is a destructive operation. """ - async with self._lock: - tasks = [] - if self._session_memory: - tasks.append(self._session_memory.clear_memory()) - if self._long_term_memory: - tasks.append(self._long_term_memory.forget_session()) - await asyncio.gather(*tasks) + if not self._enabled: return + tasks = [] + if self._short_term_memory: + tasks.append(self._short_term_memory.clear_memory()) + if self._long_term_memory: + tasks.append(self._long_term_memory.forget_session()) + await asyncio.gather(*tasks) + return async def query_memory( self, @@ -353,40 +262,38 @@ async def query_memory( a list of long term memory Episode objects, and a list of summary strings. """ - - start_time = datetime.now() + if not self._enabled: + return [], [], [] + start_time = time.monotonic_ns() search_limit = limit if limit is not None else 20 if property_filter is None: property_filter = {} - # By default, always allow cross session search - property_filter["group_id"] = self._memory_context.group_id - - async with self._lock: - if self._session_memory is None: - short_episode: list[Episode] = [] - short_summary = "" - long_episode = await cast( - LongTermMemory, self._long_term_memory - ).search( - query, - search_limit, - property_filter, - ) - elif self._long_term_memory is None: - session_result = await self._session_memory.get_session_memory_context( + + if self._short_term_memory is None: + short_episode: list[Episode] = [] + short_summary = "" + long_episode = await cast( + LongTermMemory, self._long_term_memory + ).search( + query, + search_limit, + property_filter, + ) + elif self._long_term_memory is None: + session_result = await self._short_term_memory.get_session_memory_context( + query, limit=search_limit, filter=property_filter + ) + long_episode = [] + short_episode, short_summary = session_result + else: + # Concurrently search both memory stores + session_result, long_episode = await asyncio.gather( + self._short_term_memory.get_session_memory_context( query, limit=search_limit - ) - long_episode = [] - short_episode, short_summary = session_result - else: - # Concurrently search both memory stores - session_result, long_episode = await asyncio.gather( - self._session_memory.get_session_memory_context( - query, limit=search_limit - ), - self._long_term_memory.search(query, search_limit, property_filter), - ) - short_episode, short_summary = session_result + ), + self._long_term_memory.search(query, search_limit, property_filter), + ) + short_episode, short_summary = session_result # Deduplicate episodes from both memory stores, prioritizing # short-term memory @@ -398,11 +305,9 @@ async def query_memory( uuid_set.add(episode.uuid) unique_long_episodes.append(episode) - end_time = datetime.now() - delta = end_time - start_time - self._query_latency_summary.observe( - delta.total_seconds() * 1000 + delta.microseconds / 1000 - ) + end_time = time.monotonic_ns() + delta = (end_time - start_time) / 1000000 + self._query_latency_summary.observe(delta) self._query_counter.increment() return short_episode, unique_long_episodes, [short_summary] @@ -455,41 +360,3 @@ async def formalize_query_with_context( finalized_query += f"\n{query}\n" return finalized_query - - -class AsyncEpisodicMemory: - """ - Asynchronous context manager for EpisodicMemory instances. - - This class provides an `async with` interface for `EpisodicMemory` objects, - ensuring that `reference()` is called upon entry and `close()` is called - upon exit, handling the lifecycle management automatically. - """ - - def __init__(self, episodic_memory_instance: EpisodicMemory): - """ - Initializes the AsyncEpisodicMemory context manager. - - Args: - episodic_memory_instance: The EpisodicMemory instance to manage. - """ - self.episodic_memory_instance = episodic_memory_instance - - async def __aenter__(self) -> EpisodicMemory: - """ - Enters the asynchronous context. - - Returns: - The EpisodicMemory instance. - - """ - return self.episodic_memory_instance - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """ - Exits the asynchronous context. - - Decrements the reference count of the managed EpisodicMemory instance, - triggering its closure if the count reaches zero. - """ - await self.episodic_memory_instance.close() diff --git a/src/memmachine/episodic_memory/short_term_memory/session_memory.py b/src/memmachine/episodic_memory/short_term_memory/short_term_memory.py similarity index 51% rename from src/memmachine/episodic_memory/short_term_memory/session_memory.py rename to src/memmachine/episodic_memory/short_term_memory/short_term_memory.py index c42dc7287..0c5be1ecc 100644 --- a/src/memmachine/episodic_memory/short_term_memory/session_memory.py +++ b/src/memmachine/episodic_memory/short_term_memory/short_term_memory.py @@ -4,93 +4,94 @@ This module provides the `SessionMemory` class, which is responsible for storing and managing a sequence of conversational turns (episodes) within a single session. It uses a deque with a fixed capacity and evicts older -episodes when memory limits (number of episodes, message length, or token -count) are reached. Evicted episodes are summarized asynchronously to maintain -context over a longer conversation. +episodes when memory limits (message length) are reached. Evicted episodes +are summarized asynchronously to maintain context over a longer conversation. """ import asyncio import logging +import uuid from collections import deque from memmachine.common.data_types import ExternalServiceAPIError - -from ..data_types import Episode, MemoryContext +from memmachine.common.language_model import LanguageModel +from memmachine.configuration.episodic_config import ShortTermMemoryParams +from memmachine.session_manager_interface import SessionDataManager +from ..data_types import Episode logger = logging.getLogger(__name__) -class SessionMemory: +class ShortTermMemory: # pylint: disable=too-many-instance-attributes """ Manages the short-term memory of conversion context. This class stores a sequence of recent events (episodes) in a deque with a - fixed capacity. When the memory becomes full (based on the number of - events, total message length, or total token count), older events are - evicted and summarized. + fixed capacity. When the memory becomes full (based on the total message length), + older events are evicted and summarized. """ def __init__( self, - model, - summary_system_prompt: str, - summary_user_prompt: str, - capacity: int, - max_message_len: int, - max_token_num: int, - memory_context: MemoryContext, + param: ShortTermMemoryParams, + summary: str = "", + episodes: list[Episode] | None = None, ): # pylint: disable=too-many-arguments # pylint: disable=too-many-positional-arguments """ Initializes the ShortTermMemory instance. - - Args: - model: The language model API for generating summaries. - storage: The memory storage API. - summary_system_prompt: The system prompt for creating the initial - summary. - summary_user_prompt: The user prompt for creating the initial - summary. - capacity: The maximum number of episodes to store. - max_message_len: The maximum total length of all messages in - characters. - max_token_num: The maximum total number of tokens for all - messages. - memory_context: The context (group, agent, user, session) for the - memory. """ - self._model = model - self._summary_user_prompt = summary_user_prompt - self._summary_system_prompt = summary_system_prompt - self._memory: deque[Episode] = deque(maxlen=capacity) - self._capacity = capacity + self._model: LanguageModel = param.llm_model + self._data_manager: SessionDataManager | None = param.data_manager + self._summary_user_prompt = param.summary_prompt_user + self._summary_system_prompt = param.summary_prompt_system + self._memory: deque[Episode] = deque() self._current_episode_count = 0 - self._max_message_len = max_message_len - self._max_token_num = max_token_num + self._max_message_len = param.message_capacity self._current_message_len = 0 - self._current_token_num = 0 - self._summary = "" - self._memory_context = memory_context + self._summary = summary + self._session_key = param.session_key self._summary_task = None + self._closed = False self._lock = asyncio.Lock() + if episodes is not None: + self._memory.extend(episodes) + self._current_episode_count = len(episodes) + for e in episodes: + self._current_message_len += len(e.content) + + @classmethod + async def create(cls, param: ShortTermMemoryParams) -> "ShortTermMemory": + """ + Creates a new ShortTermMemory instance. + """ + if param.data_manager is not None: + try: + await param.data_manager.create_tables() + except ValueError: + pass + try: + summary, episodes, _, _ = await param.data_manager.get_short_term_memory( + param.session_key + ) + return ShortTermMemory(param, summary, episodes) + except ValueError: + pass + return ShortTermMemory(param) def _is_full(self) -> bool: """ Checks if the short-term memory has reached its capacity. - Memory is considered full if the number of events, total message - length, or total token count exceeds their respective maximums. + Memory is considered full if total message + length exceeds its respective maximums. Returns: True if the memory is full, False otherwise. """ - result = ( - self._current_episode_count >= self._capacity - or self._current_message_len >= self._max_message_len - or self._current_token_num >= self._max_token_num - ) + result = self._current_message_len + len(self._summary) > self._max_message_len return result async def add_episode(self, episode: Episode) -> bool: @@ -105,11 +106,12 @@ async def add_episode(self, episode: Episode) -> bool: otherwise. """ async with self._lock: + if self._closed: + raise RuntimeError(f"Memory is closed {self._session_key}") self._memory.append(episode) self._current_episode_count += 1 self._current_message_len += len(episode.content) - self._current_token_num += self._compute_token_num(self._memory[-1]) full = self._is_full() if full: await self._do_evict() @@ -122,40 +124,72 @@ async def _do_evict(self): as possible for current capacity. """ result = [] - # do not clear the episode memory here so rolling episode can be - # used as context - # just remove the episode that left over from previous evition. - while len(self._memory) > self._current_episode_count: + # Remove old messages that have been summarized + while ( + len(self._memory) > self._current_episode_count + and self._current_message_len + len(self._summary) > self._max_message_len + ): + self._current_message_len -= len(self._memory[0].content) self._memory.popleft() + if ( + len(self._memory) == 0 + or self._current_message_len + len(self._summary) <= self._max_message_len + ): + return + for e in self._memory: result.append(e) + # Reset the count so it will only count new episodes self._current_episode_count = 0 - self._current_message_len = 0 - self._current_token_num = 0 # if previous summary task is still running, wait for it if self._summary_task is not None: await self._summary_task self._summary_task = asyncio.create_task(self._create_summary(result)) - async def clear_memory(self): + async def close(self): """ Clears all events and the summary from the short-term memory. - Resets the capacity, message length, and token count to zero. + Resets the message length to zero. """ async with self._lock: + if self._closed: + return + self._closed = True if self._summary_task is not None: - self._summary_task.cancel() + await self._summary_task + self._summary_task = None self._memory.clear() self._current_episode_count = 0 self._current_message_len = 0 - self._current_token_num = 0 self._summary = "" - async def close(self): - """Closes the memory, which currently just involves clearing it.""" - await self.clear_memory() + async def clear_memory(self): + """ + Clear all events and summary. Reset the message length to zero. + """ + async with self._lock: + if self._closed: + return + if self._summary_task is not None: + await self._summary_task + self._summary_task = None + self._memory.clear() + self._current_episode_count = 0 + self._current_message_len = 0 + self._summary = "" + + async def delete_episode(self, uuid: uuid.UUID): + """Delete one episode by uuid""" + async with self._lock: + for e in self._memory: + if e.uuid == uuid: + self._current_episode_count -= 1 + self._current_message_len -= len(e.content) + self._memory.remove(e) + return True + return False async def _create_summary(self, episodes: list[Episode]): """ @@ -181,12 +215,19 @@ async def _create_summary(self, episodes: list[Episode]): meta = repr(entry.user_metadata) episode_content += f"[{str(entry.uuid)} : {meta} : {entry.content}]" msg = self._summary_user_prompt.format( - episodes=episode_content, summary=self._summary + episodes=episode_content, + summary=self._summary, + max_length=self._max_message_len / 2, ) result = await self._model.generate_response( system_prompt=self._summary_system_prompt, user_prompt=msg ) self._summary = result[0] + if self._data_manager is not None: + await self._data_manager.save_short_term_memory( + self._session_key, self._summary, episodes, episodes[-1].sequence_num, len(episodes) + ) + logger.debug("Summary: %s\n", self._summary) except ExternalServiceAPIError: logger.info("External API error when creating summary") @@ -196,49 +237,63 @@ async def _create_summary(self, episodes: list[Episode]): logger.info("Runtime error when creating summary") async def get_session_memory_context( - self, query, limit: int = 0, max_token_num: int = 0 + self, query, limit: int = 0, max_message_length: int = 0, filter: dict[str, str] | None = None ) -> tuple[list[Episode], str]: """ Retrieves context from short-term memory for a given query. This includes the current summary and as many recent episodes as can - fit within a specified token limit. + fit within a specified message length limit. Args: query: The user's query string. - max_token_num: The maximum number of tokens for the context. If 0, + max_message_length: The maximum length of messages for the context. If 0, no limit is applied. + + Returns: + A tuple containing a list of episodes and the current summary. """ + logger.debug("Get session for %s", query) async with self._lock: + if self._closed: + raise RuntimeError(f"Memory is closed {self._session_key}") if self._summary_task is not None: await self._summary_task self._summary_task = None - length = ( - self._compute_token_num(self._summary) - if self._summary is not None - else 0 - ) + length = 0 if self._summary is None else len(self._summary) episodes: deque[Episode] = deque() + for e in reversed(self._memory): - if length >= max_token_num > 0: + if length >= max_message_length > 0: break if len(episodes) >= limit > 0: break - token_num = self._compute_token_num(e) - if length + token_num > max_token_num > 0: + #check if should filter the message + if filter is not None: + if "producer" in filter and filter["producer"] != e.producer_id: + continue + + matched = True + for key, value in filter.items(): + if e.user_metadata.get(key) != value: + matched = False + break + if not matched: + continue + + msg_len = self._compute_episode_length(e) + if length + msg_len > max_message_length > 0: break episodes.appendleft(e) - length += token_num + length += msg_len return list(episodes), self._summary - def _compute_token_num(self, episode: Episode | str) -> int: + def _compute_episode_length(self, episode: Episode) -> int: """ - Computes the total number of tokens in an episodes. + Computes the message length in an episodes. """ result = 0 - if isinstance(episode, str): - return int(len(episode) / 4) # 4 character per token if episode.content is None: return 0 if isinstance(episode.content, str): @@ -246,7 +301,7 @@ def _compute_token_num(self, episode: Episode | str) -> int: else: result += len(repr(episode.content)) if episode.user_metadata is None: - return int(result / 4) # 4 character per token + return result if isinstance(episode.user_metadata, str): result += len(episode.user_metadata) elif isinstance(episode.user_metadata, dict): @@ -255,4 +310,4 @@ def _compute_token_num(self, episode: Episode | str) -> int: result += len(v) else: result += len(repr(v)) - return int(result / 4) # 4 character per token + return result diff --git a/src/memmachine/episodic_memory_manager.py b/src/memmachine/episodic_memory_manager.py new file mode 100644 index 000000000..21b3508d6 --- /dev/null +++ b/src/memmachine/episodic_memory_manager.py @@ -0,0 +1,197 @@ + +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator + +from memmachine.configuration.episodic_config import ( + EpisodicMemoryManagerParam, + EpisodicMemoryParams, +) +from memmachine.episodic_memory.episodic_memory import EpisodicMemory + +from .instance_lru_cache import MemoryInstanceCache + + +class EpisodicMemoryManager: + """ + Manages the lifecycle and access of semantic memory instances. + + This class is responsible for creating, retrieving, and closing + `SemanticMemory` instances based on a session key. It uses a + reference counting mechanism to manage the lifecycle of each memory + instance, ensuring that resources are properly released when no + longer needed. + """ + + def __init__(self, param: EpisodicMemoryManagerParam): + """ + Initializes the SemanticMemoryManager. + + Args: + config: The overall MemMachine configuration. + """ + self._instance_cache: MemoryInstanceCache = MemoryInstanceCache( + param.instance_cache_size, param.max_life_time + ) + self._session_storage = param.session_storage + self._lock = asyncio.Lock() + self._closed = False + self._check_instance_task = asyncio.create_task(self._check_instance_life_time()) + + + async def _check_instance_life_time(self): + while not self._closed: + await asyncio.sleep(2) + async with self._lock: + await self._instance_cache.clean_old_instance() + + @asynccontextmanager + async def open_episodic_memory(self, session_key: str) -> AsyncIterator[EpisodicMemory]: + """ + Asynchronously provides a SemanticMemory instance for a given session key. + + This is an asynchronous context manager. It will create a new + `SemanticMemory` instance if one doesn't exist for the given session key, + or return an existing one. It manages a reference count for each instance. + + Args: + session_key: The unique identifier for the session. + + Yields: + A SemanticMemory instance. + + Raises: + ValueError: If semantic memory is not enabled in the configuration. + """ + instance: EpisodicMemory | None = None + async with self._lock: + if self._closed: + raise RuntimeError(f"Memory is closed {session_key}") + + # Check if the instance is in the cache and in use + instance = self._instance_cache.get(session_key) + if instance is None: + # load from the database + _, _, _, param = await self._session_storage.get_session_info(session_key) + # TODO: callback to instantiate the param + instance = await EpisodicMemory.create(param) + await self._instance_cache.add(session_key, instance) + try: + yield instance + finally: + if instance is not None: + async with self._lock: + self._instance_cache.put(session_key) + + @asynccontextmanager + async def create_episodic_memory(self, session_key: str, param: EpisodicMemoryParams, description: str, metadata: dict, config: dict | None = None) -> AsyncIterator[EpisodicMemory]: + """ + Creates a new episodic memory instance and stores its configuration. + + Args: + session_key: The unique identifier for the session. + param: The parameters for configuring the episodic memory. + description: A brief description of the session. + metadata: User-defined metadata for the session. + + Raises: + ValueError: If a session with the given session_key already exists. + """ + instance: EpisodicMemory | None = None + if config is None: + config = {} + async with self._lock: + if self._closed: + raise RuntimeError(f"Memory is closed {session_key}") + + await self._session_storage.create_new_session(session_key, config, param, description, metadata) + instance = await EpisodicMemory.create(param) + await self._instance_cache.add(session_key, instance) + try: + yield instance + finally: + if instance is not None: + async with self._lock: + self._instance_cache.put(session_key) + + async def delete_episodic_session(self, session_key: str): + """ + Deletes an episodic memory instance and its associated data. + + Args: + session_key: The unique identifier of the session to delete. + """ + instance: EpisodicMemory | None = None + async with self._lock: + if self._closed: + raise RuntimeError(f"Memory is closed {session_key}") + # Check if the instance is in the cache and in use + ref_count = self._instance_cache.get_ref_count(session_key) + instance = self._instance_cache.get(session_key) + if instance and ref_count > 0: + raise RuntimeError(f"Session {session_key} is still in use {ref_count}") + if instance: + self._instance_cache.put(session_key) + self._instance_cache.erase(session_key) + if instance is None: + # Open it + _, _, _, param = await self._session_storage.get_session_info(session_key) + instance = await EpisodicMemory.create(param) + await instance.delete_data() + await instance.close() + await self._session_storage.delete_session(session_key) + + + async def get_episodic_memory_keys(self, filter: dict[str, str] | None) -> list[str]: + """ + Retrieves a list of all available episodic memory session keys. + + Returns: + A list of session keys. + """ + return await self._session_storage.get_sessions(filter) + + async def get_session_configuration(self, session_key: str) -> tuple[dict, str, dict, EpisodicMemoryParams]: + """ + Retrieves the configuration, description, and metadata for a given session. + """ + return await self._session_storage.get_session_info(session_key) + + async def close_session(self, session_key: str): + """ + Closes an idle episodic memory instance and its associated data. + + Args: + session_key: The unique identifier of the session to close. + """ + async with self._lock: + if self._closed: + raise RuntimeError(f"Memory is closed {session_key}") + ref_count = self._instance_cache.get_ref_count(session_key) + if ref_count < 0: + return + if ref_count > 0: + raise RuntimeError(f"Session {session_key} is busy") + instance = self._instance_cache.get(session_key) + if instance is not None: + await instance.close() + self._instance_cache.put(session_key) + self._instance_cache.erase(session_key) + + async def close(self): + """ + Closes all open episodic memory instances and the session storage. + """ + tasks = [] + async with self._lock: + if self._closed: + return + for key in self._instance_cache.keys(): + tasks.append(self._instance_cache.get(key).close()) + await asyncio.gather(*tasks) + await self._session_storage.close() + self._instance_cache.clear() + self._closed = True + + if hasattr(self, "_check_instance_task"): + await self._check_instance_task diff --git a/src/memmachine/instance_lru_cache.py b/src/memmachine/instance_lru_cache.py new file mode 100644 index 000000000..a5ba32f5e --- /dev/null +++ b/src/memmachine/instance_lru_cache.py @@ -0,0 +1,168 @@ +from datetime import datetime +from typing import Any, cast + +from memmachine.episodic_memory.episodic_memory import EpisodicMemory + + +class Node: + """ + Node for the doubly linked list. + Each node stores a key-value pair. + """ + + def __init__(self, key: str | None, value: EpisodicMemory | None): + self.key = key + self.value = value + self.ref_count = 1 + self.last_access = datetime.now() + self.prev: Node = self + self.next: Node = self + + +class MemoryInstanceCache: + """ + A Least Recently Used (LRU) Cache implementation that manage memory instances. + + Attributes: + capacity (int): The maximum number of items the cache can hold. + cache (dict): A dictionary mapping keys to Node objects for O(1) lookups. + head (Node): A sentinel head node for the doubly linked list. + tail (Node): A sentinel tail node for the doubly linked list. + """ + + def __init__(self, capacity: int, max_lifetime: int): + if capacity <= 0: + raise ValueError("Capacity must be a positive integer") + self.capacity = capacity + self.max_lifetime = max_lifetime + self.cache: dict[str, Node] = {} # Stores key -> Node + + # Initialize sentinel head and tail nodes for the doubly linked list. + # head.next points to the most recently used item. + # tail.prev points to the least recently used item. + self.head = Node(None, None) + self.tail = Node(None, None) + self.head.next = self.tail + self.tail.prev = self.head + + def _remove_node(self, node: Node) -> None: + """Removes a node from the doubly linked list.""" + if node.prev and node.next: + prev_node = node.prev + next_node = node.next + prev_node.next = next_node + next_node.prev = prev_node + + def _add_to_front(self, node: Node) -> None: + """Adds a node to the front of the doubly linked list (right after head).""" + node.prev = self.head + node.next = self.head.next + if self.head.next: + self.head.next.prev = node + self.head.next = node + + def clear(self) -> None: + """ + Removes all items from the cache. + """ + self.cache.clear() + self.head.next = self.tail + self.tail.prev = self.head + + def keys(self) -> list[str]: + """ + Returns a list of all keys in the cache. + """ + return list(self.cache.keys()) + + def erase(self, key: str) -> None: + """ + Removes an item from the cache. + """ + if key in self.cache: + node = self.cache[key] + if node.ref_count > 0: + raise RuntimeError(f"Key {key} is still in use {node.ref_count}") + self._remove_node(node) + del self.cache[key] + + def get(self, key: str) -> EpisodicMemory | None: + """ + Retrieves an item from the cache. + Returns the value if the key exists, otherwise -1 (or None/raise KeyError). + Moves the accessed item to the front (most recently used). + """ + if key in self.cache: + node = self.cache[key] + node.ref_count += 1 + # Move accessed node to the front + self._remove_node(node) + self._add_to_front(node) + node.last_access = datetime.now() + return node.value + return None + + def get_ref_count(self, key: Any) -> int: + """ + Retrieves the reference count of an item in the cache. + Returns the reference count if the key exists, otherwise -1. + """ + if key in self.cache: + return self.cache[key].ref_count + return -1 + + async def add(self, key: str, value: EpisodicMemory): + """ + Adds a new item to the cache. + """ + if key in self.cache: + raise ValueError(f"Key {key} already exists") + + # Add new key + lru_node = self.tail.prev + while len(self.cache) >= self.capacity and lru_node != self.head: + if lru_node.ref_count > 0: + lru_node = cast(Node, lru_node.prev) + continue + tmp = lru_node.prev + self._remove_node(lru_node) + if lru_node.value is not None: + await lru_node.value.close() + del self.cache[cast(str, lru_node.key)] + lru_node = tmp + + + new_node = Node(key, value) + self.cache[key] = new_node + self._add_to_front(new_node) + + def put(self, key: str) -> None: + """ + Release the object reference. + """ + if key in self.cache: + # Update existing key's value and move it to the front + node = self.cache[key] + assert node.ref_count > 0 + node.ref_count -= 1 + else: + raise ValueError(f"Key {key} does not exist") + + async def clean_old_instance(self) -> None: + """ + Remove unused instance with long lifetime. + """ + now = datetime.now() + lru_node = self.tail.prev + while lru_node != self.head: + if lru_node.ref_count > 0: + lru_node = cast(Node, lru_node.prev) + continue + tmp = lru_node.prev + if (now - lru_node.last_access).total_seconds() > self.max_lifetime: + self._remove_node(lru_node) + if lru_node.value is not None: + await lru_node.value.close() + del self.cache[cast(str, lru_node.key)] + lru_node = self.tail.prev + lru_node = tmp diff --git a/src/memmachine/session_manager.py b/src/memmachine/session_manager.py new file mode 100644 index 000000000..4ee5beed3 --- /dev/null +++ b/src/memmachine/session_manager.py @@ -0,0 +1,297 @@ +"""Manages database for session config and short term data""" + +import io +import os +import pickle +from typing import Annotated + +from sqlalchemy import ( + JSON, + ForeignKeyConstraint, + Integer, + LargeBinary, + PrimaryKeyConstraint, + String, + and_, + func, + insert, + select, + update, +) +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + relationship, +) + +from memmachine.configuration.episodic_config import EpisodicMemoryParams + +from .session_manager_interface import SessionDataManager + + +# Base class for declarative class definitions +class Base(DeclarativeBase): # pylint: disable=too-few-public-methods + """ + Base class for declarative class definitions. + """ + + +IntColumn = Annotated[int, mapped_column(Integer)] +StringKeyColumn = Annotated[str, mapped_column(String, primary_key=True)] +StringColumn = Annotated[str, mapped_column(String)] +JSONColumen = Annotated[dict, mapped_column(JSON)] +BinaryColumn = Annotated[bytes, mapped_column(LargeBinary)] + + + +class SessionDataManagerImpl(SessionDataManager): + """ + Handle's the session related data persistency. + """ + + class SessionConfig(Base): # pylint: disable=too-few-public-methods + """ORM model for a session configuration. + session_key is the primary key + """ + + __tablename__ = "sessions" + session_key: Mapped[StringKeyColumn] + timestamp: Mapped[IntColumn] + configuration: Mapped[JSONColumen] + param_data: Mapped[BinaryColumn] + description: Mapped[StringColumn] + user_metadata: Mapped[JSONColumen] + __table_args__ = (PrimaryKeyConstraint("session_key"),) + shorttem_memory_data = relationship( + "ShortTermMemoryData", cascade="all, delete-orphan" + ) + + class ShortTermMemoryData(Base): # pylint: disable=too-few-public-methods + """ORM model for short term memory data. + session_key is the primary key + """ + + __tablename__ = "short_term_memory_data" + session_key: Mapped[StringKeyColumn] + summary: Mapped[StringColumn] + last_seq: Mapped[IntColumn] + episode_num: Mapped[IntColumn] + timestamp: Mapped[IntColumn] + __table_args__ = ( + PrimaryKeyConstraint("session_key"), + ForeignKeyConstraint(["session_key"], ["sessions.session_key"]), + ) + + def __init__(self, engine: AsyncEngine, schema: str | None = None): + """Initializes the SessionDataManagerImpl. + + Args: + engine: The SQLAlchemy async engine to use for database connections. + schema: The database schema to use for the tables. + """ + self._engine = engine + self._async_session = async_sessionmaker(bind=self._engine, expire_on_commit=False) + if schema: + for table in Base.metadata.tables.values(): + table.schema = schema + + async def create_tables(self) -> None: + """Creates the necessary tables in the database. + + This method connects to the database and creates all tables defined in the + Base metadata if they don't already exist. + """ + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def close(self): + """Closes the database connection engine.""" + if hasattr(self, "_engine"): + await self._engine.dispose() + + async def create_new_session( + self, session_key: str, configuration: dict, param: EpisodicMemoryParams, description: str, metadata: dict + ): + """ + Creates a new session entry in the database. + + Args: + session_key: The unique identifier for the session. + configuration: A dictionary containing the session's configuration. + description: A brief description of the session. + metadata: A dictionary for user-defined metadata. + param: The episodic memory parameters. + + Raises: + ValueError: If a session with the given session_key already exists. + """ + + buffer = io.BytesIO() + pickle.dump(param, buffer) + buffer.seek(0) + param_data = buffer.getvalue() + async with self._async_session() as dbsession: + # Query for an existing session with the same ID + sessions = ( + await dbsession.execute(select(self.SessionConfig) + .where(self.SessionConfig.session_key == session_key)) + ) + session = sessions.first() + if session is not None: + raise ValueError(f"""Session {session_key} already exists""") + # create a new entry + new_session = self.SessionConfig( + session_key=session_key, + timestamp=int(os.times()[4]), + configuration=configuration, + param_data=param_data, + description=description, + user_metadata=metadata, + ) + dbsession.add(new_session) + await dbsession.commit() + + async def delete_session(self, session_key: str): + """Deletes a session and its related data from the database. + + Args: + session_key: The unique identifier of the session to delete. + """ + async with self._async_session() as dbsession: + # Query for an existing session with the same ID + row = await dbsession.get(self.SessionConfig, session_key) + if row is None: + raise ValueError(f"""Session {session_key} does not exists""") + await dbsession.delete(row) + await dbsession.commit() + return + + + async def get_session_info(self, session_key: str) -> tuple[dict, str, dict, EpisodicMemoryParams]: + """Retrieves a session's data from the database. + + Args: + session_key: The unique identifier of the session to retrieve. + + Returns: + A tuple containing the configuration dictionary, description string, + user metadata dictionary, and the EpisodicMemoryParams object. + + Raises: + ValueError: If the session with the given session_key does not exist. + """ + async with self._async_session() as dbsession: + sessions = await dbsession.execute(select(self.SessionConfig).where( + self.SessionConfig.session_key == session_key)) + session = sessions.scalars().first() + if session is None: + raise ValueError(f"""Session {session_key} does not exists""") + binary_buffer = io.BytesIO(session.param_data) + binary_buffer.seek(0) + param: EpisodicMemoryParams = pickle.load(binary_buffer) + return session.configuration, session.description, session.user_metadata, param + + def _json_contains(self, column, filter): + if self._engine.dialect.name == "mysql": + return func.json_contains(column, func.json_quote(func.json(filter))) + + elif self._engine.dialect.name == "postgresql": + return column.op("@>")(filter) + + elif self._engine.dialect.name == "sqlite": + # SQLite has no JSON_CONTAINS; emulate using json_extract + if not isinstance(filter, dict): + raise ValueError("SQLite emulation only supports dict values") + conditions = [ + func.json_extract(column, f'$.{k}') == v + for k, v in filter.items() + ] + return and_(*conditions) + + else: + raise NotImplementedError(f"json_contains not supported for dialect '{self._engine.dialect.name}'") + + async def get_sessions(self, filter: dict[str, str] | None = None) -> list[str]: + """Retrieves a list of all session keys from the database. + + Returns: + A list of session keys. + """ + if filter is None: + stmt = select(self.SessionConfig.session_key) + else: + stmt = select(self.SessionConfig.session_key).where(self._json_contains(self.SessionConfig.user_metadata, filter)) + async with self._async_session() as dbsession: + sessions = await dbsession.execute(stmt) + return list(sessions.scalars().all()) + + + async def save_short_term_memory(self, session_key: str, summary: str, last_seq: int, episode_num: int): + """Saves or updates the short-term memory data for a session. + + Args: + session_key: The unique identifier for the session. + summary: The summary of the short-term memory. + last_seq: The last sequence number of the episodes in the short-term memory. + episode_num: The number of episodes in the short-term memory. + """ + async with self._async_session() as dbsession: + # Query for an existing session with the same ID + sessions = ( + await dbsession.execute(select(self.SessionConfig) + .where(self.SessionConfig.session_key == session_key) + ) + ) + session = sessions.first() + if session is None: + raise ValueError(f"""Session {session_key} does not exists""") + short_term_datas = ( + await dbsession.execute(select(self.ShortTermMemoryData) + .where(self.ShortTermMemoryData.session_key == session_key)) + ) + short_term_data = short_term_datas.scalars().first() + if short_term_data is not None: + update_stmt = update(self.ShortTermMemoryData).where( + self.ShortTermMemoryData.session_key == session_key + ).values( + summary=summary, + last_seq=last_seq, + episode_num=episode_num, + timestamp=int(os.times()[4]), + ) + await dbsession.execute(update_stmt) + else: + insert_stmt = insert(self.ShortTermMemoryData).values( + session_key=session_key, + summary=summary, + last_seq=last_seq, + episode_num=episode_num, + timestamp=int(os.times()[4]), + ) + await dbsession.execute(insert_stmt) + await dbsession.commit() + + async def get_short_term_memory(self, session_key: str) -> tuple[str, int, int]: + """Retrieves the short-term memory data for a session. + + Args: + session_key: The unique identifier for the session. + + Returns: + A tuple containing the summary string, the number of episodes, and the last sequence number. + + Raises: + ValueError: If no short-term memory data exists for the given session_key. + """ + async with self._async_session() as dbsession: + short_term_data = ( + await dbsession.execute(select(self.ShortTermMemoryData) + .where(self.ShortTermMemoryData.session_key == session_key)) + ).scalars().first() + if short_term_data is None: + raise ValueError( + f"""session {session_key} does not have short term memory""" + ) + return short_term_data.summary, short_term_data.episode_num, short_term_data.last_seq diff --git a/src/memmachine/session_manager_interface.py b/src/memmachine/session_manager_interface.py new file mode 100644 index 000000000..e5a29c26d --- /dev/null +++ b/src/memmachine/session_manager_interface.py @@ -0,0 +1,106 @@ +from abc import ABC, abstractmethod + + + +class SessionDataManager(ABC): + """ + Interface for managing session data, including session configurations and + short-term memory. + """ + + @classmethod + async def close(self): + """ + Closes the database connection. + """ + raise NotImplementedError + + @abstractmethod + async def create_tables(self): + """ + Creates the necessary tables in the database. + """ + raise NotImplementedError + + @abstractmethod + async def create_new_session( + self, session_key: str, configuration: dict, param: 'EpisodicMemoryParams', description: str, metadata: dict + ): + """ + Creates a new session entry in the database. + + Args: + session_key: The unique identifier for the session. + configuration: A dictionary containing the session's configuration. + description: A brief description of the session. + metadata: A dictionary for user-defined metadata. + + Raises: + ValueError: If a session with the given session_key already exists. + """ + raise NotImplementedError + + @abstractmethod + async def delete_session(self, session_key: str): + """ + Deletes a session entry from the database. + + Args: + session_key: The unique identifier of the session to delete. + """ + raise NotImplementedError + + @abstractmethod + async def get_session_info(self, session_key: str) -> tuple[dict, str, dict, 'EpisodicMemoryParams']: + """ + Retrieves the configuration, description, and metadata for a given + session. + + Args: + session_key: The unique identifier of the session. + + Returns: + A tuple containing the configuration dictionary, description string, + metadata dictionary and the EpisodicMemoryParams. + + Raises: + ValueError: If the session with the given session_key does not exist. + """ + raise NotImplementedError + + @abstractmethod + async def get_sessions(self, filter: dict[str, str] | None = None) -> list[str]: + """ + Retrieves a list of all session keys from the database. + + Returns: + A list of session keys. + """ + raise NotImplementedError + + @abstractmethod + async def save_short_term_memory(self, session_key: str, summary: str, last_seq, episode_num: int): + """ + Saves or updates the short-term memory data for a session. + + Args: + session_key: The unique identifier for the session. + summary: The summary of the short-term memory. + episode_num: The number of episodes in the short-term memory. + + Raises: + ValueError: If the session with the given session_key does not exist. + """ + raise NotImplementedError + + @abstractmethod + async def get_short_term_memory(self, session_key: str) -> tuple[str, int, int]: + """ + Retrieves the short-term memory data for a session. + + Args: + session_key: The unique identifier for the session + Returns: + A tuple containing the summary string and the number of episodes and the last sequence number. + """ + raise NotImplementedError diff --git a/src/memmachine/test_episodic_memory_manager.py b/src/memmachine/test_episodic_memory_manager.py new file mode 100644 index 000000000..d4963c0a9 --- /dev/null +++ b/src/memmachine/test_episodic_memory_manager.py @@ -0,0 +1,314 @@ +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from unittest.mock import AsyncMock, MagicMock, patch + +from memmachine.configuration.episodic_config import ( + EpisodicMemoryManagerParam, + EpisodicMemoryParams, +) +from memmachine.common.metrics_factory import MetricsFactory +from memmachine.episodic_memory.episodic_memory import EpisodicMemory +from memmachine.episodic_memory_manager import EpisodicMemoryManager +from memmachine.session_manager import SessionDataManagerImpl + + +@pytest_asyncio.fixture +async def db_engine(): + """Fixture for an in-memory SQLite async engine.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + yield engine + await engine.dispose() + +@pytest_asyncio.fixture +async def mock_session_storage(db_engine): + """Fixture for a mocked SessionDataManager.""" + storage = SessionDataManagerImpl(engine=db_engine) + await storage.create_tables() + return storage + +@pytest.fixture +def mock_metrics_factory(): + """Fixture for a mocked MetricsFactory.""" + global MockMetricsFactory + class MockMetricsFactory(MetricsFactory): + def __init__(self): + self.counters = MagicMock() + self.gauge = MagicMock() + self.histogram = MagicMock() + self.summaries = MagicMock() + + def get_counter(self, name, description, label_names=...): + return self.counters + + def get_summary(self, name, description, label_names=...): + return self.summaries + + def get_gauge(self, name, description, label_names=...): + return self.gauge + + def get_histogram(self, name, description, label_names=...): + return self.histogram + + def reset(self): + pass + + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + + + factory = MockMetricsFactory() + return factory + +@pytest.fixture +def mock_episodic_memory_params(mock_metrics_factory): + """Fixture for a dummy EpisodicMemoryParams object.""" + return EpisodicMemoryParams( + session_key="test_session", + metrics_factory=mock_metrics_factory, + enabled=False + ) + + +@pytest.fixture +def mock_episodic_memory_manager_param(mock_session_storage): + """Fixture for EpisodicMemoryManagerParam.""" + return EpisodicMemoryManagerParam( + session_storage=mock_session_storage, + instance_cache_size=10, + max_life_time=3600, + ) + + +@pytest_asyncio.fixture +async def manager(mock_episodic_memory_manager_param): + """Fixture for an EpisodicMemoryManager instance.""" + return EpisodicMemoryManager(param=mock_episodic_memory_manager_param) + + +@pytest.mark.asyncio +@patch("memmachine.episodic_memory_manager.EpisodicMemory.create", new_callable=AsyncMock) +async def test_create_episodic_memory_success( + mock_create, manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params +): + """Test successfully creating a new episodic memory instance.""" + session_key = "new_session" + description = "A new test session" + metadata = {"owner": "tester"} + mock_instance = AsyncMock(spec=EpisodicMemory) + mock_create.return_value = mock_instance + + async with manager.create_episodic_memory( + session_key, mock_episodic_memory_params, description, metadata + ) as instance: + assert instance is mock_instance + mock_create.assert_awaited_once_with(mock_episodic_memory_params) + assert manager._instance_cache.get_ref_count(session_key) == 1 # 1 from add + + assert manager._instance_cache.get_ref_count(session_key) == 0 # put is called + + +@pytest.mark.asyncio +async def test_create_episodic_memory_already_exists( + manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params +): + """Test that creating a session that already exists raises an error.""" + session_key = "existing_session" + async with manager.create_episodic_memory( + session_key, mock_episodic_memory_params, "", {} + ): + + with pytest.raises(ValueError, match=f"Session {session_key} already exists"): + async with manager.create_episodic_memory( + session_key, mock_episodic_memory_params, "", {} + ): + pass # This part should not be reached + + +@pytest.mark.asyncio +@patch("memmachine.episodic_memory_manager.EpisodicMemory.create", new_callable=AsyncMock) +async def test_open_episodic_memory_new_instance( + mock_create, manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params +): + """Test opening a session for the first time, loading it from storage.""" + session_key = "session_to_open" + mock_instance = AsyncMock(spec=EpisodicMemory) + mock_create.return_value = mock_instance + async with manager.create_episodic_memory( + session_key, mock_episodic_memory_params, "", {} + ) as instance: + assert instance is mock_instance + await manager.close_session(session_key) + + async with manager.open_episodic_memory(session_key) as instance: + assert instance is mock_instance + mock_create.assert_awaited_once_with(mock_episodic_memory_params) + assert manager._instance_cache.get_ref_count(session_key) == 1 + + assert manager._instance_cache.get_ref_count(session_key) == 0 + + +@pytest.mark.asyncio +@patch("memmachine.episodic_memory_manager.EpisodicMemory.create", new_callable=AsyncMock) +async def test_open_episodic_memory_cached_instance( + mock_create, manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params +): + """Test opening a session that is already in the cache.""" + session_key = "cached_session" + mock_instance = AsyncMock(spec=EpisodicMemory) + mock_create.return_value = mock_instance + + # Pre-populate the cache + async with manager.create_episodic_memory(session_key, mock_episodic_memory_params, "", {}): + pass + + mock_create.assert_awaited_once() + mock_create.reset_mock() + + # Open it again + async with manager.open_episodic_memory(session_key) as instance: + assert instance is mock_instance + # Should not call storage or create again + mock_create.assert_not_awaited() + assert manager._instance_cache.get_ref_count(session_key) == 1 + + assert manager._instance_cache.get_ref_count(session_key) == 0 + + +@pytest.mark.asyncio +async def test_delete_episodic_session_not_in_use( + manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params +): + """Test deleting a session that is not currently in use.""" + session_key = "session_to_delete" + mock_instance = AsyncMock(spec=EpisodicMemory) + + with patch("memmachine.episodic_memory_manager.EpisodicMemory.create", return_value=mock_instance): + # Create and release the session so it's in cache but not in use + async with manager.create_episodic_memory(session_key, mock_episodic_memory_params, "", {}): + pass + + assert manager._instance_cache.get_ref_count(session_key) == 0 + + await manager.delete_episodic_session(session_key) + + # Verify it's gone from cache and storage + assert manager._instance_cache.get(session_key) is None + mock_instance.delete_data.assert_awaited_once() + mock_instance.close.assert_awaited_once() + +@pytest.mark.asyncio +async def test_delete_episodic_session_in_use_raises_error( + manager: EpisodicMemoryManager, mock_episodic_memory_params +): + """Test that deleting a session currently in use raises a RuntimeError.""" + session_key = "session_in_use" + mock_instance = AsyncMock(spec=EpisodicMemory) + + with patch("memmachine.episodic_memory_manager.EpisodicMemory.create", return_value=mock_instance): + async with manager.create_episodic_memory(session_key, mock_episodic_memory_params, "", {}): + with pytest.raises(RuntimeError, match=f"Session {session_key} is still in use"): + await manager.delete_episodic_session(session_key) + + +@pytest.mark.asyncio +@patch("memmachine.episodic_memory_manager.EpisodicMemory.create", new_callable=AsyncMock) +async def test_delete_episodic_session_not_in_cache( + mock_create, manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params +): + """Test deleting a session that exists in storage but not in the cache.""" + session_key = "not_in_cache_session" + mock_instance = AsyncMock(spec=EpisodicMemory) + mock_create.return_value = mock_instance + async with manager.create_episodic_memory(session_key, mock_episodic_memory_params, "", {}): + pass + await manager.close_session(session_key) + mock_create.assert_awaited_once_with(mock_episodic_memory_params) + mock_instance.close.assert_awaited_once() + mock_create.reset_mock() + mock_instance.reset_mock() + + await manager.delete_episodic_session(session_key) + + # Should load from storage to delete + mock_instance.delete_data.assert_awaited_once() + mock_instance.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_close_session_not_in_use(manager: EpisodicMemoryManager, mock_episodic_memory_params): + """Test closing a session that is cached but not in use.""" + session_key = "session_to_close" + mock_instance = AsyncMock(spec=EpisodicMemory) + + with patch("memmachine.episodic_memory_manager.EpisodicMemory.create", return_value=mock_instance): + async with manager.create_episodic_memory(session_key, mock_episodic_memory_params, "", {}): + pass # Enters and exits context, ref_count becomes 1 + + await manager.close_session(session_key) + + mock_instance.close.assert_awaited_once() + assert manager._instance_cache.get(session_key) is None + + +@pytest.mark.asyncio +async def test_close_session_in_use_raises_error(manager: EpisodicMemoryManager, mock_episodic_memory_params): + """Test that closing a session in use raises a RuntimeError.""" + session_key = "busy_session" + with patch("memmachine.episodic_memory_manager.EpisodicMemory.create", new_callable=AsyncMock): + async with manager.create_episodic_memory(session_key, mock_episodic_memory_params, "", {}): + with pytest.raises(RuntimeError, match=f"Session {session_key} is busy"): + await manager.close_session(session_key) + + +@pytest.mark.asyncio +async def test_manager_close(manager: EpisodicMemoryManager, mock_session_storage, mock_episodic_memory_params): + """Test the main close method of the manager.""" + session_key1 = "s1" + session_key2 = "s2" + mock_instance1 = AsyncMock(spec=EpisodicMemory) + mock_instance2 = AsyncMock(spec=EpisodicMemory) + + with patch("memmachine.episodic_memory_manager.EpisodicMemory.create", side_effect=[mock_instance1, mock_instance2]): + # Create two sessions and leave them in the cache + async with manager.create_episodic_memory(session_key1, mock_episodic_memory_params, "", {}): + pass + async with manager.create_episodic_memory(session_key2, mock_episodic_memory_params, "", {}): + pass + + await manager.close() + + # Verify instances were closed and removed from cache + mock_instance1.close.assert_awaited_once() + mock_instance2.close.assert_awaited_once() + assert manager._instance_cache.get(session_key1) is None + assert manager._instance_cache.get(session_key2) is None + + # Verify manager is in a closed state + with pytest.raises(RuntimeError, match="Memory is closed"): + async with manager.open_episodic_memory("any_session"): + pass + + +@pytest.mark.asyncio +async def test_manager_methods_after_close_raise_error(manager: EpisodicMemoryManager): + """Test that all public methods raise RuntimeError after the manager is closed.""" + await manager.close() + + with pytest.raises(RuntimeError, match="Memory is closed"): + async with manager.create_episodic_memory("s", MagicMock(), "", {}): + pass + + with pytest.raises(RuntimeError, match="Memory is closed"): + async with manager.open_episodic_memory("s"): + pass + + with pytest.raises(RuntimeError, match="Memory is closed"): + await manager.delete_episodic_session("s") + + with pytest.raises(RuntimeError, match="Memory is closed"): + await manager.close_session("s") \ No newline at end of file diff --git a/src/memmachine/test_instance_lru_cache.py b/src/memmachine/test_instance_lru_cache.py new file mode 100644 index 000000000..f49163abb --- /dev/null +++ b/src/memmachine/test_instance_lru_cache.py @@ -0,0 +1,257 @@ +""" +Unit test for the MemoryInstanceCache. +""" +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from memmachine.instance_lru_cache import MemoryInstanceCache + + +@pytest.fixture +def mock_episodic_memory(): + """Fixture to create a mock EpisodicMemory object with an async close method.""" + + def _create_mock(name: str): + mock_memory = MagicMock(name=name) + mock_memory.close = AsyncMock() + return mock_memory + + return _create_mock + + +def test_init_invalid_capacity(): + """Test that initializing with zero or negative capacity raises ValueError.""" + with pytest.raises(ValueError, match="Capacity must be a positive integer"): + MemoryInstanceCache(capacity=0, max_lifetime=60) + with pytest.raises(ValueError, match="Capacity must be a positive integer"): + MemoryInstanceCache(capacity=-1, max_lifetime=60) + + +def test_init_valid_capacity(): + """Test successful initialization.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + assert cache.capacity == 2 + assert len(cache.cache) == 0 + + +@pytest.mark.asyncio +async def test_add_and_get(mock_episodic_memory): + """Test adding an item and then getting it.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + + await cache.add("key1", mem1) + + assert "key1" in cache.keys() + assert cache.get_ref_count("key1") == 1 + + retrieved_mem = cache.get("key1") + assert retrieved_mem is mem1 + assert cache.get_ref_count("key1") == 2 + + +@pytest.mark.asyncio +async def test_add_existing_key_raises_error(mock_episodic_memory): + """Test that adding a key that already exists raises a ValueError.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + await cache.add("key1", mem1) + + with pytest.raises(ValueError, match="Key key1 already exists"): + await cache.add("key1", mem1) + + +def test_get_nonexistent_key(): + """Test that getting a non-existent key returns None.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + assert cache.get("nonexistent") is None + + +@pytest.mark.asyncio +async def test_put(mock_episodic_memory): + """Test the put method to decrease the reference count.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + + await cache.add("key1", mem1) + assert cache.get_ref_count("key1") == 1 + + _ = cache.get("key1") + assert cache.get_ref_count("key1") == 2 + + cache.put("key1") + assert cache.get_ref_count("key1") == 1 + + cache.put("key1") + assert cache.get_ref_count("key1") == 0 + + +def test_put_nonexistent_key(): + """Test that calling put on a non-existent key raises a ValueError.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + with pytest.raises(ValueError, match="Key key1 does not exist"): + cache.put("key1") + + +@pytest.mark.asyncio +async def test_put_below_zero_raises_assertion_error(mock_episodic_memory): + """Test that put raises an error if ref_count goes below zero.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + await cache.add("key1", mem1) + + cache.put("key1") # ref_count becomes 0 + assert cache.get_ref_count("key1") == 0 + + with pytest.raises(AssertionError): + cache.put("key1") # Should fail as ref_count is already 0 + + +@pytest.mark.asyncio +async def test_lru_eviction(mock_episodic_memory): + """Test that the least recently used item is evicted when capacity is full.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + mem2 = mock_episodic_memory("mem2") + mem3 = mock_episodic_memory("mem3") + + # Add two items + await cache.add("key1", mem1) + await cache.add("key2", mem2) + assert sorted(cache.keys()) == ["key1", "key2"] + + # Release both items + cache.put("key1") + cache.put("key2") + assert cache.get_ref_count("key1") == 0 + assert cache.get_ref_count("key2") == 0 + + # Add a third item, which should evict the LRU item ('key1') + await cache.add("key3", mem3) + + # Check that key1 is gone and its close method was called + assert "key1" not in cache.keys() + assert sorted(cache.keys()) == ["key2", "key3"] + mem1.close.assert_awaited_once() + mem2.close.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_lru_eviction_with_in_use_item(mock_episodic_memory): + """Test that an in-use (ref_count > 0) item is not evicted.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + mem2 = mock_episodic_memory("mem2") + mem3 = mock_episodic_memory("mem3") + + await cache.add("key1", mem1) # LRU + await cache.add("key2", mem2) # MRU + + # key1 is in use, key2 is not + cache.put("key2") + assert cache.get_ref_count("key1") == 1 + assert cache.get_ref_count("key2") == 0 + + # Try to add key3. It should evict key2, not key1. + await cache.add("key3", mem3) + + assert "key2" not in cache.keys() + assert sorted(cache.keys()) == ["key1", "key3"] + mem1.close.assert_not_awaited() + mem2.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_lru_order_on_get(mock_episodic_memory): + """Test that `get` moves an item to the most recently used position.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + mem2 = mock_episodic_memory("mem2") + mem3 = mock_episodic_memory("mem3") + + await cache.add("key1", mem1) + await cache.add("key2", mem2) + + # Access key1, making it the MRU + _ = cache.get("key1") + + # Release all references + cache.put("key1") + cache.put("key1") + cache.put("key2") + + # Add key3. This should evict key2 (the new LRU) + await cache.add("key3", mem3) + + assert "key2" not in cache.keys() + assert sorted(cache.keys()) == ["key1", "key3"] + mem2.close.assert_awaited_once() + mem1.close.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_erase(mock_episodic_memory): + """Test the erase method.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + mem1 = mock_episodic_memory("mem1") + await cache.add("key1", mem1) + + # Cannot erase while in use + with pytest.raises(RuntimeError, match="Key key1 is still in use 1"): + cache.erase("key1") + + # Release and then erase + cache.put("key1") + assert cache.get_ref_count("key1") == 0 + cache.erase("key1") + + assert "key1" not in cache.keys() + assert cache.get("key1") is None + + +def test_get_ref_count_nonexistent(): + """Test get_ref_count for a non-existent key returns -1.""" + cache = MemoryInstanceCache(capacity=2, max_lifetime=60) + assert cache.get_ref_count("nonexistent") == -1 + + +@pytest.mark.asyncio +async def test_keys(mock_episodic_memory): + """Test the keys method.""" + cache = MemoryInstanceCache(capacity=3, max_lifetime=60) + assert cache.keys() == [] + + await cache.add("key1", mock_episodic_memory("mem1")) + await cache.add("key2", mock_episodic_memory("mem2")) + + assert sorted(cache.keys()) == ["key1", "key2"] + + cache.put("key1") + cache.erase("key1") + + assert cache.keys() == ["key2"] + +@pytest.mark.asyncio +async def test_clean_old_instance(mock_episodic_memory): + """Test the clean_old_instance method.""" + cache = MemoryInstanceCache(capacity=4, max_lifetime=1) + mem1 = mock_episodic_memory("mem1") + mem2 = mock_episodic_memory("mem2") + + await cache.add("key1", mem1) + await cache.add("key2", mem2) + assert sorted(cache.keys()) == ["key1", "key2"] + assert cache.get_ref_count("key1") == 1 + assert cache.get_ref_count("key2") == 1 + await asyncio.sleep(2) + await cache.clean_old_instance() + # Would not delete item because of the reference + assert sorted(cache.keys()) == ["key1", "key2"] + cache.put("key1") + assert cache.get_ref_count("key1") == 0 + assert cache.get_ref_count("key2") == 1 + # Should remove key1 now + await cache.clean_old_instance() + assert sorted(cache.keys()) == ["key2"] diff --git a/src/memmachine/test_session_manager.py b/src/memmachine/test_session_manager.py new file mode 100644 index 000000000..5efc66082 --- /dev/null +++ b/src/memmachine/test_session_manager.py @@ -0,0 +1,249 @@ +from unittest.mock import MagicMock + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from memmachine.common.metrics_factory import MetricsFactory +from memmachine.configuration.episodic_config import EpisodicMemoryParams +from memmachine.session_manager import SessionDataManagerImpl + + +@pytest.fixture +def mock_metrics_factory(): + """Fixture for a mocked MetricsFactory.""" + global MockMetricsFactory + class MockMetricsFactory(MetricsFactory): + def __init__(self): + self.counters = MagicMock() + self.gauge = MagicMock() + self.histogram = MagicMock() + self.summaries = MagicMock() + + def get_counter(self, name, description, label_names=...): + return self.counters + + def get_summary(self, name, description, label_names=...): + return self.summaries + + def get_gauge(self, name, description, label_names=...): + return self.gauge + + def get_histogram(self, name, description, label_names=...): + return self.histogram + + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + + factory = MockMetricsFactory() + return factory + +@pytest.fixture +def episodic_memory_params(mock_metrics_factory): + """Fixture for a dummy EpisodicMemoryParams object.""" + return EpisodicMemoryParams( + session_key="test_session", + metrics_factory=mock_metrics_factory, + enabled=False + ) + + +@pytest_asyncio.fixture +async def db_engine(): + """Fixture for an in-memory SQLite async engine.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def session_manager(db_engine: AsyncEngine): + """Fixture for SessionDataManagerImpl, with tables created.""" + manager = SessionDataManagerImpl(engine=db_engine) + await manager.create_tables() + yield manager + await manager.close() + + +@pytest.mark.asyncio +async def test_create_tables(db_engine: AsyncEngine): + """Test that create_tables creates the expected tables.""" + manager = SessionDataManagerImpl(engine=db_engine) + await manager.create_tables() + +@pytest.mark.asyncio +async def test_create_new_session(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test creating a new session successfully.""" + session_key = "session1" + config = {"key": "value"} + description = "A test session" + metadata = {"user": "tester"} + + await session_manager.create_new_session(session_key, config, episodic_memory_params, description, metadata) + + ret_config, ret_desc, ret_meta, ret_param = await session_manager.get_session_info(session_key) + + assert ret_config == config + assert ret_desc == description + assert ret_meta == metadata + assert ret_param.session_key == episodic_memory_params.session_key + + +@pytest.mark.asyncio +async def test_create_existing_session_raises_error(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test that creating a session that already exists raises a ValueError.""" + session_key = "session1" + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + + with pytest.raises(ValueError, match=f"Session {session_key} already exists"): + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + + +@pytest.mark.asyncio +async def test_delete_session(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test deleting an existing session.""" + session_key = "session_to_delete" + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + + # Verify it exists + await session_manager.get_session_info(session_key) + + # Delete it + await session_manager.delete_session(session_key) + + # Verify it's gone + with pytest.raises(ValueError, match=f"Session {session_key} does not exists"): + await session_manager.get_session_info(session_key) + + +@pytest.mark.asyncio +async def test_delete_nonexistent_session_raises_error(session_manager: SessionDataManagerImpl): + """Test that deleting a non-existent session raises a ValueError.""" + session_key = "nonexistent_session" + with pytest.raises(ValueError, match=f"Session {session_key} does not exists"): + await session_manager.delete_session(session_key) + + +@pytest.mark.asyncio +async def test_get_session_info_nonexistent_raises_error(session_manager: SessionDataManagerImpl): + """Test that getting info for a non-existent session raises a ValueError.""" + session_key = "nonexistent_session" + with pytest.raises(ValueError, match=f"Session {session_key} does not exists"): + await session_manager.get_session_info(session_key) + + +@pytest.mark.asyncio +async def test_get_sessions(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test retrieving session keys with and without filters.""" + # Create some sessions + await session_manager.create_new_session("session1", {}, episodic_memory_params, "", {"tag": "A", "user": "1"}) + await session_manager.create_new_session("session2", {}, episodic_memory_params, "", {"tag": "B", "user": "1"}) + await session_manager.create_new_session("session3", {}, episodic_memory_params, "", {"tag": "A", "user": "2"}) + + # Get all sessions + all_sessions = await session_manager.get_sessions() + assert sorted(all_sessions) == ["session1", "session2", "session3"] + + # Filter by tag 'A' + sessions_A = await session_manager.get_sessions(filter={"tag": "A"}) + assert sorted(sessions_A) == ["session1", "session3"] + + # Filter by user '1' + sessions_user1 = await session_manager.get_sessions(filter={"user": "1"}) + assert sorted(sessions_user1) == ["session1", "session2"] + + # Filter by tag 'B' and user '1' + sessions_B_user1 = await session_manager.get_sessions(filter={"tag": "B", "user": "1"}) + assert sessions_B_user1 == ["session2"] + + # Filter with no matches + no_match = await session_manager.get_sessions(filter={"tag": "C"}) + assert no_match == [] + + +@pytest.mark.asyncio +async def test_get_sessions_empty(session_manager: SessionDataManagerImpl): + """Test retrieving sessions when none exist.""" + sessions = await session_manager.get_sessions() + assert sessions == [] + + +@pytest.mark.asyncio +async def test_save_short_term_memory_new(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test saving short-term memory for a session for the first time.""" + session_key = "stm_session_1" + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + + summary = "This is a summary." + last_seq = 10 + episode_num = 5 + + await session_manager.save_short_term_memory(session_key, summary, last_seq, episode_num) + + ret_summary, ret_ep_num, ret_last_seq = await session_manager.get_short_term_memory(session_key) + + assert ret_summary == summary + assert ret_last_seq == last_seq + assert ret_ep_num == episode_num + + +@pytest.mark.asyncio +async def test_save_short_term_memory_update(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test updating existing short-term memory for a session.""" + session_key = "stm_session_2" + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + + # First save + await session_manager.save_short_term_memory(session_key, "summary1", 1, 1) + + # Second save (update) + summary = "This is an updated summary." + last_seq = 20 + episode_num = 10 + await session_manager.save_short_term_memory(session_key, summary, last_seq, episode_num) + + ret_summary, ret_ep_num, ret_last_seq = await session_manager.get_short_term_memory(session_key) + + assert ret_summary == summary + assert ret_last_seq == last_seq + assert ret_ep_num == episode_num + + +@pytest.mark.asyncio +async def test_save_short_term_memory_for_nonexistent_session(session_manager: SessionDataManagerImpl): + """Test that saving STM for a non-existent session raises a ValueError.""" + session_key = "nonexistent_session" + with pytest.raises(ValueError, match=f"Session {session_key} does not exists"): + await session_manager.save_short_term_memory(session_key, "summary", 1, 1) + + +@pytest.mark.asyncio +async def test_get_short_term_memory_nonexistent(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test that getting STM for which none has been saved raises a ValueError.""" + session_key = "session_no_stm" + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + + with pytest.raises(ValueError, match=f"session {session_key} does not have short term memory"): + await session_manager.get_short_term_memory(session_key) + + +@pytest.mark.asyncio +async def test_delete_session_cascades_to_short_term_memory(session_manager: SessionDataManagerImpl, episodic_memory_params: EpisodicMemoryParams): + """Test that deleting a session also deletes its associated short-term memory data.""" + session_key = "cascade_delete_session" + await session_manager.create_new_session(session_key, {}, episodic_memory_params, "", {}) + await session_manager.save_short_term_memory(session_key, "summary", 1, 1) + + # Verify STM exists + await session_manager.get_short_term_memory(session_key) + + # Delete the parent session + await session_manager.delete_session(session_key) + + # Verify STM is also gone + with pytest.raises(ValueError, match=f"session {session_key} does not have short term memory"): + await session_manager.get_short_term_memory(session_key) diff --git a/tests/memmachine/episodic_memory/short_term_memory/test_session_memory.py b/tests/memmachine/episodic_memory/short_term_memory/test_session_memory.py deleted file mode 100644 index 5958a41b5..000000000 --- a/tests/memmachine/episodic_memory/short_term_memory/test_session_memory.py +++ /dev/null @@ -1,159 +0,0 @@ -import uuid -from datetime import datetime -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from memmachine.episodic_memory.data_types import ( - ContentType, - Episode, - MemoryContext, -) -from memmachine.episodic_memory.short_term_memory.session_memory import ( - SessionMemory, -) - - -def create_test_episode(**kwargs): - """Helper function to create a valid Episode for testing.""" - defaults = { - "uuid": uuid.uuid4(), - "episode_type": "message", - "content_type": ContentType.STRING, - "content": "default content", - "timestamp": datetime.now(), - "group_id": "group1", - "session_id": "session1", - "producer_id": "user1", - } - defaults.update(kwargs) - return Episode(**defaults) - - -@pytest.fixture -def mock_model(): - """Fixture for a mocked language model.""" - model = MagicMock() - model.generate_response = AsyncMock(return_value=["summary"]) - return model - - -@pytest.fixture -def memory_context(): - """Fixture for a sample MemoryContext.""" - return MemoryContext( - group_id="group1", - agent_id={"agent1"}, - user_id={"user1"}, - session_id="session1", - ) - - -@pytest.fixture -def memory(mock_model, memory_context): - """Fixture for a SessionMemory instance.""" - return SessionMemory( - model=mock_model, - summary_system_prompt="System prompt", - summary_user_prompt="User prompt: {episodes} {summary}", - capacity=3, - max_message_len=100, - max_token_num=50, - memory_context=memory_context, - ) - - -@pytest.mark.asyncio -class TestSessionMemoryPublicAPI: - """Test suite for the public API of SessionMemory.""" - - async def test_initial_state(self, memory): - """Test that the SessionMemory instance is initialized correctly.""" - episodes, summary = await memory.get_session_memory_context(query="test") - assert episodes == [] - assert summary == "" - - async def test_add_episode(self, memory): - """Test adding an episode to the session memory.""" - episode1 = create_test_episode(content="Hello") - await memory.add_episode(episode1) - - episodes, summary = await memory.get_session_memory_context(query="test") - # session memory is not full - assert episodes == [episode1] - assert summary == "" - - episode2 = create_test_episode(content="World") - await memory.add_episode(episode2) - - episodes, summary = await memory.get_session_memory_context(query="test") - assert episodes == [episode1, episode2] - assert summary == "" - - # session memory is full - episode3 = create_test_episode(content="!") - await memory.add_episode(episode3) - episodes, summary = await memory.get_session_memory_context(query="test") - assert episodes == [episode1, episode2, episode3] - assert summary == "summary" - - # New episode push out the oldest one: episode1 - episode4 = create_test_episode(content="?") - await memory.add_episode(episode4) - episodes, summary = await memory.get_session_memory_context(query="test") - assert episodes == [episode2, episode3, episode4] - assert summary == "summary" - - async def test_clear_memory(self, memory): - """Test clearing the memory.""" - await memory.add_episode(create_test_episode(content="test")) - - await memory.clear_memory() - - episodes, summary = await memory.get_session_memory_context(query="test") - assert episodes == [] - assert summary == "" - - async def test_close(self, memory): - """Test closing the memory.""" - await memory.add_episode(create_test_episode(content="test")) - await memory.close() - episodes, summary = await memory.get_session_memory_context(query="test") - assert episodes == [] - assert summary == "" - - async def test_get_session_memory_context(self, memory): - """Test retrieving session memory context.""" - ep1 = create_test_episode(content="a" * 20) # 5 tokens - ep2 = create_test_episode(content="b" * 20) # 5 tokens - ep3 = create_test_episode(content="c" * 20) # 5 tokens - await memory.add_episode(ep1) - await memory.add_episode(ep2) - await memory.add_episode(ep3) - - # Test with token limit that fits all - # summary (5) + ep1 (5) + ep2 (5) + ep3 (5) = 20 tokens - episodes, summary = await memory.get_session_memory_context( - query="test", max_token_num=21 - ) - assert len(episodes) == 3 - assert episodes == [ep1, ep2, ep3] - assert summary == "summary" - - # Test with a tighter token limit. Episodes are retrieved newest first. - # length=5 (summary) - # add ep1 (5 tokens), length=10. - # add ep2 (5 tokens), length=15. Now length >= 14, so loop breaks. - # Should return [ep1, ep2] - episodes, summary = await memory.get_session_memory_context( - query="test", max_token_num=14 - ) - assert len(episodes) == 2 - assert episodes == [ep2, ep3] - - # Test with episode limit - episodes, summary = await memory.get_session_memory_context( - query="test", limit=1 - ) - assert len(episodes) == 1 - assert episodes == [ep3] diff --git a/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py b/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py new file mode 100644 index 000000000..33aaeaede --- /dev/null +++ b/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py @@ -0,0 +1,239 @@ +import uuid +from datetime import datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio + +from memmachine.common.language_model import LanguageModel +from memmachine.configuration.episodic_config import ShortTermMemoryParams, EpisodicMemoryParams +from memmachine.episodic_memory.data_types import ( + ContentType, + Episode, +) +from memmachine.session_manager_interface import SessionDataManager +from memmachine.episodic_memory.short_term_memory.short_term_memory import ( + ShortTermMemory, +) + + +def create_test_episode(**kwargs): + """Helper function to create a valid Episode for testing.""" + defaults = { + "uuid": uuid.uuid4(), + "sequence_num": 1, + "session_key": "session1", + "episode_type": "message", + "content_type": ContentType.STRING, + "content": "default content", + "timestamp": datetime.now(), + "producer_id": "user1", + "producer_role": "user", + "produced_for_id": None, + "user_metadata": None, + } + defaults.update(kwargs) + return Episode(**defaults) + + +class MockShortTermMemoryDataManager(SessionDataManager): + """Mock implementation of SessionDataManager for testing.""" + + def __init__(self): + self.data = {} + self.tables_created = False + + async def create_tables(self): + self.tables_created = True + + async def save_short_term_memory( + self, session_key: str, summary: str, episodes: list[Episode], seq: int, num: int + ): + self.data[session_key] = (summary, episodes, seq, num) + + async def get_short_term_memory( + self, session_key: str + ) -> tuple[str, list[Episode], int, int]: + if session_key not in self.data: + raise ValueError(f"No data for session key {session_key}") + return self.data[session_key] + + async def close(self): + self.data = {} + self.tables_created = False + + async def create_new_session(self, session_key, configuration, param, description, metadata): + pass + + async def create_tables(self): + pass + + async def create_new_session( + self, session_key: str, configuration: dict, param: EpisodicMemoryParams, description: str, metadata: dict + ): + pass + + async def delete_session(self, session_key: str): + pass + + async def get_session_info(self, session_key: str) -> tuple[dict, str, dict, EpisodicMemoryParams]: + pass + + async def get_sessions(self, filter: dict[str, str] | None = None) -> list[str]: + pass + +class MockLanguageModel(MagicMock, LanguageModel): + """Mock implementation of LanguageModel for testing.""" + + async def generate_response( + self, + system_prompt: str | None = None, + user_prompt: str | None = None, + tools: list | None = None, + tool_choice: str | dict[str, str] = "", + max_attempts: int = 1, + ) -> tuple[str, Any]: + return "summary", "" + + +@pytest.fixture +def mock_model(): + """Fixture for a mocked language model.""" + model = MockLanguageModel() + model.generate_response = AsyncMock(return_value=["summary"]) + return model + + +@pytest.fixture +def mock_data_manager(): + """Fixture for a mocked ShortTermMemoryDataManager.""" + return MockShortTermMemoryDataManager() + + +@pytest.fixture +def short_term_memory_param(mock_model, mock_data_manager): + """Fixture for short_term_memory_params.""" + return ShortTermMemoryParams( + session_key="session1", + llm_model=mock_model, + data_manager=mock_data_manager, + summary_prompt_system="System prompt", + summary_prompt_user="User prompt: {episodes} {summary} {max_length}", + message_capacity=16, + ) + + +@pytest_asyncio.fixture +async def memory(short_term_memory_param): + """Fixture for a SessionMemory instance.""" + return await ShortTermMemory.create(short_term_memory_param) + + +@pytest.mark.asyncio +class TestSessionMemoryPublicAPI: + """Test suite for the public API of SessionMemory.""" + + async def test_initial_state(self, memory): + """Test that the SessionMemory instance is initialized correctly.""" + episodes, summary = await memory.get_session_memory_context(query="test") + assert episodes == [] + assert summary == "" + + async def test_add_episode(self, memory): + """Test adding an episode to the session memory.""" + episode1 = create_test_episode(content="Hello") + await memory.add_episode(episode1) + + episodes, summary = await memory.get_session_memory_context(query="test") + # session memory is not full + assert episodes == [episode1] + assert summary == "" + + episode2 = create_test_episode(content="World") + await memory.add_episode(episode2) + + episodes, summary = await memory.get_session_memory_context(query="test") + assert episodes == [episode1, episode2] + assert summary == "" + + # session memory is full + episode3 = create_test_episode(content="!" * 7) + await memory.add_episode(episode3) + episodes, summary = await memory.get_session_memory_context(query="test") + assert episodes == [episode1, episode2, episode3] + assert summary == "summary" + + # New episode push out the oldest one: episode1 + episode4 = create_test_episode(content="??") + await memory.add_episode(episode4) + episodes, summary = await memory.get_session_memory_context(query="test") + assert episodes == [episode3, episode4] + assert summary == "summary" + + async def test_clear_memory(self, memory): + """Test clearing the memory.""" + await memory.add_episode(create_test_episode(content="test")) + + await memory.clear_memory() + + episodes, summary = await memory.get_session_memory_context(query="test") + assert episodes == [] + assert summary == "" + + async def test_delete_episode(self, memory): + """Test deleting an episode from the memory.""" + ep1 = create_test_episode(content="a") + ep2 = create_test_episode(content="b") + ep3 = create_test_episode(content="c") + await memory.add_episode(ep1) + await memory.add_episode(ep2) + await memory.add_episode(ep3) + + await memory.delete_episode(ep2.uuid) + episodes, _ = await memory.get_session_memory_context(query="test") + assert episodes == [ep1, ep3] + + async def test_close(self, memory): + """Test closing the memory.""" + await memory.add_episode(create_test_episode(content="test")) + await memory.close() + with pytest.raises(RuntimeError): + await memory.add_episode(create_test_episode) + with pytest.raises(RuntimeError): + _, _ = await memory.get_session_memory_context(query="test") + + async def test_get_session_memory_context(self, memory): + """Test retrieving session memory context.""" + ep1 = create_test_episode(content="a" * 6) + ep2 = create_test_episode(content="b" * 6) + ep3 = create_test_episode(content="c" * 6) + await memory.add_episode(ep1) + await memory.add_episode(ep2) + await memory.add_episode(ep3) + + # Test with message length limit that fits all + episodes, summary = await memory.get_session_memory_context( + query="test", max_message_length=100 + ) + assert len(episodes) == 3 + assert episodes == [ep1, ep2, ep3] + assert summary == "summary" + + # Test with a tighter message length limit. Episodes are retrieved newest first. + # length=7 (summary) + # add ep1 (length 6), length=13. + # add ep2 (length 6), length=19. Now length >= 19, so loop breaks. + # Should return [ep1, ep2] + episodes, summary = await memory.get_session_memory_context( + query="test", max_message_length=19 + ) + assert len(episodes) == 2 + assert episodes == [ep2, ep3] + + # Test with episode limit + episodes, summary = await memory.get_session_memory_context( + query="test", limit=1 + ) + assert len(episodes) == 1 + assert episodes == [ep3]