mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
run with ollama
This commit is contained in:
parent
b7e1fc0090
commit
55612cb252
8 changed files with 59 additions and 47 deletions
|
@ -81,4 +81,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=["mcp"],
|
pip_packages=["mcp"],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="mem0",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.mem0_memory",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.mem0_memory.config.Mem0ToolRuntimeConfig",
|
||||||
|
pip_packages=["mem0"],
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -12,7 +12,7 @@ from .config import Mem0ToolRuntimeConfig
|
||||||
from .memory import Mem0MemoryToolRuntimeImpl
|
from .memory import Mem0MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: Mem0ToolRuntimeConfig, deps: Dict[str, Any]):
|
async def get_adapter_impl(config: Mem0ToolRuntimeConfig, _deps):
|
||||||
impl = Mem0MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
impl = Mem0MemoryToolRuntimeImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -13,7 +13,7 @@ class Mem0ToolRuntimeConfig(BaseModel):
|
||||||
"""Configuration for Mem0 Tool Runtime"""
|
"""Configuration for Mem0 Tool Runtime"""
|
||||||
|
|
||||||
host: Optional[str] = "https://api.mem0.ai"
|
host: Optional[str] = "https://api.mem0.ai"
|
||||||
api_key: str
|
api_key: Optional[str] = None
|
||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
org_id: Optional[str] = None
|
org_id: Optional[str] = None
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
|
|
|
@ -53,12 +53,8 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Mem0ToolRuntimeConfig,
|
config: Mem0ToolRuntimeConfig,
|
||||||
vector_io_api: VectorIO,
|
|
||||||
inference_api: Inference,
|
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vector_io_api = vector_io_api
|
|
||||||
self.inference_api = inference_api
|
|
||||||
|
|
||||||
# Mem0 API configuration
|
# Mem0 API configuration
|
||||||
self.api_base_url = config.host
|
self.api_base_url = config.host
|
||||||
|
@ -69,7 +65,7 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Mem0 API Key not provided")
|
raise ValueError("Mem0 API Key not provided")
|
||||||
if not (self.org_id and self.project_id):
|
if (self.org_id is not None) != (self.project_id is not None):
|
||||||
raise ValueError("Both org_id and project_id must be provided")
|
raise ValueError("Both org_id and project_id must be provided")
|
||||||
|
|
||||||
# Setup headers
|
# Setup headers
|
||||||
|
@ -116,10 +112,11 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [{"role": "user", "content": content}],
|
"messages": [{"role": "user", "content": content}],
|
||||||
"metadata": {"document_id": doc.document_id},
|
"metadata": {"document_id": doc.document_id},
|
||||||
"org_id": self.org_id,
|
|
||||||
"project_id": self.project_id,
|
|
||||||
"user_id": vector_db_id,
|
"user_id": vector_db_id,
|
||||||
}
|
}
|
||||||
|
if self.org_id and self.project_id:
|
||||||
|
payload["org_id"] = self.org_id
|
||||||
|
payload["project_id"] = self.project_id
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(self.api_base_url, "/v1/memories/"),
|
urljoin(self.api_base_url, "/v1/memories/"),
|
||||||
|
@ -127,6 +124,7 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=60
|
timeout=60
|
||||||
)
|
)
|
||||||
|
print(response.json())
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
log.error(f"Failed to insert document to Mem0: {str(e)}")
|
log.error(f"Failed to insert document to Mem0: {str(e)}")
|
||||||
|
@ -144,11 +142,6 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.vector_io_api.insert_chunks(
|
|
||||||
chunks=chunks,
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
|
@ -159,39 +152,40 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
return RAGQueryResult(content=None)
|
return RAGQueryResult(content=None)
|
||||||
|
|
||||||
query_config = query_config or RAGQueryConfig()
|
query_config = query_config or RAGQueryConfig()
|
||||||
query = await generate_rag_query(
|
query = content
|
||||||
query_config.query_generator_config,
|
print(query)
|
||||||
content,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search Mem0 memory via API
|
# Search Mem0 memory via API
|
||||||
mem0_chunks = []
|
mem0_chunks = []
|
||||||
try:
|
for vector_db_id in vector_db_ids:
|
||||||
payload = {
|
try:
|
||||||
"query": query,
|
payload = {
|
||||||
"org_id": self.org_id,
|
"query": query,
|
||||||
"project_id": self.project_id,
|
"user_id": vector_db_id
|
||||||
}
|
}
|
||||||
|
if self.org_id and self.project_id:
|
||||||
|
payload["org_id"] = self.org_id
|
||||||
|
payload["project_id"] = self.project_id
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(self.api_base_url, "/v1/memories/search/"),
|
urljoin(self.api_base_url, "/v1/memories/search/"),
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=60
|
timeout=60
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
mem0_results = response.json()
|
|
||||||
mem0_chunks = [
|
|
||||||
TextContentItem(
|
|
||||||
text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}"
|
|
||||||
)
|
)
|
||||||
for result in mem0_results
|
print(response.json())
|
||||||
]
|
response.raise_for_status()
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
log.error(f"Failed to search Mem0: {str(e)}")
|
mem0_results = response.json()
|
||||||
# Continue with vector store search even if Mem0 fails
|
mem0_chunks = [
|
||||||
|
TextContentItem(
|
||||||
|
text=f"id:{result.get('metadata', {}).get('document_id', 'unknown')}; content:{result.get('memory', '')}"
|
||||||
|
)
|
||||||
|
for result in mem0_results
|
||||||
|
]
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
log.error(f"Failed to search Mem0: {str(e)}")
|
||||||
|
# Continue with vector store search even if Mem0 fails
|
||||||
|
|
||||||
if not mem0_chunks:
|
if not mem0_chunks:
|
||||||
return RAGQueryResult(content=None)
|
return RAGQueryResult(content=None)
|
||||||
|
@ -216,12 +210,12 @@ class Mem0MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntim
|
||||||
# encountering fatals.
|
# encountering fatals.
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
name="query_from_memory",
|
name="query_from_mem0",
|
||||||
description="Retrieve context from memory",
|
description="Retrieve context from mem0",
|
||||||
),
|
),
|
||||||
ToolDef(
|
ToolDef(
|
||||||
name="insert_into_memory",
|
name="insert_into_mem0",
|
||||||
description="Insert documents into memory",
|
description="Insert documents into mem0",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -28,4 +28,5 @@ distribution_spec:
|
||||||
- remote::tavily-search
|
- remote::tavily-search
|
||||||
- inline::code-interpreter
|
- inline::code-interpreter
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
|
- remote::mem0
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
|
@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"remote::tavily-search",
|
"remote::tavily-search",
|
||||||
"inline::code-interpreter",
|
"inline::code-interpreter",
|
||||||
"inline::rag-runtime",
|
"inline::rag-runtime",
|
||||||
|
"remote::mem0",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
name = "ollama"
|
name = "ollama"
|
||||||
|
|
|
@ -121,3 +121,5 @@ tool_groups:
|
||||||
provider_id: rag-runtime
|
provider_id: rag-runtime
|
||||||
- toolgroup_id: builtin::code_interpreter
|
- toolgroup_id: builtin::code_interpreter
|
||||||
provider_id: code-interpreter
|
provider_id: code-interpreter
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: mem0
|
|
@ -69,6 +69,9 @@ providers:
|
||||||
config:
|
config:
|
||||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
|
- provider_id: mem0
|
||||||
|
provider_type: remote::mem0
|
||||||
|
config: {}
|
||||||
- provider_id: brave-search
|
- provider_id: brave-search
|
||||||
provider_type: remote::brave-search
|
provider_type: remote::brave-search
|
||||||
config:
|
config:
|
||||||
|
@ -110,3 +113,5 @@ tool_groups:
|
||||||
provider_id: rag-runtime
|
provider_id: rag-runtime
|
||||||
- toolgroup_id: builtin::code_interpreter
|
- toolgroup_id: builtin::code_interpreter
|
||||||
provider_id: code-interpreter
|
provider_id: code-interpreter
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: mem0
|
Loading…
Add table
Add a link
Reference in a new issue