From 1a7490470a4622f53f17485a548899b5cf501396 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 22 Jan 2025 10:04:16 -0800 Subject: [PATCH] [memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832) See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ``` --- .../openapi_generator/pyopenapi/operations.py | 6 + docs/resources/llama-stack-spec.html | 1191 +++++++++-------- docs/resources/llama-stack-spec.yaml | 997 +++++++------- llama_stack/apis/tools/__init__.py | 1 + llama_stack/apis/tools/rag_tool.py | 95 ++ llama_stack/apis/tools/tools.py | 10 +- llama_stack/distribution/resolver.py | 2 + llama_stack/distribution/routers/routers.py | 46 +- llama_stack/distribution/server/endpoints.py | 22 +- llama_stack/distribution/stack.py | 3 +- llama_stack/distribution/store/registry.py | 2 +- .../inline/agents/meta_reference/__init__.py | 3 +- .../agents/meta_reference/agent_instance.py | 95 +- .../inline/agents/meta_reference/agents.py | 12 +- .../code_interpreter/code_interpreter.py | 4 +- .../inline/tool_runtime/memory/__init__.py | 4 +- .../inline/tool_runtime/memory/config.py | 83 +- .../tool_runtime/memory/context_retriever.py | 52 +- .../inline/tool_runtime/memory/memory.py | 174 ++- .../providers/registry/tool_runtime.py | 2 +- .../tool_runtime/bing_search/bing_search.py | 4 +- .../tool_runtime/brave_search/brave_search.py | 4 +- .../model_context_protocol.py | 4 +- .../tavily_search/tavily_search.py | 4 +- .../wolfram_alpha/wolfram_alpha.py | 4 +- .../providers/tests/agents/conftest.py | 14 +- .../providers/tests/agents/fixtures.py | 4 +- .../providers/tests/agents/test_agents.py | 4 +- .../tests/vector_io/test_vector_io.py | 15 +- .../providers/utils/memory/vector_store.py | 20 +- llama_stack/scripts/test_rag_via_curl.py | 105 ++ llama_stack/templates/together/build.yaml | 2 +- llama_stack/templates/together/run.yaml | 5 +- 33 files changed, 1648 insertions(+), 1345 deletions(-) create mode 100644 llama_stack/apis/tools/rag_tool.py create mode 100644 llama_stack/scripts/test_rag_via_curl.py diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index 4cea9d970..abeb16936 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -172,10 +172,16 @@ def _get_endpoint_functions( def _get_defining_class(member_fn: str, derived_cls: type) -> type: "Find the class in which a member function is first defined in a class inheritance hierarchy." + # This import must be dynamic here + from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime + # iterate in reverse member resolution order to find most specific class first for cls in reversed(inspect.getmro(derived_cls)): for name, _ in inspect.getmembers(cls, inspect.isfunction): if name == member_fn: + # HACK ALERT + if cls == RAGToolRuntime: + return ToolRuntime return cls raise ValidationError( diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 459a53888..f00d7b291 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -1108,98 +1108,6 @@ ] } }, - "/v1/memory-banks/{memory_bank_id}": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/MemoryBank" - }, - { - "type": "null" - } - ] - } - } - } - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "memory_bank_id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ] - }, - "delete": { - "responses": { - "200": { - "description": "OK" - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "memory_bank_id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ] - } - }, "/v1/models/{model_id}": { "get": { "responses": { @@ -1848,6 +1756,98 @@ ] } }, + "/v1/vector-dbs/{vector_db_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/VectorDB" + }, + { + "type": "null" + } + ] + } + } + } + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "vector_db_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "vector_db_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/health": { "get": { "responses": { @@ -1887,7 +1887,7 @@ ] } }, - "/v1/memory/insert": { + "/v1/vector-io/insert": { "post": { "responses": { "200": { @@ -1895,7 +1895,7 @@ } }, "tags": [ - "Memory" + "VectorIO" ], "parameters": [ { @@ -1917,6 +1917,49 @@ } } ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InsertChunksRequest" + } + } + }, + "required": true + } + } + }, + "/v1/tool-runtime/rag-tool/insert-documents": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Index documents so they can be used by the RAG system", + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ], "requestBody": { "content": { "application/json": { @@ -2300,105 +2343,6 @@ } } }, - "/v1/memory-banks": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListMemoryBanksResponse" - } - } - } - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ] - }, - "post": { - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBank" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBank" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBank" - }, - { - "$ref": "#/components/schemas/GraphMemoryBank" - } - ] - } - } - } - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/RegisterMemoryBankRequest" - } - } - }, - "required": true - } - } - }, "/v1/models": { "get": { "responses": { @@ -2912,6 +2856,92 @@ ] } }, + "/v1/vector-dbs": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListVectorDBsResponse" + } + } + } + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ] + }, + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VectorDB" + } + } + } + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RegisterVectorDbRequest" + } + } + }, + "required": true + } + } + }, "/v1/telemetry/events": { "post": { "responses": { @@ -3003,7 +3033,7 @@ } } }, - "/v1/memory/query": { + "/v1/vector-io/query": { "post": { "responses": { "200": { @@ -3011,14 +3041,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/QueryDocumentsResponse" + "$ref": "#/components/schemas/QueryChunksResponse" } } } } }, "tags": [ - "Memory" + "VectorIO" ], "parameters": [ { @@ -3044,7 +3074,57 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/QueryDocumentsRequest" + "$ref": "#/components/schemas/QueryChunksRequest" + } + } + }, + "required": true + } + } + }, + "/v1/tool-runtime/rag-tool/query-context": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RAGQueryResult" + } + } + } + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Query the RAG system for context; typically invoked by the agent", + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryContextRequest" } } }, @@ -5851,118 +5931,6 @@ "aggregated_results" ] }, - "GraphMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "graph", - "default": "graph" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type" - ] - }, - "KeyValueMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "keyvalue", - "default": "keyvalue" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type" - ] - }, - "KeywordMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "keyword", - "default": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type" - ] - }, - "MemoryBank": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBank" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBank" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBank" - }, - { - "$ref": "#/components/schemas/GraphMemoryBank" - } - ] - }, "Session": { "type": "object", "properties": { @@ -5981,9 +5949,6 @@ "started_at": { "type": "string", "format": "date-time" - }, - "memory_bank": { - "$ref": "#/components/schemas/MemoryBank" } }, "additionalProperties": false, @@ -5995,53 +5960,6 @@ ], "title": "A single session of an interaction with an Agentic System." }, - "VectorMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "vector", - "default": "vector" - }, - "embedding_model": { - "type": "string" - }, - "chunk_size_in_tokens": { - "type": "integer" - }, - "embedding_dimension": { - "type": "integer", - "default": 384 - }, - "overlap_size_in_tokens": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type", - "embedding_model", - "chunk_size_in_tokens" - ] - }, "AgentStepResponse": { "type": "object", "properties": { @@ -7012,6 +6930,40 @@ "data" ] }, + "VectorDB": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "provider_resource_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "vector_db", + "default": "vector_db" + }, + "embedding_model": { + "type": "string" + }, + "embedding_dimension": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "provider_resource_id", + "provider_id", + "type", + "embedding_model", + "embedding_dimension" + ] + }, "HealthInfo": { "type": "object", "properties": { @@ -7024,7 +6976,64 @@ "status" ] }, - "MemoryBankDocument": { + "InsertChunksRequest": { + "type": "object", + "properties": { + "vector_db_id": { + "type": "string" + }, + "chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "content", + "metadata" + ] + } + }, + "ttl_seconds": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "vector_db_id", + "chunks" + ] + }, + "RAGDocument": { "type": "object", "properties": { "document_id": { @@ -7088,23 +7097,24 @@ "InsertDocumentsRequest": { "type": "object", "properties": { - "bank_id": { - "type": "string" - }, "documents": { "type": "array", "items": { - "$ref": "#/components/schemas/MemoryBankDocument" + "$ref": "#/components/schemas/RAGDocument" } }, - "ttl_seconds": { + "vector_db_id": { + "type": "string" + }, + "chunk_size_in_tokens": { "type": "integer" } }, "additionalProperties": false, "required": [ - "bank_id", - "documents" + "documents", + "vector_db_id", + "chunk_size_in_tokens" ] }, "InvokeToolRequest": { @@ -7113,7 +7123,7 @@ "tool_name": { "type": "string" }, - "args": { + "kwargs": { "type": "object", "additionalProperties": { "oneOf": [ @@ -7142,7 +7152,7 @@ "additionalProperties": false, "required": [ "tool_name", - "args" + "kwargs" ] }, "ToolInvocationResult": { @@ -7193,21 +7203,6 @@ "data" ] }, - "ListMemoryBanksResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/MemoryBank" - } - } - }, - "additionalProperties": false, - "required": [ - "data" - ] - }, "ListModelsResponse": { "type": "object", "properties": { @@ -7356,6 +7351,21 @@ "data" ] }, + "ListVectorDBsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/VectorDB" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ] + }, "LogSeverity": { "type": "string", "enum": [ @@ -7873,10 +7883,10 @@ "job_uuid" ] }, - "QueryDocumentsRequest": { + "QueryChunksRequest": { "type": "object", "properties": { - "bank_id": { + "vector_db_id": { "type": "string" }, "query": { @@ -7910,11 +7920,11 @@ }, "additionalProperties": false, "required": [ - "bank_id", + "vector_db_id", "query" ] }, - "QueryDocumentsResponse": { + "QueryChunksResponse": { "type": "object", "properties": { "chunks": { @@ -7925,18 +7935,36 @@ "content": { "$ref": "#/components/schemas/InterleavedContent" }, - "token_count": { - "type": "integer" - }, - "document_id": { - "type": "string" + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, "required": [ "content", - "token_count", - "document_id" + "metadata" ] } }, @@ -7953,6 +7981,111 @@ "scores" ] }, + "DefaultRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default", + "default": "default" + }, + "separator": { + "type": "string", + "default": " " + } + }, + "additionalProperties": false, + "required": [ + "type", + "separator" + ] + }, + "LLMRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm", + "default": "llm" + }, + "model": { + "type": "string" + }, + "template": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ] + }, + "RAGQueryConfig": { + "type": "object", + "properties": { + "query_generator_config": { + "$ref": "#/components/schemas/RAGQueryGeneratorConfig" + }, + "max_tokens_in_context": { + "type": "integer", + "default": 4096 + }, + "max_chunks": { + "type": "integer", + "default": 5 + } + }, + "additionalProperties": false, + "required": [ + "query_generator_config", + "max_tokens_in_context", + "max_chunks" + ] + }, + "RAGQueryGeneratorConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig" + }, + { + "$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } + ] + }, + "QueryContextRequest": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + }, + "query_config": { + "$ref": "#/components/schemas/RAGQueryConfig" + }, + "vector_db_ids": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "content", + "query_config", + "vector_db_ids" + ] + }, + "RAGQueryResult": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + } + }, + "additionalProperties": false + }, "QueryCondition": { "type": "object", "properties": { @@ -8139,108 +8272,6 @@ "scoring_functions" ] }, - "GraphMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "graph", - "default": "graph" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type" - ] - }, - "KeyValueMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "keyvalue", - "default": "keyvalue" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type" - ] - }, - "KeywordMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "keyword", - "default": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type" - ] - }, - "VectorMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "vector", - "default": "vector" - }, - "embedding_model": { - "type": "string" - }, - "chunk_size_in_tokens": { - "type": "integer" - }, - "overlap_size_in_tokens": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type", - "embedding_model", - "chunk_size_in_tokens" - ] - }, - "RegisterMemoryBankRequest": { - "type": "object", - "properties": { - "memory_bank_id": { - "type": "string" - }, - "params": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBankParams" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBankParams" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBankParams" - }, - { - "$ref": "#/components/schemas/GraphMemoryBankParams" - } - ] - }, - "provider_id": { - "type": "string" - }, - "provider_memory_bank_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_id", - "params" - ] - }, "RegisterModelRequest": { "type": "object", "properties": { @@ -8413,6 +8444,31 @@ "provider_id" ] }, + "RegisterVectorDbRequest": { + "type": "object", + "properties": { + "vector_db_id": { + "type": "string" + }, + "embedding_model": { + "type": "string" + }, + "embedding_dimension": { + "type": "integer" + }, + "provider_id": { + "type": "string" + }, + "provider_vector_db_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "vector_db_id", + "embedding_model" + ] + }, "RunEvalRequest": { "type": "object", "properties": { @@ -9128,6 +9184,10 @@ { "name": "Datasets" }, + { + "name": "DefaultRAGQueryGeneratorConfig", + "description": "" + }, { "name": "EfficiencyConfig", "description": "" @@ -9158,14 +9218,6 @@ "name": "EvaluateRowsRequest", "description": "" }, - { - "name": "GraphMemoryBank", - "description": "" - }, - { - "name": "GraphMemoryBankParams", - "description": "" - }, { "name": "GreedySamplingStrategy", "description": "" @@ -9189,6 +9241,10 @@ "name": "InferenceStep", "description": "" }, + { + "name": "InsertChunksRequest", + "description": "" + }, { "name": "InsertDocumentsRequest", "description": "" @@ -9220,26 +9276,14 @@ "name": "JsonType", "description": "" }, - { - "name": "KeyValueMemoryBank", - "description": "" - }, - { - "name": "KeyValueMemoryBankParams", - "description": "" - }, - { - "name": "KeywordMemoryBank", - "description": "" - }, - { - "name": "KeywordMemoryBankParams", - "description": "" - }, { "name": "LLMAsJudgeScoringFnParams", "description": "" }, + { + "name": "LLMRAGQueryGeneratorConfig", + "description": "" + }, { "name": "ListDatasetsResponse", "description": "" @@ -9248,10 +9292,6 @@ "name": "ListEvalTasksResponse", "description": "" }, - { - "name": "ListMemoryBanksResponse", - "description": "" - }, { "name": "ListModelsResponse", "description": "" @@ -9284,6 +9324,10 @@ "name": "ListToolsResponse", "description": "" }, + { + "name": "ListVectorDBsResponse", + "description": "" + }, { "name": "LogEventRequest", "description": "" @@ -9296,20 +9340,6 @@ "name": "LoraFinetuningConfig", "description": "" }, - { - "name": "Memory" - }, - { - "name": "MemoryBank", - "description": "" - }, - { - "name": "MemoryBankDocument", - "description": "" - }, - { - "name": "MemoryBanks" - }, { "name": "MemoryRetrievalStep", "description": "" @@ -9388,6 +9418,14 @@ "name": "QATFinetuningConfig", "description": "" }, + { + "name": "QueryChunksRequest", + "description": "" + }, + { + "name": "QueryChunksResponse", + "description": "" + }, { "name": "QueryCondition", "description": "" @@ -9397,12 +9435,8 @@ "description": "" }, { - "name": "QueryDocumentsRequest", - "description": "" - }, - { - "name": "QueryDocumentsResponse", - "description": "" + "name": "QueryContextRequest", + "description": "" }, { "name": "QuerySpanTreeResponse", @@ -9416,6 +9450,22 @@ "name": "QueryTracesResponse", "description": "" }, + { + "name": "RAGDocument", + "description": "" + }, + { + "name": "RAGQueryConfig", + "description": "" + }, + { + "name": "RAGQueryGeneratorConfig", + "description": "" + }, + { + "name": "RAGQueryResult", + "description": "" + }, { "name": "RegexParserScoringFnParams", "description": "" @@ -9428,10 +9478,6 @@ "name": "RegisterEvalTaskRequest", "description": "" }, - { - "name": "RegisterMemoryBankRequest", - "description": "" - }, { "name": "RegisterModelRequest", "description": "" @@ -9448,6 +9494,10 @@ "name": "RegisterToolGroupRequest", "description": "" }, + { + "name": "RegisterVectorDbRequest", + "description": "" + }, { "name": "ResponseFormat", "description": "" @@ -9701,12 +9751,14 @@ "description": "" }, { - "name": "VectorMemoryBank", - "description": "" + "name": "VectorDB", + "description": "" }, { - "name": "VectorMemoryBankParams", - "description": "" + "name": "VectorDBs" + }, + { + "name": "VectorIO" }, { "name": "VersionInfo", @@ -9729,8 +9781,6 @@ "EvalTasks", "Inference", "Inspect", - "Memory", - "MemoryBanks", "Models", "PostTraining (Coming Soon)", "Safety", @@ -9740,7 +9790,9 @@ "SyntheticDataGeneration (Coming Soon)", "Telemetry", "ToolGroups", - "ToolRuntime" + "ToolRuntime", + "VectorDBs", + "VectorIO" ] }, { @@ -9793,19 +9845,19 @@ "DataConfig", "Dataset", "DatasetFormat", + "DefaultRAGQueryGeneratorConfig", "EfficiencyConfig", "EmbeddingsRequest", "EmbeddingsResponse", "EvalTask", "EvaluateResponse", "EvaluateRowsRequest", - "GraphMemoryBank", - "GraphMemoryBankParams", "GreedySamplingStrategy", "HealthInfo", "ImageContentItem", "ImageDelta", "InferenceStep", + "InsertChunksRequest", "InsertDocumentsRequest", "InterleavedContent", "InterleavedContentItem", @@ -9813,14 +9865,10 @@ "Job", "JobStatus", "JsonType", - "KeyValueMemoryBank", - "KeyValueMemoryBankParams", - "KeywordMemoryBank", - "KeywordMemoryBankParams", "LLMAsJudgeScoringFnParams", + "LLMRAGQueryGeneratorConfig", "ListDatasetsResponse", "ListEvalTasksResponse", - "ListMemoryBanksResponse", "ListModelsResponse", "ListPostTrainingJobsResponse", "ListProvidersResponse", @@ -9829,11 +9877,10 @@ "ListShieldsResponse", "ListToolGroupsResponse", "ListToolsResponse", + "ListVectorDBsResponse", "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", - "MemoryBank", - "MemoryBankDocument", "MemoryRetrievalStep", "Message", "MetricEvent", @@ -9852,21 +9899,26 @@ "PreferenceOptimizeRequest", "ProviderInfo", "QATFinetuningConfig", + "QueryChunksRequest", + "QueryChunksResponse", "QueryCondition", "QueryConditionOp", - "QueryDocumentsRequest", - "QueryDocumentsResponse", + "QueryContextRequest", "QuerySpanTreeResponse", "QuerySpansResponse", "QueryTracesResponse", + "RAGDocument", + "RAGQueryConfig", + "RAGQueryGeneratorConfig", + "RAGQueryResult", "RegexParserScoringFnParams", "RegisterDatasetRequest", "RegisterEvalTaskRequest", - "RegisterMemoryBankRequest", "RegisterModelRequest", "RegisterScoringFunctionRequest", "RegisterShieldRequest", "RegisterToolGroupRequest", + "RegisterVectorDbRequest", "ResponseFormat", "RouteInfo", "RunEvalRequest", @@ -9924,8 +9976,7 @@ "UnionType", "UnstructuredLogEvent", "UserMessage", - "VectorMemoryBank", - "VectorMemoryBankParams", + "VectorDB", "VersionInfo", "ViolationLevel" ] diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 9aeac6db3..e1ae07c45 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -761,6 +761,20 @@ components: - instruct - dialog type: string + DefaultRAGQueryGeneratorConfig: + additionalProperties: false + properties: + separator: + default: ' ' + type: string + type: + const: default + default: default + type: string + required: + - type + - separator + type: object EfficiencyConfig: additionalProperties: false properties: @@ -891,40 +905,6 @@ components: - scoring_functions - task_config type: object - GraphMemoryBank: - additionalProperties: false - properties: - identifier: - type: string - memory_bank_type: - const: graph - default: graph - type: string - provider_id: - type: string - provider_resource_id: - type: string - type: - const: memory_bank - default: memory_bank - type: string - required: - - identifier - - provider_resource_id - - provider_id - - type - - memory_bank_type - type: object - GraphMemoryBankParams: - additionalProperties: false - properties: - memory_bank_type: - const: graph - default: graph - type: string - required: - - memory_bank_type - type: object GreedySamplingStrategy: additionalProperties: false properties: @@ -997,20 +977,53 @@ components: - step_type - model_response type: object - InsertDocumentsRequest: + InsertChunksRequest: additionalProperties: false properties: - bank_id: - type: string - documents: + chunks: items: - $ref: '#/components/schemas/MemoryBankDocument' + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + required: + - content + - metadata + type: object type: array ttl_seconds: type: integer + vector_db_id: + type: string + required: + - vector_db_id + - chunks + type: object + InsertDocumentsRequest: + additionalProperties: false + properties: + chunk_size_in_tokens: + type: integer + documents: + items: + $ref: '#/components/schemas/RAGDocument' + type: array + vector_db_id: + type: string required: - - bank_id - documents + - vector_db_id + - chunk_size_in_tokens type: object InterleavedContent: oneOf: @@ -1026,7 +1039,7 @@ components: InvokeToolRequest: additionalProperties: false properties: - args: + kwargs: additionalProperties: oneOf: - type: 'null' @@ -1040,7 +1053,7 @@ components: type: string required: - tool_name - - args + - kwargs type: object Job: additionalProperties: false @@ -1067,74 +1080,6 @@ components: required: - type type: object - KeyValueMemoryBank: - additionalProperties: false - properties: - identifier: - type: string - memory_bank_type: - const: keyvalue - default: keyvalue - type: string - provider_id: - type: string - provider_resource_id: - type: string - type: - const: memory_bank - default: memory_bank - type: string - required: - - identifier - - provider_resource_id - - provider_id - - type - - memory_bank_type - type: object - KeyValueMemoryBankParams: - additionalProperties: false - properties: - memory_bank_type: - const: keyvalue - default: keyvalue - type: string - required: - - memory_bank_type - type: object - KeywordMemoryBank: - additionalProperties: false - properties: - identifier: - type: string - memory_bank_type: - const: keyword - default: keyword - type: string - provider_id: - type: string - provider_resource_id: - type: string - type: - const: memory_bank - default: memory_bank - type: string - required: - - identifier - - provider_resource_id - - provider_id - - type - - memory_bank_type - type: object - KeywordMemoryBankParams: - additionalProperties: false - properties: - memory_bank_type: - const: keyword - default: keyword - type: string - required: - - memory_bank_type - type: object LLMAsJudgeScoringFnParams: additionalProperties: false properties: @@ -1158,6 +1103,22 @@ components: - type - judge_model type: object + LLMRAGQueryGeneratorConfig: + additionalProperties: false + properties: + model: + type: string + template: + type: string + type: + const: llm + default: llm + type: string + required: + - type + - model + - template + type: object ListDatasetsResponse: additionalProperties: false properties: @@ -1178,16 +1139,6 @@ components: required: - data type: object - ListMemoryBanksResponse: - additionalProperties: false - properties: - data: - items: - $ref: '#/components/schemas/MemoryBank' - type: array - required: - - data - type: object ListModelsResponse: additionalProperties: false properties: @@ -1274,6 +1225,16 @@ components: required: - data type: object + ListVectorDBsResponse: + additionalProperties: false + properties: + data: + items: + $ref: '#/components/schemas/VectorDB' + type: array + required: + - data + type: object LogEventRequest: additionalProperties: false properties: @@ -1330,42 +1291,6 @@ components: - rank - alpha type: object - MemoryBank: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBank' - - $ref: '#/components/schemas/KeyValueMemoryBank' - - $ref: '#/components/schemas/KeywordMemoryBank' - - $ref: '#/components/schemas/GraphMemoryBank' - MemoryBankDocument: - additionalProperties: false - properties: - content: - oneOf: - - type: string - - $ref: '#/components/schemas/InterleavedContentItem' - - items: - $ref: '#/components/schemas/InterleavedContentItem' - type: array - - $ref: '#/components/schemas/URL' - document_id: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - mime_type: - type: string - required: - - document_id - - content - - metadata - type: object MemoryRetrievalStep: additionalProperties: false properties: @@ -1705,6 +1630,59 @@ components: - quantizer_name - group_size type: object + QueryChunksRequest: + additionalProperties: false + properties: + params: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + query: + $ref: '#/components/schemas/InterleavedContent' + vector_db_id: + type: string + required: + - vector_db_id + - query + type: object + QueryChunksResponse: + additionalProperties: false + properties: + chunks: + items: + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + required: + - content + - metadata + type: object + type: array + scores: + items: + type: number + type: array + required: + - chunks + - scores + type: object QueryCondition: additionalProperties: false properties: @@ -1732,53 +1710,21 @@ components: - gt - lt type: string - QueryDocumentsRequest: + QueryContextRequest: additionalProperties: false properties: - bank_id: - type: string - params: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - query: + content: $ref: '#/components/schemas/InterleavedContent' - required: - - bank_id - - query - type: object - QueryDocumentsResponse: - additionalProperties: false - properties: - chunks: + query_config: + $ref: '#/components/schemas/RAGQueryConfig' + vector_db_ids: items: - additionalProperties: false - properties: - content: - $ref: '#/components/schemas/InterleavedContent' - document_id: - type: string - token_count: - type: integer - required: - - content - - token_count - - document_id - type: object - type: array - scores: - items: - type: number + type: string type: array required: - - chunks - - scores + - content + - query_config + - vector_db_ids type: object QuerySpanTreeResponse: additionalProperties: false @@ -1810,6 +1756,62 @@ components: required: - data type: object + RAGDocument: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - items: + $ref: '#/components/schemas/InterleavedContentItem' + type: array + - $ref: '#/components/schemas/URL' + document_id: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + mime_type: + type: string + required: + - document_id + - content + - metadata + type: object + RAGQueryConfig: + additionalProperties: false + properties: + max_chunks: + default: 5 + type: integer + max_tokens_in_context: + default: 4096 + type: integer + query_generator_config: + $ref: '#/components/schemas/RAGQueryGeneratorConfig' + required: + - query_generator_config + - max_tokens_in_context + - max_chunks + type: object + RAGQueryGeneratorConfig: + oneOf: + - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig' + RAGQueryResult: + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + type: object RegexParserScoringFnParams: additionalProperties: false properties: @@ -1888,25 +1890,6 @@ components: - dataset_id - scoring_functions type: object - RegisterMemoryBankRequest: - additionalProperties: false - properties: - memory_bank_id: - type: string - params: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBankParams' - - $ref: '#/components/schemas/KeyValueMemoryBankParams' - - $ref: '#/components/schemas/KeywordMemoryBankParams' - - $ref: '#/components/schemas/GraphMemoryBankParams' - provider_id: - type: string - provider_memory_bank_id: - type: string - required: - - memory_bank_id - - params - type: object RegisterModelRequest: additionalProperties: false properties: @@ -1999,6 +1982,23 @@ components: - toolgroup_id - provider_id type: object + RegisterVectorDbRequest: + additionalProperties: false + properties: + embedding_dimension: + type: integer + embedding_model: + type: string + provider_id: + type: string + provider_vector_db_id: + type: string + vector_db_id: + type: string + required: + - vector_db_id + - embedding_model + type: object ResponseFormat: oneOf: - additionalProperties: false @@ -2298,8 +2298,6 @@ components: Session: additionalProperties: false properties: - memory_bank: - $ref: '#/components/schemas/MemoryBank' session_id: type: string session_name: @@ -3202,58 +3200,30 @@ components: - role - content type: object - VectorMemoryBank: + VectorDB: additionalProperties: false properties: - chunk_size_in_tokens: - type: integer embedding_dimension: - default: 384 type: integer embedding_model: type: string identifier: type: string - memory_bank_type: - const: vector - default: vector - type: string - overlap_size_in_tokens: - type: integer provider_id: type: string provider_resource_id: type: string type: - const: memory_bank - default: memory_bank + const: vector_db + default: vector_db type: string required: - identifier - provider_resource_id - provider_id - type - - memory_bank_type - embedding_model - - chunk_size_in_tokens - type: object - VectorMemoryBankParams: - additionalProperties: false - properties: - chunk_size_in_tokens: - type: integer - embedding_model: - type: string - memory_bank_type: - const: vector - default: vector - type: string - overlap_size_in_tokens: - type: integer - required: - - memory_bank_type - - embedding_model - - chunk_size_in_tokens + - embedding_dimension type: object VersionInfo: additionalProperties: false @@ -4272,186 +4242,6 @@ paths: description: OK tags: - Inspect - /v1/memory-banks: - get: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/ListMemoryBanksResponse' - description: OK - tags: - - MemoryBanks - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterMemoryBankRequest' - required: true - responses: - '200': - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBank' - - $ref: '#/components/schemas/KeyValueMemoryBank' - - $ref: '#/components/schemas/KeywordMemoryBank' - - $ref: '#/components/schemas/GraphMemoryBank' - description: '' - tags: - - MemoryBanks - /v1/memory-banks/{memory_bank_id}: - delete: - parameters: - - in: path - name: memory_bank_id - required: true - schema: - type: string - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - responses: - '200': - description: OK - tags: - - MemoryBanks - get: - parameters: - - in: path - name: memory_bank_id - required: true - schema: - type: string - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - responses: - '200': - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/MemoryBank' - - type: 'null' - description: OK - tags: - - MemoryBanks - /v1/memory/insert: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/InsertDocumentsRequest' - required: true - responses: - '200': - description: OK - tags: - - Memory - /v1/memory/query: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/QueryDocumentsRequest' - required: true - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/QueryDocumentsResponse' - description: OK - tags: - - Memory /v1/models: get: parameters: @@ -5386,6 +5176,68 @@ paths: description: OK tags: - ToolRuntime + /v1/tool-runtime/rag-tool/insert-documents: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertDocumentsRequest' + required: true + responses: + '200': + description: OK + summary: Index documents so they can be used by the RAG system + tags: + - ToolRuntime + /v1/tool-runtime/rag-tool/query-context: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryContextRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RAGQueryResult' + description: OK + summary: Query the RAG system for context; typically invoked by the agent + tags: + - ToolRuntime /v1/toolgroups: get: parameters: @@ -5562,6 +5414,182 @@ paths: description: OK tags: - ToolGroups + /v1/vector-dbs: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ListVectorDBsResponse' + description: OK + tags: + - VectorDBs + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterVectorDbRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/VectorDB' + description: OK + tags: + - VectorDBs + /v1/vector-dbs/{vector_db_id}: + delete: + parameters: + - in: path + name: vector_db_id + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + responses: + '200': + description: OK + tags: + - VectorDBs + get: + parameters: + - in: path + name: vector_db_id + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/VectorDB' + - type: 'null' + description: OK + tags: + - VectorDBs + /v1/vector-io/insert: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertChunksRequest' + required: true + responses: + '200': + description: OK + tags: + - VectorIO + /v1/vector-io/query: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryChunksRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/QueryChunksResponse' + description: OK + tags: + - VectorIO /v1/version: get: parameters: @@ -5748,6 +5776,9 @@ tags: name: DatasetFormat - name: DatasetIO - name: Datasets +- description: + name: DefaultRAGQueryGeneratorConfig - description: name: EfficiencyConfig @@ -5767,12 +5798,6 @@ tags: - description: name: EvaluateRowsRequest -- description: - name: GraphMemoryBank -- description: - name: GraphMemoryBankParams - description: name: GreedySamplingStrategy @@ -5786,6 +5811,9 @@ tags: - name: Inference - description: name: InferenceStep +- description: + name: InsertChunksRequest - description: name: InsertDocumentsRequest @@ -5805,30 +5833,18 @@ tags: name: JobStatus - description: name: JsonType -- description: - name: KeyValueMemoryBank -- description: - name: KeyValueMemoryBankParams -- description: - name: KeywordMemoryBank -- description: - name: KeywordMemoryBankParams - description: name: LLMAsJudgeScoringFnParams +- description: + name: LLMRAGQueryGeneratorConfig - description: name: ListDatasetsResponse - description: name: ListEvalTasksResponse -- description: - name: ListMemoryBanksResponse - description: name: ListModelsResponse @@ -5853,6 +5869,9 @@ tags: - description: name: ListToolsResponse +- description: + name: ListVectorDBsResponse - description: name: LogEventRequest @@ -5861,13 +5880,6 @@ tags: - description: name: LoraFinetuningConfig -- name: Memory -- description: - name: MemoryBank -- description: - name: MemoryBankDocument -- name: MemoryBanks - description: name: MemoryRetrievalStep @@ -5920,17 +5932,20 @@ tags: - description: name: QATFinetuningConfig +- description: + name: QueryChunksRequest +- description: + name: QueryChunksResponse - description: name: QueryCondition - description: name: QueryConditionOp -- description: - name: QueryDocumentsRequest -- description: - name: QueryDocumentsResponse + name: QueryContextRequest - description: name: QuerySpanTreeResponse @@ -5940,6 +5955,15 @@ tags: - description: name: QueryTracesResponse +- description: + name: RAGDocument +- description: + name: RAGQueryConfig +- description: + name: RAGQueryGeneratorConfig +- description: + name: RAGQueryResult - description: name: RegexParserScoringFnParams @@ -5949,9 +5973,6 @@ tags: - description: name: RegisterEvalTaskRequest -- description: - name: RegisterMemoryBankRequest - description: name: RegisterModelRequest @@ -5964,6 +5985,9 @@ tags: - description: name: RegisterToolGroupRequest +- description: + name: RegisterVectorDbRequest - description: name: ResponseFormat - description: @@ -6128,12 +6152,10 @@ tags: name: UnstructuredLogEvent - description: name: UserMessage -- description: - name: VectorMemoryBank -- description: - name: VectorMemoryBankParams +- description: + name: VectorDB +- name: VectorDBs +- name: VectorIO - description: name: VersionInfo - description: @@ -6149,8 +6171,6 @@ x-tagGroups: - EvalTasks - Inference - Inspect - - Memory - - MemoryBanks - Models - PostTraining (Coming Soon) - Safety @@ -6161,6 +6181,8 @@ x-tagGroups: - Telemetry - ToolGroups - ToolRuntime + - VectorDBs + - VectorIO - name: Types tags: - AgentCandidate @@ -6210,19 +6232,19 @@ x-tagGroups: - DataConfig - Dataset - DatasetFormat + - DefaultRAGQueryGeneratorConfig - EfficiencyConfig - EmbeddingsRequest - EmbeddingsResponse - EvalTask - EvaluateResponse - EvaluateRowsRequest - - GraphMemoryBank - - GraphMemoryBankParams - GreedySamplingStrategy - HealthInfo - ImageContentItem - ImageDelta - InferenceStep + - InsertChunksRequest - InsertDocumentsRequest - InterleavedContent - InterleavedContentItem @@ -6230,14 +6252,10 @@ x-tagGroups: - Job - JobStatus - JsonType - - KeyValueMemoryBank - - KeyValueMemoryBankParams - - KeywordMemoryBank - - KeywordMemoryBankParams - LLMAsJudgeScoringFnParams + - LLMRAGQueryGeneratorConfig - ListDatasetsResponse - ListEvalTasksResponse - - ListMemoryBanksResponse - ListModelsResponse - ListPostTrainingJobsResponse - ListProvidersResponse @@ -6246,11 +6264,10 @@ x-tagGroups: - ListShieldsResponse - ListToolGroupsResponse - ListToolsResponse + - ListVectorDBsResponse - LogEventRequest - LogSeverity - LoraFinetuningConfig - - MemoryBank - - MemoryBankDocument - MemoryRetrievalStep - Message - MetricEvent @@ -6269,21 +6286,26 @@ x-tagGroups: - PreferenceOptimizeRequest - ProviderInfo - QATFinetuningConfig + - QueryChunksRequest + - QueryChunksResponse - QueryCondition - QueryConditionOp - - QueryDocumentsRequest - - QueryDocumentsResponse + - QueryContextRequest - QuerySpanTreeResponse - QuerySpansResponse - QueryTracesResponse + - RAGDocument + - RAGQueryConfig + - RAGQueryGeneratorConfig + - RAGQueryResult - RegexParserScoringFnParams - RegisterDatasetRequest - RegisterEvalTaskRequest - - RegisterMemoryBankRequest - RegisterModelRequest - RegisterScoringFunctionRequest - RegisterShieldRequest - RegisterToolGroupRequest + - RegisterVectorDbRequest - ResponseFormat - RouteInfo - RunEvalRequest @@ -6341,7 +6363,6 @@ x-tagGroups: - UnionType - UnstructuredLogEvent - UserMessage - - VectorMemoryBank - - VectorMemoryBankParams + - VectorDB - VersionInfo - ViolationLevel diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py index f747fcdc2..8cd798ebf 100644 --- a/llama_stack/apis/tools/__init__.py +++ b/llama_stack/apis/tools/__init__.py @@ -5,3 +5,4 @@ # the root directory of this source tree. from .tools import * # noqa: F401 F403 +from .rag_tool import * # noqa: F401 F403 diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py new file mode 100644 index 000000000..0247bb384 --- /dev/null +++ b/llama_stack/apis/tools/rag_tool.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Protocol, runtime_checkable + +from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol + + +@json_schema_type +class RAGDocument(BaseModel): + document_id: str + content: InterleavedContent | URL + mime_type: str | None = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + +@json_schema_type +class RAGQueryResult(BaseModel): + content: Optional[InterleavedContent] = None + + +@json_schema_type +class RAGQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +@json_schema_type +class DefaultRAGQueryGeneratorConfig(BaseModel): + type: Literal["default"] = "default" + separator: str = " " + + +@json_schema_type +class LLMRAGQueryGeneratorConfig(BaseModel): + type: Literal["llm"] = "llm" + model: str + template: str + + +RAGQueryGeneratorConfig = register_schema( + Annotated[ + Union[ + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, + ], + Field(discriminator="type"), + ], + name="RAGQueryGeneratorConfig", +) + + +@json_schema_type +class RAGQueryConfig(BaseModel): + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: RAGQueryGeneratorConfig = Field( + default=DefaultRAGQueryGeneratorConfig() + ) + max_tokens_in_context: int = 4096 + max_chunks: int = 5 + + +@runtime_checkable +@trace_protocol +class RAGToolRuntime(Protocol): + @webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST") + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + """Index documents so they can be used by the RAG system""" + ... + + @webmethod(route="/tool-runtime/rag-tool/query-context", method="POST") + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + """Query the RAG system for context; typically invoked by the agent""" + ... diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index fb990cc41..1af019bd4 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from .rag_tool import RAGToolRuntime + @json_schema_type class ToolParameter(BaseModel): @@ -130,11 +132,17 @@ class ToolGroups(Protocol): ... +class SpecialToolGroup(Enum): + rag_tool = "rag_tool" + + @runtime_checkable @trace_protocol class ToolRuntime(Protocol): tool_store: ToolStore + rag_tool: RAGToolRuntime + # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( @@ -143,7 +151,7 @@ class ToolRuntime(Protocol): @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index bd5a9ae98..dd6d4be6f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -333,6 +333,8 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config + # TODO: check compliance for special tool groups + # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol check_protocol_compliance(impl, protocols[provider_spec.api]) if ( not isinstance(provider_spec, AutoRoutedProviderSpec) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 979c68b72..3ae9833dc 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -36,7 +36,14 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import ToolDef, ToolRuntime +from llama_stack.apis.tools import ( + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolDef, + ToolRuntime, +) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable @@ -400,22 +407,55 @@ class EvalRouter(Eval): class ToolRuntimeRouter(ToolRuntime): + class RagToolImpl(RAGToolRuntime): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + return await self.routing_table.get_provider_impl( + "rag_tool.query_context" + ).query_context(content, query_config, vector_db_ids) + + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + return await self.routing_table.get_provider_impl( + "rag_tool.insert_documents" + ).insert_documents(documents, vector_db_id, chunk_size_in_tokens) + def __init__( self, routing_table: RoutingTable, ) -> None: self.routing_table = routing_table + # HACK ALERT this should be in sync with "get_all_api_endpoints()" + # TODO: make sure rag_tool vs builtin::memory is correct everywhere + self.rag_tool = self.RagToolImpl(routing_table) + setattr(self, "rag_tool.query_context", self.rag_tool.query_context) + setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents) + async def initialize(self) -> None: pass async def shutdown(self) -> None: pass - async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: + async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, - args=args, + kwargs=kwargs, ) async def list_runtime_tools( diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index af429e020..180479e40 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -9,6 +9,8 @@ from typing import Dict, List from pydantic import BaseModel +from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup + from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map @@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel): name: str +def toolgroup_protocol_map(): + return { + SpecialToolGroup.rag_tool: RAGToolRuntime, + } + + def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: apis = {} protocols = api_protocol_map() + toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + # HACK ALERT + if api == Api.tool_runtime: + for tool_group in SpecialToolGroup: + sub_protocol = toolgroup_protocols[tool_group] + sub_protocol_methods = inspect.getmembers( + sub_protocol, predicate=inspect.isfunction + ) + for name, method in sub_protocol_methods: + if not hasattr(method, "__webmethod__"): + continue + protocol_methods.append((f"{tool_group.value}.{name}", method)) + for name, method in protocol_methods: if not hasattr(method, "__webmethod__"): continue webmethod = method.__webmethod__ route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" - if webmethod.method == "GET": method = "get" elif webmethod.method == "DELETE": diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 180ec0ecc..f0c34dba4 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.datatypes import StackRunConfig @@ -62,6 +62,7 @@ class LlamaStack( Inspect, ToolGroups, ToolRuntime, + RAGToolRuntime, ): pass diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 010d137ec..5c0b8b5db 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -35,7 +35,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v5" +KEY_VERSION = "v6" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 50f61fb42..de34b8d2c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -19,9 +19,8 @@ async def get_provider_impl( impl = MetaReferenceAgentsImpl( config, deps[Api.inference], - deps[Api.memory], + deps[Api.vector_io], deps[Api.safety], - deps[Api.memory_banks], deps[Api.tool_runtime], deps[Api.tool_groups], ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2ebc7ded1..5b5175cee 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -59,13 +59,18 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import Memory, MemoryBankDocument -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import ( + DefaultRAGQueryGeneratorConfig, + RAGDocument, + RAGQueryConfig, + ToolGroups, + ToolRuntime, +) +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing - from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin @@ -79,7 +84,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_QUERY_TOOL = "query_memory" +MEMORY_QUERY_TOOL = "rag_tool.query_context" WEB_SEARCH_TOOL = "web_search" MEMORY_GROUP = "builtin::memory" @@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin): agent_config: AgentConfig, tempdir: str, inference_api: Inference, - memory_api: Memory, - memory_banks_api: MemoryBanks, safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + vector_io_api: VectorIO, persistence_store: KVStore, ): self.agent_id = agent_id self.agent_config = agent_config self.tempdir = tempdir self.inference_api = inference_api - self.memory_api = memory_api - self.memory_banks_api = memory_banks_api self.safety_api = safety_api + self.vector_io_api = vector_io_api self.storage = AgentPersistence(agent_id, persistence_store) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin): documents: Optional[List[Document]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: + # TODO: simplify all of this code, it can be simpler toolgroup_args = {} + toolgroups = set() for toolgroup in self.agent_config.toolgroups: if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroups.add(toolgroup.name) toolgroup_args[toolgroup.name] = toolgroup.args + else: + toolgroups.add(toolgroup) if toolgroups_for_turn: for toolgroup in toolgroups_for_turn: if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroups.add(toolgroup.name) toolgroup_args[toolgroup.name] = toolgroup.args + else: + toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: await self.handle_documents( session_id, documents, input_messages, tool_defs ) - if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: - memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) - if memory_tool_group is None: - raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}") + + if MEMORY_GROUP in toolgroups and len(input_messages) > 0: with tracing.span(MEMORY_QUERY_TOOL) as span: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - query_args = { - "messages": [msg.content for msg in input_messages], - **toolgroup_args.get(memory_tool_group, {}), - } + args = toolgroup_args.get(MEMORY_GROUP, {}) + vector_db_ids = args.get("vector_db_ids", []) session_info = await self.storage.get_session_info(session_id) + # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - if "memory_bank_ids" not in query_args: - query_args["memory_bank_ids"] = [] - query_args["memory_bank_ids"].append(session_info.memory_bank_id) + vector_db_ids.append(session_info.memory_bank_id) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=MEMORY_QUERY_TOOL, - args=query_args, + result = await self.tool_runtime_api.rag_tool.query_context( + content=concat_interleaved_content( + [msg.content for msg in input_messages] + ), + query_config=RAGQueryConfig( + query_generator_config=DefaultRAGQueryGeneratorConfig(), + max_tokens_in_context=4096, + max_chunks=5, + ), + vector_db_ids=vector_db_ids, ) + retrieved_context = result.content yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin): ToolResponse( call_id="", tool_name=MEMORY_QUERY_TOOL, - content=result.content, + content=retrieved_context or [], ) ], ), @@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute( "input", [m.model_dump_json() for m in input_messages] ) - span.set_attribute("output", result.content) - span.set_attribute("error_code", result.error_code) - span.set_attribute("error_message", result.error_message) + span.set_attribute("output", retrieved_context) span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - if result.error_code == 0: + if retrieved_context: last_message = input_messages[-1] - last_message.context = result.content + last_message.context = retrieved_context output_attachments = [] @@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin): if session_info.memory_bank_id is None: bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), + + # TODO: the semantic for registration is definitely not "creation" + # so we need to fix it if we expect the agent to create a new vector db + # for each session + await self.vector_io_api.register_vector_db( + vector_db_id=bank_id, + embedding_model="all-MiniLM-L6-v2", ) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: @@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin): async def add_to_session_memory_bank( self, session_id: str, data: List[Document] ) -> None: - bank_id = await self._ensure_memory_bank(session_id) + vector_db_id = await self._ensure_memory_bank(session_id) documents = [ - MemoryBankDocument( + RAGDocument( document_id=str(uuid.uuid4()), content=a.content, mime_type=a.mime_type, @@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin): ) for a in data ] - await self.memory_api.insert_documents( - bank_id=bank_id, + await self.tool_runtime_api.rag_tool.insert_documents( documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=512, ) @@ -955,7 +970,7 @@ async def execute_tool_call_maybe( result = await tool_runtime_api.invoke_tool( tool_name=name, - args=dict( + kwargs=dict( session_id=session_id, **tool_call_args, ), diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index d22ef82ab..b1844f4d0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -26,10 +26,9 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent @@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents): self, config: MetaReferenceAgentsImplConfig, inference_api: Inference, - memory_api: Memory, + vector_io_api: VectorIO, safety_api: Safety, - memory_banks_api: MemoryBanks, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, ): self.config = config self.inference_api = inference_api - self.memory_api = memory_api + self.vector_io_api = vector_io_api self.safety_api = safety_api - self.memory_banks_api = memory_banks_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents): tempdir=self.tempdir, inference_api=self.inference_api, safety_api=self.safety_api, - memory_api=self.memory_api, - memory_banks_api=self.memory_banks_api, + vector_io_api=self.vector_io_api, tool_runtime_api=self.tool_runtime_api, tool_groups_api=self.tool_groups_api, persistence_store=( diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 361c91a92..04434768d 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: - script = args["code"] + script = kwargs["code"] req = CodeExecutionRequest(scripts=[script]) res = self.code_executor.execute(req) pieces = [res["process_status"]] diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py index 928afa484..42a0a6b01 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): - impl = MemoryToolRuntimeImpl( - config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] - ) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py index 6ff242c6b..4a20c986c 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/config.py +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -4,87 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import Annotated, List, Literal, Union - -from pydantic import BaseModel, Field - - -class _MemoryBankConfigCommon(BaseModel): - bank_id: str - - -class VectorMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["vector"] = "vector" - - -class KeyValueMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyvalue"] = "keyvalue" - keys: List[str] # what keys to focus on - - -class KeywordMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyword"] = "keyword" - - -class GraphMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["graph"] = "graph" - entities: List[str] # what entities to focus on - - -MemoryBankConfig = Annotated[ - Union[ - VectorMemoryBankConfig, - KeyValueMemoryBankConfig, - KeywordMemoryBankConfig, - GraphMemoryBankConfig, - ], - Field(discriminator="type"), -] - - -class MemoryQueryGenerator(Enum): - default = "default" - llm = "llm" - custom = "custom" - - -class DefaultMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.default.value] = ( - MemoryQueryGenerator.default.value - ) - sep: str = " " - - -class LLMMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value - model: str - template: str - - -class CustomMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value - - -MemoryQueryGeneratorConfig = Annotated[ - Union[ - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - CustomMemoryQueryGeneratorConfig, - ], - Field(discriminator="type"), -] - - -class MemoryToolConfig(BaseModel): - memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) +from pydantic import BaseModel class MemoryToolRuntimeConfig(BaseModel): - # This config defines how a query is generated using the messages - # for memory bank retrieval. - query_generator_config: MemoryQueryGeneratorConfig = Field( - default=DefaultMemoryQueryGeneratorConfig() - ) - max_tokens_in_context: int = 4096 - max_chunks: int = 5 + pass diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 803981f07..e77ec76af 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -5,68 +5,64 @@ # the root directory of this source tree. -from typing import List - from jinja2 import Template -from pydantic import BaseModel from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import UserMessage + +from llama_stack.apis.tools.rag_tool import ( + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, + RAGQueryGenerator, + RAGQueryGeneratorConfig, +) from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from .config import ( - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - MemoryQueryGenerator, - MemoryQueryGeneratorConfig, -) - async def generate_rag_query( - config: MemoryQueryGeneratorConfig, - messages: List[InterleavedContent], + config: RAGQueryGeneratorConfig, + content: InterleavedContent, **kwargs, ): """ Generates a query that will be used for retrieving relevant information from the memory bank. """ - if config.type == MemoryQueryGenerator.default.value: - query = await default_rag_query_generator(config, messages, **kwargs) - elif config.type == MemoryQueryGenerator.llm.value: - query = await llm_rag_query_generator(config, messages, **kwargs) + if config.type == RAGQueryGenerator.default.value: + query = await default_rag_query_generator(config, content, **kwargs) + elif config.type == RAGQueryGenerator.llm.value: + query = await llm_rag_query_generator(config, content, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") return query async def default_rag_query_generator( - config: DefaultMemoryQueryGeneratorConfig, - messages: List[InterleavedContent], + config: DefaultRAGQueryGeneratorConfig, + content: InterleavedContent, **kwargs, ): - return config.sep.join(interleaved_content_as_str(m) for m in messages) + return interleaved_content_as_str(content, sep=config.separator) async def llm_rag_query_generator( - config: LLMMemoryQueryGeneratorConfig, - messages: List[InterleavedContent], + config: LLMRAGQueryGeneratorConfig, + content: InterleavedContent, **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = { - "messages": [ - message.model_dump() if isinstance(message, BaseModel) else message - for message in messages - ] - } + messages = [] + if isinstance(content, list): + messages = [interleaved_content_as_str(m) for m in content] + else: + messages = [interleaved_content_as_str(content)] template = Template(config.template) - content = template.render(m_dict) + content = template.render({"messages": messages}) model = config.model message = UserMessage(content=content) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index fe6325abb..d3f8b07dc 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -10,20 +10,29 @@ import secrets import string from typing import Any, Dict, List, Optional -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.memory import Memory, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, ToolDef, ToolInvocationResult, - ToolParameter, ToolRuntime, ) +from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content +from llama_stack.providers.utils.memory.vector_store import ( + content_from_doc, + make_overlapped_chunks, +) -from .config import MemoryToolConfig, MemoryToolRuntimeConfig +from .config import MemoryToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) @@ -35,65 +44,79 @@ def make_random_string(length: int = 8): ) -class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): +class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: MemoryToolRuntimeConfig, - memory_api: Memory, - memory_banks_api: MemoryBanks, + vector_io_api: VectorIO, inference_api: Inference, ): self.config = config - self.memory_api = memory_api - self.memory_banks_api = memory_banks_api + self.vector_io_api = vector_io_api self.inference_api = inference_api async def initialize(self): pass - async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None - ) -> List[ToolDef]: - return [ - ToolDef( - name="query_memory", - description="Retrieve context from memory", - parameters=[ - ToolParameter( - name="messages", - description="The input messages to search for", - parameter_type="array", - ), - ], - ) - ] + async def shutdown(self): + pass + + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + chunks = [] + for doc in documents: + content = await content_from_doc(doc) + chunks.extend( + make_overlapped_chunks( + doc.document_id, + content, + chunk_size_in_tokens, + chunk_size_in_tokens // 4, + ) + ) + + if not chunks: + return + + await self.vector_io_api.insert_chunks( + chunks=chunks, + vector_db_id=vector_db_id, + ) + + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + if not vector_db_ids: + return RAGQueryResult(content=None) - async def _retrieve_context( - self, input_messages: List[InterleavedContent], bank_ids: List[str] - ) -> Optional[List[InterleavedContent]]: - if not bank_ids: - return None query = await generate_rag_query( - self.config.query_generator_config, - input_messages, + query_config.query_generator_config, + content, inference_api=self.inference_api, ) tasks = [ - self.memory_api.query_documents( - bank_id=bank_id, + self.vector_io_api.query_chunks( + vector_db_id=vector_db_id, query=query, params={ - "max_chunks": self.config.max_chunks, + "max_chunks": query_config.max_chunks, }, ) - for bank_id in bank_ids + for vector_db_id in vector_db_ids ] - results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) + results: List[QueryChunksResponse] = await asyncio.gather(*tasks) chunks = [c for r in results for c in r.chunks] scores = [s for r in results for s in r.scores] if not chunks: - return None + return RAGQueryResult(content=None) # sort by score chunks, scores = zip( @@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): tokens = 0 picked = [] - for c in chunks[: self.config.max_chunks]: - tokens += c.token_count - if tokens > self.config.max_tokens_in_context: + for c in chunks[: query_config.max_chunks]: + metadata = c.metadata + tokens += metadata["token_count"] + if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break - picked.append(f"id:{c.document_id}; content:{c.content}") + picked.append( + TextContentItem( + text=f"id:{metadata['document_id']}; content:{c.content}", + ) + ) + return RAGQueryResult( + content=[ + TextContentItem( + text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", + ), + *picked, + TextContentItem( + text="\n=== END-RETRIEVED-CONTEXT ===\n", + ), + ], + ) + + async def list_runtime_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + # Parameters are not listed since these methods are not yet invoked automatically + # by the LLM. The method is only implemented so things like /tools can list without + # encountering fatals. return [ - "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - *picked, - "\n=== END-RETRIEVED-CONTEXT ===\n", + ToolDef( + name="rag_tool.query_context", + description="Retrieve context from memory", + ), + ToolDef( + name="rag_tool.insert_documents", + description="Insert documents into memory", + ), ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: - tool = await self.tool_store.get_tool(tool_name) - tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) - final_args = tool_group.args or {} - final_args.update(args) - config = MemoryToolConfig() - if tool.metadata and tool.metadata.get("config") is not None: - config = MemoryToolConfig(**tool.metadata["config"]) - if "memory_bank_ids" in final_args: - bank_ids = final_args["memory_bank_ids"] - else: - bank_ids = [ - bank_config.bank_id for bank_config in config.memory_bank_configs - ] - if "messages" not in final_args: - raise ValueError("messages are required") - context = await self._retrieve_context( - final_args["messages"], - bank_ids, - ) - if context is None: - context = [] - return ToolInvocationResult( - content=concat_interleaved_content(context), error_code=0 + raise RuntimeError( + "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol" ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b3ea68949..426fe22f2 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.inline.tool_runtime.memory", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference], + api_dependencies=[Api.vector_io, Api.inference], ), InlineProviderSpec( api=Api.tool_runtime, diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 5114e06aa..677e29c12 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() headers = { @@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl( "count": self.config.top_k, "textDecorations": True, "textFormat": "HTML", - "q": args["query"], + "q": kwargs["query"], } response = requests.get( diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 016f746ea..1162cc900 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() url = "https://api.search.brave.com/res/v1/web/search" @@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl( "Accept-Encoding": "gzip", "Accept": "application/json", } - payload = {"q": args["query"]} + payload = {"q": kwargs["query"]} response = requests.get(url=url, params=payload, headers=headers) response.raise_for_status() results = self._clean_brave_response(response.json()) diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index a304167e9..e0caec1d0 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): return tools async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: @@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async with sse_client(endpoint) as streams: async with ClientSession(*streams) as session: await session.initialize() - result = await session.call_tool(tool.identifier, args) + result = await session.call_tool(tool.identifier, kwargs) return ToolInvocationResult( content="\n".join([result.model_dump_json() for result in result.content]), diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 82077193e..f5826c0ff 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() response = requests.post( "https://api.tavily.com/search", - json={"api_key": api_key, "query": args["query"]}, + json={"api_key": api_key, "query": kwargs["query"]}, ) return ToolInvocationResult( diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 04ecfcc15..bf298c13e 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() params = { - "input": args["query"], + "input": kwargs["query"], "appid": api_key, "format": "plaintext", "output": "json", diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 4efdfe8b7..9c115e3a1 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -12,10 +12,10 @@ from ..conftest import ( get_test_config_for_api, ) from ..inference.fixtures import INFERENCE_FIXTURES -from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from ..tools.fixtures import TOOL_RUNTIME_FIXTURES +from ..vector_io.fixtures import VECTOR_IO_FIXTURES from .fixtures import AGENTS_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ @@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "meta_reference", "safety": "llama_guard", - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "ollama", "safety": "llama_guard", - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "inference": "together", "safety": "llama_guard", # make this work with Weaviate which is what the together distro supports - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "fireworks", "safety": "llama_guard", - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "remote", "safety": "remote", - "memory": "remote", + "vector_io": "remote", "agents": "remote", "tool_runtime": "memory_and_search", }, @@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc): available_fixtures = { "inference": INFERENCE_FIXTURES, "safety": SAFETY_FIXTURES, - "memory": MEMORY_FIXTURES, + "vector_io": VECTOR_IO_FIXTURES, "agents": AGENTS_FIXTURES, "tool_runtime": TOOL_RUNTIME_FIXTURES, } diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 1b1781f36..bb4a6e6a3 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -69,7 +69,7 @@ async def agents_stack( providers = {} provider_data = {} - for key in ["inference", "safety", "memory", "agents", "tool_runtime"]: + for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -118,7 +118,7 @@ async def agents_stack( ) test_stack = await construct_stack_for_test( - [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], + [Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime], providers, provider_data, models=models, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 320096826..f11aef3ec 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -214,9 +214,11 @@ class TestAgents: turn_response = [ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] - assert len(turn_response) > 0 + # FIXME: we need to check the content of the turn response and ensure + # RAG actually worked + @pytest.mark.asyncio async def test_create_agent_turn_with_tavily_search( self, agents_stack, search_query_messages, common_params diff --git a/llama_stack/providers/tests/vector_io/test_vector_io.py b/llama_stack/providers/tests/vector_io/test_vector_io.py index 901b8bd11..521131f63 100644 --- a/llama_stack/providers/tests/vector_io/test_vector_io.py +++ b/llama_stack/providers/tests/vector_io/test_vector_io.py @@ -8,13 +8,12 @@ import uuid import pytest +from llama_stack.apis.tools import RAGDocument + from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.providers.utils.memory.vector_store import ( - make_overlapped_chunks, - MemoryBankDocument, -) +from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks # How to run this test: # @@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import ( @pytest.fixture(scope="session") def sample_chunks(): docs = [ - MemoryBankDocument( + RAGDocument( document_id="doc1", content="Python is a high-level programming language.", metadata={"category": "programming", "difficulty": "beginner"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc2", content="Machine learning is a subset of artificial intelligence.", metadata={"category": "AI", "difficulty": "advanced"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc3", content="Data structures are fundamental to computer science.", metadata={"category": "computer science", "difficulty": "intermediate"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc4", content="Neural networks are inspired by biological neural networks.", metadata={"category": "AI", "difficulty": "advanced"}, diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index c2de6c714..82c0c9c07 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -19,7 +19,6 @@ import numpy as np from llama_models.llama3.api.tokenizer import Tokenizer from numpy.typing import NDArray -from pydantic import BaseModel, Field from pypdf import PdfReader from llama_stack.apis.common.content_types import ( @@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import ( TextContentItem, URL, ) +from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.datatypes import Api @@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) - log = logging.getLogger(__name__) -class MemoryBankDocument(BaseModel): - document_id: str - content: InterleavedContent | URL - mime_type: str | None = None - metadata: Dict[str, Any] = Field(default_factory=dict) - - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string pdf_bytes = io.BytesIO(data) @@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved return ret -async def content_from_doc(doc: MemoryBankDocument) -> str: +async def content_from_doc(doc: RAGDocument) -> str: if isinstance(doc.content, URL): if doc.content.uri.startswith("data:"): return content_from_data(doc.content.uri) @@ -161,7 +153,13 @@ def make_overlapped_chunks( chunk = tokenizer.decode(toks) # chunk is a string chunks.append( - Chunk(content=chunk, token_count=len(toks), document_id=document_id) + Chunk( + content=chunk, + metadata={ + "token_count": len(toks), + "document_id": document_id, + }, + ) ) return chunks diff --git a/llama_stack/scripts/test_rag_via_curl.py b/llama_stack/scripts/test_rag_via_curl.py new file mode 100644 index 000000000..28d6fb601 --- /dev/null +++ b/llama_stack/scripts/test_rag_via_curl.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import List + +import pytest +import requests +from pydantic import TypeAdapter + +from llama_stack.apis.tools import ( + DefaultRAGQueryGeneratorConfig, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, +) +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str + + +class TestRAGToolEndpoints: + @pytest.fixture + def base_url(self) -> str: + return "http://localhost:8321/v1" # Adjust port if needed + + @pytest.fixture + def sample_documents(self) -> List[RAGDocument]: + return [ + RAGDocument( + document_id="doc1", + content="Python is a high-level programming language.", + metadata={"category": "programming", "difficulty": "beginner"}, + ), + RAGDocument( + document_id="doc2", + content="Machine learning is a subset of artificial intelligence.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + RAGDocument( + document_id="doc3", + content="Data structures are fundamental to computer science.", + metadata={"category": "computer science", "difficulty": "intermediate"}, + ), + ] + + @pytest.mark.asyncio + async def test_rag_workflow( + self, base_url: str, sample_documents: List[RAGDocument] + ): + vector_db_payload = { + "vector_db_id": "test_vector_db", + "embedding_model": "all-MiniLM-L6-v2", + "embedding_dimension": 384, + } + + response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload) + assert response.status_code == 200 + vector_db = VectorDB(**response.json()) + + insert_payload = { + "documents": [ + json.loads(doc.model_dump_json()) for doc in sample_documents + ], + "vector_db_id": vector_db.identifier, + "chunk_size_in_tokens": 512, + } + + response = requests.post( + f"{base_url}/tool-runtime/rag-tool/insert-documents", + json=insert_payload, + ) + assert response.status_code == 200 + + query = "What is Python?" + query_config = RAGQueryConfig( + query_generator_config=DefaultRAGQueryGeneratorConfig(), + max_tokens_in_context=4096, + max_chunks=2, + ) + + query_payload = { + "content": query, + "query_config": json.loads(query_config.model_dump_json()), + "vector_db_ids": [vector_db.identifier], + } + + response = requests.post( + f"{base_url}/tool-runtime/rag-tool/query-context", + json=query_payload, + ) + assert response.status_code == 200 + result = response.json() + result = TypeAdapter(RAGQueryResult).validate_python(result) + + content_str = interleaved_content_as_str(result.content) + print(f"content: {content_str}") + assert len(content_str) > 0 + assert "Python" in content_str + + # Clean up: Delete the vector DB + response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}") + assert response.status_code == 200 diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index ea7387a24..2160adb8e 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -4,7 +4,7 @@ distribution_spec: providers: inference: - remote::together - memory: + vector_io: - inline::faiss - remote::chromadb - remote::pgvector diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index da25fd144..135b124e4 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -5,7 +5,7 @@ apis: - datasetio - eval - inference -- memory +- vector_io - safety - scoring - telemetry @@ -20,7 +20,7 @@ providers: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} - memory: + vector_io: - provider_id: faiss provider_type: inline::faiss config: @@ -145,7 +145,6 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B -memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: []