run with ollama

This commit is contained in:
Dev-Khant 2025-02-03 13:16:15 +05:30
parent b7e1fc0090
commit 55612cb252
8 changed files with 59 additions and 47 deletions

View file

@ -81,4 +81,13 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["mcp"],
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="mem0",
module="llama_stack.providers.remote.tool_runtime.mem0_memory",
config_class="llama_stack.providers.remote.tool_runtime.mem0_memory.config.Mem0ToolRuntimeConfig",
pip_packages=["mem0"],
),
),
]

View file

@ -12,7 +12,7 @@ from .config import Mem0ToolRuntimeConfig
from .memory import Mem0MemoryToolRuntimeImpl
async def get_provider_impl(config: Mem0ToolRuntimeConfig, deps: Dict[str, Any]):
impl = Mem0MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
async def get_adapter_impl(config: Mem0ToolRuntimeConfig, _deps):
impl = Mem0MemoryToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -13,7 +13,7 @@ class Mem0ToolRuntimeConfig(BaseModel):
"""Configuration for Mem0 Tool Runtime"""
host: Optional[str] = "https://api.mem0.ai"
api_key: str
api_key: Optional[str] = None
top_k: int = 10
org_id: Optional[str] = None
project_id: Optional[str] = None

View file

@ -53,12 +53,8 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
def __init__(
self,
config: Mem0ToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
# Mem0 API configuration
self.api_base_url = config.host
@ -69,7 +65,7 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
# Validate configuration
if not self.api_key:
raise ValueError("Mem0 API Key not provided")
if not (self.org_id and self.project_id):
if (self.org_id is not None) != (self.project_id is not None):
raise ValueError("Both org_id and project_id must be provided")
# Setup headers
@ -116,10 +112,11 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
payload = {
"messages": [{"role": "user", "content": content}],
"metadata": {"document_id": doc.document_id},
"org_id": self.org_id,
"project_id": self.project_id,
"user_id": vector_db_id,
}
if self.org_id and self.project_id:
payload["org_id"] = self.org_id
payload["project_id"] = self.project_id
response = requests.post(
urljoin(self.api_base_url, "/v1/memories/"),
@ -127,6 +124,7 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
json=payload,
timeout=60
)
print(response.json())
response.raise_for_status()
except requests.exceptions.RequestException as e:
log.error(f"Failed to insert document to Mem0: {str(e)}")
@ -144,11 +142,6 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query(
self,
content: InterleavedContent,
@ -159,39 +152,40 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
return RAGQueryResult(content=None)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
query = content
print(query)
# Search Mem0 memory via API
mem0_chunks = []
try:
payload = {
"query": query,
"org_id": self.org_id,
"project_id": self.project_id,
}
for vector_db_id in vector_db_ids:
try:
payload = {
"query": query,
"user_id": vector_db_id
}
if self.org_id and self.project_id:
payload["org_id"] = self.org_id
payload["project_id"] = self.project_id
response = requests.post(
urljoin(self.api_base_url, "/v1/memories/search/"),
headers=self.headers,
json=payload,
timeout=60
)
response.raise_for_status()
mem0_results = response.json()
mem0_chunks = [
TextContentItem(
text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}"
response = requests.post(
urljoin(self.api_base_url, "/v1/memories/search/"),
headers=self.headers,
json=payload,
timeout=60
)
for result in mem0_results
]
except requests.exceptions.RequestException as e:
log.error(f"Failed to search Mem0: {str(e)}")
# Continue with vector store search even if Mem0 fails
print(response.json())
response.raise_for_status()
mem0_results = response.json()
mem0_chunks = [
TextContentItem(
text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}"
)
for result in mem0_results
]
except requests.exceptions.RequestException as e:
log.error(f"Failed to search Mem0: {str(e)}")
# Continue with vector store search even if Mem0 fails
if not mem0_chunks:
return RAGQueryResult(content=None)
@ -216,12 +210,12 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
# encountering fatals.
return [
ToolDef(
name="query_from_memory",
description="Retrieve context from memory",
name="query_from_mem0",
description="Retrieve context from mem0",
),
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
name="insert_into_mem0",
description="Insert documents into mem0",
),
]

View file

@ -28,4 +28,5 @@ distribution_spec:
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::mem0
image_type: conda

View file

@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::mem0",
],
}
name = "ollama"

View file

@ -121,3 +121,5 @@ tool_groups:
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::rag
provider_id: mem0

View file

@ -69,6 +69,9 @@ providers:
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: mem0
provider_type: remote::mem0
config: {}
- provider_id: brave-search
provider_type: remote::brave-search
config:
@ -110,3 +113,5 @@ tool_groups:
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::rag
provider_id: mem0