diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 33d880f30..d2b2337e7 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -81,4 +81,13 @@ def available_providers() -> List[ProviderSpec]: pip_packages=["mcp"], ), ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="mem0", + module="llama_stack.providers.remote.tool_runtime.mem0_memory", + config_class="llama_stack.providers.remote.tool_runtime.mem0_memory.config.Mem0ToolRuntimeConfig", + pip_packages=["mem0"], + ), + ), ] diff --git a/llama_stack/providers/remote/tool_runtime/mem0_memory/__init__.py b/llama_stack/providers/remote/tool_runtime/mem0_memory/__init__.py index d5fd1b928..18f99e5ad 100644 --- a/llama_stack/providers/remote/tool_runtime/mem0_memory/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/mem0_memory/__init__.py @@ -12,7 +12,7 @@ from .config import Mem0ToolRuntimeConfig from .memory import Mem0MemoryToolRuntimeImpl -async def get_provider_impl(config: Mem0ToolRuntimeConfig, deps: Dict[str, Any]): - impl = Mem0MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) +async def get_adapter_impl(config: Mem0ToolRuntimeConfig, _deps): + impl = Mem0MemoryToolRuntimeImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/tool_runtime/mem0_memory/config.py b/llama_stack/providers/remote/tool_runtime/mem0_memory/config.py index f4267b8e7..f1ff9c2fc 100644 --- a/llama_stack/providers/remote/tool_runtime/mem0_memory/config.py +++ b/llama_stack/providers/remote/tool_runtime/mem0_memory/config.py @@ -13,7 +13,7 @@ class Mem0ToolRuntimeConfig(BaseModel): """Configuration for Mem0 Tool Runtime""" host: Optional[str] = "https://api.mem0.ai" - api_key: str + api_key: Optional[str] = None top_k: int = 10 org_id: Optional[str] = None project_id: Optional[str] = None diff --git a/llama_stack/providers/remote/tool_runtime/mem0_memory/memory.py b/llama_stack/providers/remote/tool_runtime/mem0_memory/memory.py index 9c8c9478e..9c0b5c585 100644 --- a/llama_stack/providers/remote/tool_runtime/mem0_memory/memory.py +++ b/llama_stack/providers/remote/tool_runtime/mem0_memory/memory.py @@ -53,12 +53,8 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim def __init__( self, config: Mem0ToolRuntimeConfig, - vector_io_api: VectorIO, - inference_api: Inference, ): self.config = config - self.vector_io_api = vector_io_api - self.inference_api = inference_api # Mem0 API configuration self.api_base_url = config.host @@ -69,7 +65,7 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim # Validate configuration if not self.api_key: raise ValueError("Mem0 API Key not provided") - if not (self.org_id and self.project_id): + if (self.org_id is not None) != (self.project_id is not None): raise ValueError("Both org_id and project_id must be provided") # Setup headers @@ -116,10 +112,11 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim payload = { "messages": [{"role": "user", "content": content}], "metadata": {"document_id": doc.document_id}, - "org_id": self.org_id, - "project_id": self.project_id, "user_id": vector_db_id, } + if self.org_id and self.project_id: + payload["org_id"] = self.org_id + payload["project_id"] = self.project_id response = requests.post( urljoin(self.api_base_url, "/v1/memories/"), @@ -127,6 +124,7 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim json=payload, timeout=60 ) + print(response.json()) response.raise_for_status() except requests.exceptions.RequestException as e: log.error(f"Failed to insert document to Mem0: {str(e)}") @@ -144,11 +142,6 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim if not chunks: return - await self.vector_io_api.insert_chunks( - chunks=chunks, - vector_db_id=vector_db_id, - ) - async def query( self, content: InterleavedContent, @@ -159,39 +152,40 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim return RAGQueryResult(content=None) query_config = query_config or RAGQueryConfig() - query = await generate_rag_query( - query_config.query_generator_config, - content, - inference_api=self.inference_api, - ) + query = content + print(query) # Search Mem0 memory via API mem0_chunks = [] - try: - payload = { - "query": query, - "org_id": self.org_id, - "project_id": self.project_id, - } + for vector_db_id in vector_db_ids: + try: + payload = { + "query": query, + "user_id": vector_db_id + } + if self.org_id and self.project_id: + payload["org_id"] = self.org_id + payload["project_id"] = self.project_id - response = requests.post( - urljoin(self.api_base_url, "/v1/memories/search/"), - headers=self.headers, - json=payload, - timeout=60 - ) - response.raise_for_status() - - mem0_results = response.json() - mem0_chunks = [ - TextContentItem( - text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}" + response = requests.post( + urljoin(self.api_base_url, "/v1/memories/search/"), + headers=self.headers, + json=payload, + timeout=60 ) - for result in mem0_results - ] - except requests.exceptions.RequestException as e: - log.error(f"Failed to search Mem0: {str(e)}") - # Continue with vector store search even if Mem0 fails + print(response.json()) + response.raise_for_status() + + mem0_results = response.json() + mem0_chunks = [ + TextContentItem( + text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}" + ) + for result in mem0_results + ] + except requests.exceptions.RequestException as e: + log.error(f"Failed to search Mem0: {str(e)}") + # Continue with vector store search even if Mem0 fails if not mem0_chunks: return RAGQueryResult(content=None) @@ -216,12 +210,12 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim # encountering fatals. return [ ToolDef( - name="query_from_memory", - description="Retrieve context from memory", + name="query_from_mem0", + description="Retrieve context from mem0", ), ToolDef( - name="insert_into_memory", - description="Insert documents into memory", + name="insert_into_mem0", + description="Insert documents into mem0", ), ] diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 0fee6808c..4e5a95ce8 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -28,4 +28,5 @@ distribution_spec: - remote::tavily-search - inline::code-interpreter - inline::rag-runtime + - remote::mem0 image_type: conda diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d14cb3aad..bbded4c93 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate: "remote::tavily-search", "inline::code-interpreter", "inline::rag-runtime", + "remote::mem0", ], } name = "ollama" diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 5b5c9c253..4933b2ad2 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -121,3 +121,5 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::rag + provider_id: mem0 \ No newline at end of file diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 3cc1cb2ac..5da90472b 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -69,6 +69,9 @@ providers: config: openai_api_key: ${env.OPENAI_API_KEY:} tool_runtime: + - provider_id: mem0 + provider_type: remote::mem0 + config: {} - provider_id: brave-search provider_type: remote::brave-search config: @@ -110,3 +113,5 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::rag + provider_id: mem0 \ No newline at end of file