diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py
index 4cea9d970..abeb16936 100644
--- a/docs/openapi_generator/pyopenapi/operations.py
+++ b/docs/openapi_generator/pyopenapi/operations.py
@@ -172,10 +172,16 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."
+ # This import must be dynamic here
+ from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
+
# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
+ # HACK ALERT
+ if cls == RAGToolRuntime:
+ return ToolRuntime
return cls
raise ValidationError(
diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 459a53888..f00d7b291 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -1108,98 +1108,6 @@
]
}
},
- "/v1/memory-banks/{memory_bank_id}": {
- "get": {
- "responses": {
- "200": {
- "description": "OK",
- "content": {
- "application/json": {
- "schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/MemoryBank"
- },
- {
- "type": "null"
- }
- ]
- }
- }
- }
- }
- },
- "tags": [
- "MemoryBanks"
- ],
- "parameters": [
- {
- "name": "memory_bank_id",
- "in": "path",
- "required": true,
- "schema": {
- "type": "string"
- }
- },
- {
- "name": "X-LlamaStack-Provider-Data",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- },
- {
- "name": "X-LlamaStack-Client-Version",
- "in": "header",
- "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ]
- },
- "delete": {
- "responses": {
- "200": {
- "description": "OK"
- }
- },
- "tags": [
- "MemoryBanks"
- ],
- "parameters": [
- {
- "name": "memory_bank_id",
- "in": "path",
- "required": true,
- "schema": {
- "type": "string"
- }
- },
- {
- "name": "X-LlamaStack-Provider-Data",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- },
- {
- "name": "X-LlamaStack-Client-Version",
- "in": "header",
- "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ]
- }
- },
"/v1/models/{model_id}": {
"get": {
"responses": {
@@ -1848,6 +1756,98 @@
]
}
},
+ "/v1/vector-dbs/{vector_db_id}": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/VectorDB"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "VectorDBs"
+ ],
+ "parameters": [
+ {
+ "name": "vector_db_id",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Provider-Data",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Client-Version",
+ "in": "header",
+ "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "delete": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ }
+ },
+ "tags": [
+ "VectorDBs"
+ ],
+ "parameters": [
+ {
+ "name": "vector_db_id",
+ "in": "path",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Provider-Data",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Client-Version",
+ "in": "header",
+ "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
"/v1/health": {
"get": {
"responses": {
@@ -1887,7 +1887,7 @@
]
}
},
- "/v1/memory/insert": {
+ "/v1/vector-io/insert": {
"post": {
"responses": {
"200": {
@@ -1895,7 +1895,7 @@
}
},
"tags": [
- "Memory"
+ "VectorIO"
],
"parameters": [
{
@@ -1917,6 +1917,49 @@
}
}
],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/InsertChunksRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
+ "/v1/tool-runtime/rag-tool/insert-documents": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ }
+ },
+ "tags": [
+ "ToolRuntime"
+ ],
+ "summary": "Index documents so they can be used by the RAG system",
+ "parameters": [
+ {
+ "name": "X-LlamaStack-Provider-Data",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Client-Version",
+ "in": "header",
+ "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
"requestBody": {
"content": {
"application/json": {
@@ -2300,105 +2343,6 @@
}
}
},
- "/v1/memory-banks": {
- "get": {
- "responses": {
- "200": {
- "description": "OK",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/ListMemoryBanksResponse"
- }
- }
- }
- }
- },
- "tags": [
- "MemoryBanks"
- ],
- "parameters": [
- {
- "name": "X-LlamaStack-Provider-Data",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- },
- {
- "name": "X-LlamaStack-Client-Version",
- "in": "header",
- "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ]
- },
- "post": {
- "responses": {
- "200": {
- "description": "",
- "content": {
- "application/json": {
- "schema": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/VectorMemoryBank"
- },
- {
- "$ref": "#/components/schemas/KeyValueMemoryBank"
- },
- {
- "$ref": "#/components/schemas/KeywordMemoryBank"
- },
- {
- "$ref": "#/components/schemas/GraphMemoryBank"
- }
- ]
- }
- }
- }
- }
- },
- "tags": [
- "MemoryBanks"
- ],
- "parameters": [
- {
- "name": "X-LlamaStack-Provider-Data",
- "in": "header",
- "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
- "required": false,
- "schema": {
- "type": "string"
- }
- },
- {
- "name": "X-LlamaStack-Client-Version",
- "in": "header",
- "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
- "required": false,
- "schema": {
- "type": "string"
- }
- }
- ],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/RegisterMemoryBankRequest"
- }
- }
- },
- "required": true
- }
- }
- },
"/v1/models": {
"get": {
"responses": {
@@ -2912,6 +2856,92 @@
]
}
},
+ "/v1/vector-dbs": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ListVectorDBsResponse"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "VectorDBs"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-Provider-Data",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Client-Version",
+ "in": "header",
+ "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/VectorDB"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "VectorDBs"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-Provider-Data",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Client-Version",
+ "in": "header",
+ "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/RegisterVectorDbRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
"/v1/telemetry/events": {
"post": {
"responses": {
@@ -3003,7 +3033,7 @@
}
}
},
- "/v1/memory/query": {
+ "/v1/vector-io/query": {
"post": {
"responses": {
"200": {
@@ -3011,14 +3041,14 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/QueryDocumentsResponse"
+ "$ref": "#/components/schemas/QueryChunksResponse"
}
}
}
}
},
"tags": [
- "Memory"
+ "VectorIO"
],
"parameters": [
{
@@ -3044,7 +3074,57 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/QueryDocumentsRequest"
+ "$ref": "#/components/schemas/QueryChunksRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
+ "/v1/tool-runtime/rag-tool/query-context": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/RAGQueryResult"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "ToolRuntime"
+ ],
+ "summary": "Query the RAG system for context; typically invoked by the agent",
+ "parameters": [
+ {
+ "name": "X-LlamaStack-Provider-Data",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ },
+ {
+ "name": "X-LlamaStack-Client-Version",
+ "in": "header",
+ "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/QueryContextRequest"
}
}
},
@@ -5851,118 +5931,6 @@
"aggregated_results"
]
},
- "GraphMemoryBank": {
- "type": "object",
- "properties": {
- "identifier": {
- "type": "string"
- },
- "provider_resource_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "memory_bank",
- "default": "memory_bank"
- },
- "memory_bank_type": {
- "type": "string",
- "const": "graph",
- "default": "graph"
- }
- },
- "additionalProperties": false,
- "required": [
- "identifier",
- "provider_resource_id",
- "provider_id",
- "type",
- "memory_bank_type"
- ]
- },
- "KeyValueMemoryBank": {
- "type": "object",
- "properties": {
- "identifier": {
- "type": "string"
- },
- "provider_resource_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "memory_bank",
- "default": "memory_bank"
- },
- "memory_bank_type": {
- "type": "string",
- "const": "keyvalue",
- "default": "keyvalue"
- }
- },
- "additionalProperties": false,
- "required": [
- "identifier",
- "provider_resource_id",
- "provider_id",
- "type",
- "memory_bank_type"
- ]
- },
- "KeywordMemoryBank": {
- "type": "object",
- "properties": {
- "identifier": {
- "type": "string"
- },
- "provider_resource_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "memory_bank",
- "default": "memory_bank"
- },
- "memory_bank_type": {
- "type": "string",
- "const": "keyword",
- "default": "keyword"
- }
- },
- "additionalProperties": false,
- "required": [
- "identifier",
- "provider_resource_id",
- "provider_id",
- "type",
- "memory_bank_type"
- ]
- },
- "MemoryBank": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/VectorMemoryBank"
- },
- {
- "$ref": "#/components/schemas/KeyValueMemoryBank"
- },
- {
- "$ref": "#/components/schemas/KeywordMemoryBank"
- },
- {
- "$ref": "#/components/schemas/GraphMemoryBank"
- }
- ]
- },
"Session": {
"type": "object",
"properties": {
@@ -5981,9 +5949,6 @@
"started_at": {
"type": "string",
"format": "date-time"
- },
- "memory_bank": {
- "$ref": "#/components/schemas/MemoryBank"
}
},
"additionalProperties": false,
@@ -5995,53 +5960,6 @@
],
"title": "A single session of an interaction with an Agentic System."
},
- "VectorMemoryBank": {
- "type": "object",
- "properties": {
- "identifier": {
- "type": "string"
- },
- "provider_resource_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "const": "memory_bank",
- "default": "memory_bank"
- },
- "memory_bank_type": {
- "type": "string",
- "const": "vector",
- "default": "vector"
- },
- "embedding_model": {
- "type": "string"
- },
- "chunk_size_in_tokens": {
- "type": "integer"
- },
- "embedding_dimension": {
- "type": "integer",
- "default": 384
- },
- "overlap_size_in_tokens": {
- "type": "integer"
- }
- },
- "additionalProperties": false,
- "required": [
- "identifier",
- "provider_resource_id",
- "provider_id",
- "type",
- "memory_bank_type",
- "embedding_model",
- "chunk_size_in_tokens"
- ]
- },
"AgentStepResponse": {
"type": "object",
"properties": {
@@ -7012,6 +6930,40 @@
"data"
]
},
+ "VectorDB": {
+ "type": "object",
+ "properties": {
+ "identifier": {
+ "type": "string"
+ },
+ "provider_resource_id": {
+ "type": "string"
+ },
+ "provider_id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "vector_db",
+ "default": "vector_db"
+ },
+ "embedding_model": {
+ "type": "string"
+ },
+ "embedding_dimension": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "identifier",
+ "provider_resource_id",
+ "provider_id",
+ "type",
+ "embedding_model",
+ "embedding_dimension"
+ ]
+ },
"HealthInfo": {
"type": "object",
"properties": {
@@ -7024,7 +6976,64 @@
"status"
]
},
- "MemoryBankDocument": {
+ "InsertChunksRequest": {
+ "type": "object",
+ "properties": {
+ "vector_db_id": {
+ "type": "string"
+ },
+ "chunks": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "$ref": "#/components/schemas/InterleavedContent"
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "content",
+ "metadata"
+ ]
+ }
+ },
+ "ttl_seconds": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "vector_db_id",
+ "chunks"
+ ]
+ },
+ "RAGDocument": {
"type": "object",
"properties": {
"document_id": {
@@ -7088,23 +7097,24 @@
"InsertDocumentsRequest": {
"type": "object",
"properties": {
- "bank_id": {
- "type": "string"
- },
"documents": {
"type": "array",
"items": {
- "$ref": "#/components/schemas/MemoryBankDocument"
+ "$ref": "#/components/schemas/RAGDocument"
}
},
- "ttl_seconds": {
+ "vector_db_id": {
+ "type": "string"
+ },
+ "chunk_size_in_tokens": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
- "bank_id",
- "documents"
+ "documents",
+ "vector_db_id",
+ "chunk_size_in_tokens"
]
},
"InvokeToolRequest": {
@@ -7113,7 +7123,7 @@
"tool_name": {
"type": "string"
},
- "args": {
+ "kwargs": {
"type": "object",
"additionalProperties": {
"oneOf": [
@@ -7142,7 +7152,7 @@
"additionalProperties": false,
"required": [
"tool_name",
- "args"
+ "kwargs"
]
},
"ToolInvocationResult": {
@@ -7193,21 +7203,6 @@
"data"
]
},
- "ListMemoryBanksResponse": {
- "type": "object",
- "properties": {
- "data": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/MemoryBank"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "data"
- ]
- },
"ListModelsResponse": {
"type": "object",
"properties": {
@@ -7356,6 +7351,21 @@
"data"
]
},
+ "ListVectorDBsResponse": {
+ "type": "object",
+ "properties": {
+ "data": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/VectorDB"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "data"
+ ]
+ },
"LogSeverity": {
"type": "string",
"enum": [
@@ -7873,10 +7883,10 @@
"job_uuid"
]
},
- "QueryDocumentsRequest": {
+ "QueryChunksRequest": {
"type": "object",
"properties": {
- "bank_id": {
+ "vector_db_id": {
"type": "string"
},
"query": {
@@ -7910,11 +7920,11 @@
},
"additionalProperties": false,
"required": [
- "bank_id",
+ "vector_db_id",
"query"
]
},
- "QueryDocumentsResponse": {
+ "QueryChunksResponse": {
"type": "object",
"properties": {
"chunks": {
@@ -7925,18 +7935,36 @@
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
- "token_count": {
- "type": "integer"
- },
- "document_id": {
- "type": "string"
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
}
},
"additionalProperties": false,
"required": [
"content",
- "token_count",
- "document_id"
+ "metadata"
]
}
},
@@ -7953,6 +7981,111 @@
"scores"
]
},
+ "DefaultRAGQueryGeneratorConfig": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "default",
+ "default": "default"
+ },
+ "separator": {
+ "type": "string",
+ "default": " "
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "separator"
+ ]
+ },
+ "LLMRAGQueryGeneratorConfig": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "llm",
+ "default": "llm"
+ },
+ "model": {
+ "type": "string"
+ },
+ "template": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "model",
+ "template"
+ ]
+ },
+ "RAGQueryConfig": {
+ "type": "object",
+ "properties": {
+ "query_generator_config": {
+ "$ref": "#/components/schemas/RAGQueryGeneratorConfig"
+ },
+ "max_tokens_in_context": {
+ "type": "integer",
+ "default": 4096
+ },
+ "max_chunks": {
+ "type": "integer",
+ "default": 5
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "query_generator_config",
+ "max_tokens_in_context",
+ "max_chunks"
+ ]
+ },
+ "RAGQueryGeneratorConfig": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
+ },
+ {
+ "$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
+ }
+ ]
+ },
+ "QueryContextRequest": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "$ref": "#/components/schemas/InterleavedContent"
+ },
+ "query_config": {
+ "$ref": "#/components/schemas/RAGQueryConfig"
+ },
+ "vector_db_ids": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "content",
+ "query_config",
+ "vector_db_ids"
+ ]
+ },
+ "RAGQueryResult": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "$ref": "#/components/schemas/InterleavedContent"
+ }
+ },
+ "additionalProperties": false
+ },
"QueryCondition": {
"type": "object",
"properties": {
@@ -8139,108 +8272,6 @@
"scoring_functions"
]
},
- "GraphMemoryBankParams": {
- "type": "object",
- "properties": {
- "memory_bank_type": {
- "type": "string",
- "const": "graph",
- "default": "graph"
- }
- },
- "additionalProperties": false,
- "required": [
- "memory_bank_type"
- ]
- },
- "KeyValueMemoryBankParams": {
- "type": "object",
- "properties": {
- "memory_bank_type": {
- "type": "string",
- "const": "keyvalue",
- "default": "keyvalue"
- }
- },
- "additionalProperties": false,
- "required": [
- "memory_bank_type"
- ]
- },
- "KeywordMemoryBankParams": {
- "type": "object",
- "properties": {
- "memory_bank_type": {
- "type": "string",
- "const": "keyword",
- "default": "keyword"
- }
- },
- "additionalProperties": false,
- "required": [
- "memory_bank_type"
- ]
- },
- "VectorMemoryBankParams": {
- "type": "object",
- "properties": {
- "memory_bank_type": {
- "type": "string",
- "const": "vector",
- "default": "vector"
- },
- "embedding_model": {
- "type": "string"
- },
- "chunk_size_in_tokens": {
- "type": "integer"
- },
- "overlap_size_in_tokens": {
- "type": "integer"
- }
- },
- "additionalProperties": false,
- "required": [
- "memory_bank_type",
- "embedding_model",
- "chunk_size_in_tokens"
- ]
- },
- "RegisterMemoryBankRequest": {
- "type": "object",
- "properties": {
- "memory_bank_id": {
- "type": "string"
- },
- "params": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/VectorMemoryBankParams"
- },
- {
- "$ref": "#/components/schemas/KeyValueMemoryBankParams"
- },
- {
- "$ref": "#/components/schemas/KeywordMemoryBankParams"
- },
- {
- "$ref": "#/components/schemas/GraphMemoryBankParams"
- }
- ]
- },
- "provider_id": {
- "type": "string"
- },
- "provider_memory_bank_id": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "memory_bank_id",
- "params"
- ]
- },
"RegisterModelRequest": {
"type": "object",
"properties": {
@@ -8413,6 +8444,31 @@
"provider_id"
]
},
+ "RegisterVectorDbRequest": {
+ "type": "object",
+ "properties": {
+ "vector_db_id": {
+ "type": "string"
+ },
+ "embedding_model": {
+ "type": "string"
+ },
+ "embedding_dimension": {
+ "type": "integer"
+ },
+ "provider_id": {
+ "type": "string"
+ },
+ "provider_vector_db_id": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "vector_db_id",
+ "embedding_model"
+ ]
+ },
"RunEvalRequest": {
"type": "object",
"properties": {
@@ -9128,6 +9184,10 @@
{
"name": "Datasets"
},
+ {
+ "name": "DefaultRAGQueryGeneratorConfig",
+ "description": ""
+ },
{
"name": "EfficiencyConfig",
"description": ""
@@ -9158,14 +9218,6 @@
"name": "EvaluateRowsRequest",
"description": ""
},
- {
- "name": "GraphMemoryBank",
- "description": ""
- },
- {
- "name": "GraphMemoryBankParams",
- "description": ""
- },
{
"name": "GreedySamplingStrategy",
"description": ""
@@ -9189,6 +9241,10 @@
"name": "InferenceStep",
"description": ""
},
+ {
+ "name": "InsertChunksRequest",
+ "description": ""
+ },
{
"name": "InsertDocumentsRequest",
"description": ""
@@ -9220,26 +9276,14 @@
"name": "JsonType",
"description": ""
},
- {
- "name": "KeyValueMemoryBank",
- "description": ""
- },
- {
- "name": "KeyValueMemoryBankParams",
- "description": ""
- },
- {
- "name": "KeywordMemoryBank",
- "description": ""
- },
- {
- "name": "KeywordMemoryBankParams",
- "description": ""
- },
{
"name": "LLMAsJudgeScoringFnParams",
"description": ""
},
+ {
+ "name": "LLMRAGQueryGeneratorConfig",
+ "description": ""
+ },
{
"name": "ListDatasetsResponse",
"description": ""
@@ -9248,10 +9292,6 @@
"name": "ListEvalTasksResponse",
"description": ""
},
- {
- "name": "ListMemoryBanksResponse",
- "description": ""
- },
{
"name": "ListModelsResponse",
"description": ""
@@ -9284,6 +9324,10 @@
"name": "ListToolsResponse",
"description": ""
},
+ {
+ "name": "ListVectorDBsResponse",
+ "description": ""
+ },
{
"name": "LogEventRequest",
"description": ""
@@ -9296,20 +9340,6 @@
"name": "LoraFinetuningConfig",
"description": ""
},
- {
- "name": "Memory"
- },
- {
- "name": "MemoryBank",
- "description": ""
- },
- {
- "name": "MemoryBankDocument",
- "description": ""
- },
- {
- "name": "MemoryBanks"
- },
{
"name": "MemoryRetrievalStep",
"description": ""
@@ -9388,6 +9418,14 @@
"name": "QATFinetuningConfig",
"description": ""
},
+ {
+ "name": "QueryChunksRequest",
+ "description": ""
+ },
+ {
+ "name": "QueryChunksResponse",
+ "description": ""
+ },
{
"name": "QueryCondition",
"description": ""
@@ -9397,12 +9435,8 @@
"description": ""
},
{
- "name": "QueryDocumentsRequest",
- "description": ""
- },
- {
- "name": "QueryDocumentsResponse",
- "description": ""
+ "name": "QueryContextRequest",
+ "description": ""
},
{
"name": "QuerySpanTreeResponse",
@@ -9416,6 +9450,22 @@
"name": "QueryTracesResponse",
"description": ""
},
+ {
+ "name": "RAGDocument",
+ "description": ""
+ },
+ {
+ "name": "RAGQueryConfig",
+ "description": ""
+ },
+ {
+ "name": "RAGQueryGeneratorConfig",
+ "description": ""
+ },
+ {
+ "name": "RAGQueryResult",
+ "description": ""
+ },
{
"name": "RegexParserScoringFnParams",
"description": ""
@@ -9428,10 +9478,6 @@
"name": "RegisterEvalTaskRequest",
"description": ""
},
- {
- "name": "RegisterMemoryBankRequest",
- "description": ""
- },
{
"name": "RegisterModelRequest",
"description": ""
@@ -9448,6 +9494,10 @@
"name": "RegisterToolGroupRequest",
"description": ""
},
+ {
+ "name": "RegisterVectorDbRequest",
+ "description": ""
+ },
{
"name": "ResponseFormat",
"description": ""
@@ -9701,12 +9751,14 @@
"description": ""
},
{
- "name": "VectorMemoryBank",
- "description": ""
+ "name": "VectorDB",
+ "description": ""
},
{
- "name": "VectorMemoryBankParams",
- "description": ""
+ "name": "VectorDBs"
+ },
+ {
+ "name": "VectorIO"
},
{
"name": "VersionInfo",
@@ -9729,8 +9781,6 @@
"EvalTasks",
"Inference",
"Inspect",
- "Memory",
- "MemoryBanks",
"Models",
"PostTraining (Coming Soon)",
"Safety",
@@ -9740,7 +9790,9 @@
"SyntheticDataGeneration (Coming Soon)",
"Telemetry",
"ToolGroups",
- "ToolRuntime"
+ "ToolRuntime",
+ "VectorDBs",
+ "VectorIO"
]
},
{
@@ -9793,19 +9845,19 @@
"DataConfig",
"Dataset",
"DatasetFormat",
+ "DefaultRAGQueryGeneratorConfig",
"EfficiencyConfig",
"EmbeddingsRequest",
"EmbeddingsResponse",
"EvalTask",
"EvaluateResponse",
"EvaluateRowsRequest",
- "GraphMemoryBank",
- "GraphMemoryBankParams",
"GreedySamplingStrategy",
"HealthInfo",
"ImageContentItem",
"ImageDelta",
"InferenceStep",
+ "InsertChunksRequest",
"InsertDocumentsRequest",
"InterleavedContent",
"InterleavedContentItem",
@@ -9813,14 +9865,10 @@
"Job",
"JobStatus",
"JsonType",
- "KeyValueMemoryBank",
- "KeyValueMemoryBankParams",
- "KeywordMemoryBank",
- "KeywordMemoryBankParams",
"LLMAsJudgeScoringFnParams",
+ "LLMRAGQueryGeneratorConfig",
"ListDatasetsResponse",
"ListEvalTasksResponse",
- "ListMemoryBanksResponse",
"ListModelsResponse",
"ListPostTrainingJobsResponse",
"ListProvidersResponse",
@@ -9829,11 +9877,10 @@
"ListShieldsResponse",
"ListToolGroupsResponse",
"ListToolsResponse",
+ "ListVectorDBsResponse",
"LogEventRequest",
"LogSeverity",
"LoraFinetuningConfig",
- "MemoryBank",
- "MemoryBankDocument",
"MemoryRetrievalStep",
"Message",
"MetricEvent",
@@ -9852,21 +9899,26 @@
"PreferenceOptimizeRequest",
"ProviderInfo",
"QATFinetuningConfig",
+ "QueryChunksRequest",
+ "QueryChunksResponse",
"QueryCondition",
"QueryConditionOp",
- "QueryDocumentsRequest",
- "QueryDocumentsResponse",
+ "QueryContextRequest",
"QuerySpanTreeResponse",
"QuerySpansResponse",
"QueryTracesResponse",
+ "RAGDocument",
+ "RAGQueryConfig",
+ "RAGQueryGeneratorConfig",
+ "RAGQueryResult",
"RegexParserScoringFnParams",
"RegisterDatasetRequest",
"RegisterEvalTaskRequest",
- "RegisterMemoryBankRequest",
"RegisterModelRequest",
"RegisterScoringFunctionRequest",
"RegisterShieldRequest",
"RegisterToolGroupRequest",
+ "RegisterVectorDbRequest",
"ResponseFormat",
"RouteInfo",
"RunEvalRequest",
@@ -9924,8 +9976,7 @@
"UnionType",
"UnstructuredLogEvent",
"UserMessage",
- "VectorMemoryBank",
- "VectorMemoryBankParams",
+ "VectorDB",
"VersionInfo",
"ViolationLevel"
]
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 9aeac6db3..e1ae07c45 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -761,6 +761,20 @@ components:
- instruct
- dialog
type: string
+ DefaultRAGQueryGeneratorConfig:
+ additionalProperties: false
+ properties:
+ separator:
+ default: ' '
+ type: string
+ type:
+ const: default
+ default: default
+ type: string
+ required:
+ - type
+ - separator
+ type: object
EfficiencyConfig:
additionalProperties: false
properties:
@@ -891,40 +905,6 @@ components:
- scoring_functions
- task_config
type: object
- GraphMemoryBank:
- additionalProperties: false
- properties:
- identifier:
- type: string
- memory_bank_type:
- const: graph
- default: graph
- type: string
- provider_id:
- type: string
- provider_resource_id:
- type: string
- type:
- const: memory_bank
- default: memory_bank
- type: string
- required:
- - identifier
- - provider_resource_id
- - provider_id
- - type
- - memory_bank_type
- type: object
- GraphMemoryBankParams:
- additionalProperties: false
- properties:
- memory_bank_type:
- const: graph
- default: graph
- type: string
- required:
- - memory_bank_type
- type: object
GreedySamplingStrategy:
additionalProperties: false
properties:
@@ -997,20 +977,53 @@ components:
- step_type
- model_response
type: object
- InsertDocumentsRequest:
+ InsertChunksRequest:
additionalProperties: false
properties:
- bank_id:
- type: string
- documents:
+ chunks:
items:
- $ref: '#/components/schemas/MemoryBankDocument'
+ additionalProperties: false
+ properties:
+ content:
+ $ref: '#/components/schemas/InterleavedContent'
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ required:
+ - content
+ - metadata
+ type: object
type: array
ttl_seconds:
type: integer
+ vector_db_id:
+ type: string
+ required:
+ - vector_db_id
+ - chunks
+ type: object
+ InsertDocumentsRequest:
+ additionalProperties: false
+ properties:
+ chunk_size_in_tokens:
+ type: integer
+ documents:
+ items:
+ $ref: '#/components/schemas/RAGDocument'
+ type: array
+ vector_db_id:
+ type: string
required:
- - bank_id
- documents
+ - vector_db_id
+ - chunk_size_in_tokens
type: object
InterleavedContent:
oneOf:
@@ -1026,7 +1039,7 @@ components:
InvokeToolRequest:
additionalProperties: false
properties:
- args:
+ kwargs:
additionalProperties:
oneOf:
- type: 'null'
@@ -1040,7 +1053,7 @@ components:
type: string
required:
- tool_name
- - args
+ - kwargs
type: object
Job:
additionalProperties: false
@@ -1067,74 +1080,6 @@ components:
required:
- type
type: object
- KeyValueMemoryBank:
- additionalProperties: false
- properties:
- identifier:
- type: string
- memory_bank_type:
- const: keyvalue
- default: keyvalue
- type: string
- provider_id:
- type: string
- provider_resource_id:
- type: string
- type:
- const: memory_bank
- default: memory_bank
- type: string
- required:
- - identifier
- - provider_resource_id
- - provider_id
- - type
- - memory_bank_type
- type: object
- KeyValueMemoryBankParams:
- additionalProperties: false
- properties:
- memory_bank_type:
- const: keyvalue
- default: keyvalue
- type: string
- required:
- - memory_bank_type
- type: object
- KeywordMemoryBank:
- additionalProperties: false
- properties:
- identifier:
- type: string
- memory_bank_type:
- const: keyword
- default: keyword
- type: string
- provider_id:
- type: string
- provider_resource_id:
- type: string
- type:
- const: memory_bank
- default: memory_bank
- type: string
- required:
- - identifier
- - provider_resource_id
- - provider_id
- - type
- - memory_bank_type
- type: object
- KeywordMemoryBankParams:
- additionalProperties: false
- properties:
- memory_bank_type:
- const: keyword
- default: keyword
- type: string
- required:
- - memory_bank_type
- type: object
LLMAsJudgeScoringFnParams:
additionalProperties: false
properties:
@@ -1158,6 +1103,22 @@ components:
- type
- judge_model
type: object
+ LLMRAGQueryGeneratorConfig:
+ additionalProperties: false
+ properties:
+ model:
+ type: string
+ template:
+ type: string
+ type:
+ const: llm
+ default: llm
+ type: string
+ required:
+ - type
+ - model
+ - template
+ type: object
ListDatasetsResponse:
additionalProperties: false
properties:
@@ -1178,16 +1139,6 @@ components:
required:
- data
type: object
- ListMemoryBanksResponse:
- additionalProperties: false
- properties:
- data:
- items:
- $ref: '#/components/schemas/MemoryBank'
- type: array
- required:
- - data
- type: object
ListModelsResponse:
additionalProperties: false
properties:
@@ -1274,6 +1225,16 @@ components:
required:
- data
type: object
+ ListVectorDBsResponse:
+ additionalProperties: false
+ properties:
+ data:
+ items:
+ $ref: '#/components/schemas/VectorDB'
+ type: array
+ required:
+ - data
+ type: object
LogEventRequest:
additionalProperties: false
properties:
@@ -1330,42 +1291,6 @@ components:
- rank
- alpha
type: object
- MemoryBank:
- oneOf:
- - $ref: '#/components/schemas/VectorMemoryBank'
- - $ref: '#/components/schemas/KeyValueMemoryBank'
- - $ref: '#/components/schemas/KeywordMemoryBank'
- - $ref: '#/components/schemas/GraphMemoryBank'
- MemoryBankDocument:
- additionalProperties: false
- properties:
- content:
- oneOf:
- - type: string
- - $ref: '#/components/schemas/InterleavedContentItem'
- - items:
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
- - $ref: '#/components/schemas/URL'
- document_id:
- type: string
- metadata:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- mime_type:
- type: string
- required:
- - document_id
- - content
- - metadata
- type: object
MemoryRetrievalStep:
additionalProperties: false
properties:
@@ -1705,6 +1630,59 @@ components:
- quantizer_name
- group_size
type: object
+ QueryChunksRequest:
+ additionalProperties: false
+ properties:
+ params:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ query:
+ $ref: '#/components/schemas/InterleavedContent'
+ vector_db_id:
+ type: string
+ required:
+ - vector_db_id
+ - query
+ type: object
+ QueryChunksResponse:
+ additionalProperties: false
+ properties:
+ chunks:
+ items:
+ additionalProperties: false
+ properties:
+ content:
+ $ref: '#/components/schemas/InterleavedContent'
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ required:
+ - content
+ - metadata
+ type: object
+ type: array
+ scores:
+ items:
+ type: number
+ type: array
+ required:
+ - chunks
+ - scores
+ type: object
QueryCondition:
additionalProperties: false
properties:
@@ -1732,53 +1710,21 @@ components:
- gt
- lt
type: string
- QueryDocumentsRequest:
+ QueryContextRequest:
additionalProperties: false
properties:
- bank_id:
- type: string
- params:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- query:
+ content:
$ref: '#/components/schemas/InterleavedContent'
- required:
- - bank_id
- - query
- type: object
- QueryDocumentsResponse:
- additionalProperties: false
- properties:
- chunks:
+ query_config:
+ $ref: '#/components/schemas/RAGQueryConfig'
+ vector_db_ids:
items:
- additionalProperties: false
- properties:
- content:
- $ref: '#/components/schemas/InterleavedContent'
- document_id:
- type: string
- token_count:
- type: integer
- required:
- - content
- - token_count
- - document_id
- type: object
- type: array
- scores:
- items:
- type: number
+ type: string
type: array
required:
- - chunks
- - scores
+ - content
+ - query_config
+ - vector_db_ids
type: object
QuerySpanTreeResponse:
additionalProperties: false
@@ -1810,6 +1756,62 @@ components:
required:
- data
type: object
+ RAGDocument:
+ additionalProperties: false
+ properties:
+ content:
+ oneOf:
+ - type: string
+ - $ref: '#/components/schemas/InterleavedContentItem'
+ - items:
+ $ref: '#/components/schemas/InterleavedContentItem'
+ type: array
+ - $ref: '#/components/schemas/URL'
+ document_id:
+ type: string
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ mime_type:
+ type: string
+ required:
+ - document_id
+ - content
+ - metadata
+ type: object
+ RAGQueryConfig:
+ additionalProperties: false
+ properties:
+ max_chunks:
+ default: 5
+ type: integer
+ max_tokens_in_context:
+ default: 4096
+ type: integer
+ query_generator_config:
+ $ref: '#/components/schemas/RAGQueryGeneratorConfig'
+ required:
+ - query_generator_config
+ - max_tokens_in_context
+ - max_chunks
+ type: object
+ RAGQueryGeneratorConfig:
+ oneOf:
+ - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
+ - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
+ RAGQueryResult:
+ additionalProperties: false
+ properties:
+ content:
+ $ref: '#/components/schemas/InterleavedContent'
+ type: object
RegexParserScoringFnParams:
additionalProperties: false
properties:
@@ -1888,25 +1890,6 @@ components:
- dataset_id
- scoring_functions
type: object
- RegisterMemoryBankRequest:
- additionalProperties: false
- properties:
- memory_bank_id:
- type: string
- params:
- oneOf:
- - $ref: '#/components/schemas/VectorMemoryBankParams'
- - $ref: '#/components/schemas/KeyValueMemoryBankParams'
- - $ref: '#/components/schemas/KeywordMemoryBankParams'
- - $ref: '#/components/schemas/GraphMemoryBankParams'
- provider_id:
- type: string
- provider_memory_bank_id:
- type: string
- required:
- - memory_bank_id
- - params
- type: object
RegisterModelRequest:
additionalProperties: false
properties:
@@ -1999,6 +1982,23 @@ components:
- toolgroup_id
- provider_id
type: object
+ RegisterVectorDbRequest:
+ additionalProperties: false
+ properties:
+ embedding_dimension:
+ type: integer
+ embedding_model:
+ type: string
+ provider_id:
+ type: string
+ provider_vector_db_id:
+ type: string
+ vector_db_id:
+ type: string
+ required:
+ - vector_db_id
+ - embedding_model
+ type: object
ResponseFormat:
oneOf:
- additionalProperties: false
@@ -2298,8 +2298,6 @@ components:
Session:
additionalProperties: false
properties:
- memory_bank:
- $ref: '#/components/schemas/MemoryBank'
session_id:
type: string
session_name:
@@ -3202,58 +3200,30 @@ components:
- role
- content
type: object
- VectorMemoryBank:
+ VectorDB:
additionalProperties: false
properties:
- chunk_size_in_tokens:
- type: integer
embedding_dimension:
- default: 384
type: integer
embedding_model:
type: string
identifier:
type: string
- memory_bank_type:
- const: vector
- default: vector
- type: string
- overlap_size_in_tokens:
- type: integer
provider_id:
type: string
provider_resource_id:
type: string
type:
- const: memory_bank
- default: memory_bank
+ const: vector_db
+ default: vector_db
type: string
required:
- identifier
- provider_resource_id
- provider_id
- type
- - memory_bank_type
- embedding_model
- - chunk_size_in_tokens
- type: object
- VectorMemoryBankParams:
- additionalProperties: false
- properties:
- chunk_size_in_tokens:
- type: integer
- embedding_model:
- type: string
- memory_bank_type:
- const: vector
- default: vector
- type: string
- overlap_size_in_tokens:
- type: integer
- required:
- - memory_bank_type
- - embedding_model
- - chunk_size_in_tokens
+ - embedding_dimension
type: object
VersionInfo:
additionalProperties: false
@@ -4272,186 +4242,6 @@ paths:
description: OK
tags:
- Inspect
- /v1/memory-banks:
- get:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-Provider-Data
- required: false
- schema:
- type: string
- - description: Version of the client making the request. This is used to ensure
- that the client and server are compatible.
- in: header
- name: X-LlamaStack-Client-Version
- required: false
- schema:
- type: string
- responses:
- '200':
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/ListMemoryBanksResponse'
- description: OK
- tags:
- - MemoryBanks
- post:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-Provider-Data
- required: false
- schema:
- type: string
- - description: Version of the client making the request. This is used to ensure
- that the client and server are compatible.
- in: header
- name: X-LlamaStack-Client-Version
- required: false
- schema:
- type: string
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/RegisterMemoryBankRequest'
- required: true
- responses:
- '200':
- content:
- application/json:
- schema:
- oneOf:
- - $ref: '#/components/schemas/VectorMemoryBank'
- - $ref: '#/components/schemas/KeyValueMemoryBank'
- - $ref: '#/components/schemas/KeywordMemoryBank'
- - $ref: '#/components/schemas/GraphMemoryBank'
- description: ''
- tags:
- - MemoryBanks
- /v1/memory-banks/{memory_bank_id}:
- delete:
- parameters:
- - in: path
- name: memory_bank_id
- required: true
- schema:
- type: string
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-Provider-Data
- required: false
- schema:
- type: string
- - description: Version of the client making the request. This is used to ensure
- that the client and server are compatible.
- in: header
- name: X-LlamaStack-Client-Version
- required: false
- schema:
- type: string
- responses:
- '200':
- description: OK
- tags:
- - MemoryBanks
- get:
- parameters:
- - in: path
- name: memory_bank_id
- required: true
- schema:
- type: string
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-Provider-Data
- required: false
- schema:
- type: string
- - description: Version of the client making the request. This is used to ensure
- that the client and server are compatible.
- in: header
- name: X-LlamaStack-Client-Version
- required: false
- schema:
- type: string
- responses:
- '200':
- content:
- application/json:
- schema:
- oneOf:
- - $ref: '#/components/schemas/MemoryBank'
- - type: 'null'
- description: OK
- tags:
- - MemoryBanks
- /v1/memory/insert:
- post:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-Provider-Data
- required: false
- schema:
- type: string
- - description: Version of the client making the request. This is used to ensure
- that the client and server are compatible.
- in: header
- name: X-LlamaStack-Client-Version
- required: false
- schema:
- type: string
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/InsertDocumentsRequest'
- required: true
- responses:
- '200':
- description: OK
- tags:
- - Memory
- /v1/memory/query:
- post:
- parameters:
- - description: JSON-encoded provider data which will be made available to the
- adapter servicing the API
- in: header
- name: X-LlamaStack-Provider-Data
- required: false
- schema:
- type: string
- - description: Version of the client making the request. This is used to ensure
- that the client and server are compatible.
- in: header
- name: X-LlamaStack-Client-Version
- required: false
- schema:
- type: string
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/QueryDocumentsRequest'
- required: true
- responses:
- '200':
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/QueryDocumentsResponse'
- description: OK
- tags:
- - Memory
/v1/models:
get:
parameters:
@@ -5386,6 +5176,68 @@ paths:
description: OK
tags:
- ToolRuntime
+ /v1/tool-runtime/rag-tool/insert-documents:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/InsertDocumentsRequest'
+ required: true
+ responses:
+ '200':
+ description: OK
+ summary: Index documents so they can be used by the RAG system
+ tags:
+ - ToolRuntime
+ /v1/tool-runtime/rag-tool/query-context:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/QueryContextRequest'
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RAGQueryResult'
+ description: OK
+ summary: Query the RAG system for context; typically invoked by the agent
+ tags:
+ - ToolRuntime
/v1/toolgroups:
get:
parameters:
@@ -5562,6 +5414,182 @@ paths:
description: OK
tags:
- ToolGroups
+ /v1/vector-dbs:
+ get:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ListVectorDBsResponse'
+ description: OK
+ tags:
+ - VectorDBs
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RegisterVectorDbRequest'
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/VectorDB'
+ description: OK
+ tags:
+ - VectorDBs
+ /v1/vector-dbs/{vector_db_id}:
+ delete:
+ parameters:
+ - in: path
+ name: vector_db_id
+ required: true
+ schema:
+ type: string
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ description: OK
+ tags:
+ - VectorDBs
+ get:
+ parameters:
+ - in: path
+ name: vector_db_id
+ required: true
+ schema:
+ type: string
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ oneOf:
+ - $ref: '#/components/schemas/VectorDB'
+ - type: 'null'
+ description: OK
+ tags:
+ - VectorDBs
+ /v1/vector-io/insert:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/InsertChunksRequest'
+ required: true
+ responses:
+ '200':
+ description: OK
+ tags:
+ - VectorIO
+ /v1/vector-io/query:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-Provider-Data
+ required: false
+ schema:
+ type: string
+ - description: Version of the client making the request. This is used to ensure
+ that the client and server are compatible.
+ in: header
+ name: X-LlamaStack-Client-Version
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/QueryChunksRequest'
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/QueryChunksResponse'
+ description: OK
+ tags:
+ - VectorIO
/v1/version:
get:
parameters:
@@ -5748,6 +5776,9 @@ tags:
name: DatasetFormat
- name: DatasetIO
- name: Datasets
+- description:
+ name: DefaultRAGQueryGeneratorConfig
- description:
name: EfficiencyConfig
@@ -5767,12 +5798,6 @@ tags:
- description:
name: EvaluateRowsRequest
-- description:
- name: GraphMemoryBank
-- description:
- name: GraphMemoryBankParams
- description:
name: GreedySamplingStrategy
@@ -5786,6 +5811,9 @@ tags:
- name: Inference
- description:
name: InferenceStep
+- description:
+ name: InsertChunksRequest
- description:
name: InsertDocumentsRequest
@@ -5805,30 +5833,18 @@ tags:
name: JobStatus
- description:
name: JsonType
-- description:
- name: KeyValueMemoryBank
-- description:
- name: KeyValueMemoryBankParams
-- description:
- name: KeywordMemoryBank
-- description:
- name: KeywordMemoryBankParams
- description:
name: LLMAsJudgeScoringFnParams
+- description:
+ name: LLMRAGQueryGeneratorConfig
- description:
name: ListDatasetsResponse
- description:
name: ListEvalTasksResponse
-- description:
- name: ListMemoryBanksResponse
- description:
name: ListModelsResponse
@@ -5853,6 +5869,9 @@ tags:
- description:
name: ListToolsResponse
+- description:
+ name: ListVectorDBsResponse
- description:
name: LogEventRequest
@@ -5861,13 +5880,6 @@ tags:
- description:
name: LoraFinetuningConfig
-- name: Memory
-- description:
- name: MemoryBank
-- description:
- name: MemoryBankDocument
-- name: MemoryBanks
- description:
name: MemoryRetrievalStep
@@ -5920,17 +5932,20 @@ tags:
- description:
name: QATFinetuningConfig
+- description:
+ name: QueryChunksRequest
+- description:
+ name: QueryChunksResponse
- description:
name: QueryCondition
- description:
name: QueryConditionOp
-- description:
- name: QueryDocumentsRequest
-- description:
- name: QueryDocumentsResponse
+ name: QueryContextRequest
- description:
name: QuerySpanTreeResponse
@@ -5940,6 +5955,15 @@ tags:
- description:
name: QueryTracesResponse
+- description:
+ name: RAGDocument
+- description:
+ name: RAGQueryConfig
+- description:
+ name: RAGQueryGeneratorConfig
+- description:
+ name: RAGQueryResult
- description:
name: RegexParserScoringFnParams
@@ -5949,9 +5973,6 @@ tags:
- description:
name: RegisterEvalTaskRequest
-- description:
- name: RegisterMemoryBankRequest
- description:
name: RegisterModelRequest
@@ -5964,6 +5985,9 @@ tags:
- description:
name: RegisterToolGroupRequest
+- description:
+ name: RegisterVectorDbRequest
- description:
name: ResponseFormat
- description:
@@ -6128,12 +6152,10 @@ tags:
name: UnstructuredLogEvent
- description:
name: UserMessage
-- description:
- name: VectorMemoryBank
-- description:
- name: VectorMemoryBankParams
+- description:
+ name: VectorDB
+- name: VectorDBs
+- name: VectorIO
- description:
name: VersionInfo
- description:
@@ -6149,8 +6171,6 @@ x-tagGroups:
- EvalTasks
- Inference
- Inspect
- - Memory
- - MemoryBanks
- Models
- PostTraining (Coming Soon)
- Safety
@@ -6161,6 +6181,8 @@ x-tagGroups:
- Telemetry
- ToolGroups
- ToolRuntime
+ - VectorDBs
+ - VectorIO
- name: Types
tags:
- AgentCandidate
@@ -6210,19 +6232,19 @@ x-tagGroups:
- DataConfig
- Dataset
- DatasetFormat
+ - DefaultRAGQueryGeneratorConfig
- EfficiencyConfig
- EmbeddingsRequest
- EmbeddingsResponse
- EvalTask
- EvaluateResponse
- EvaluateRowsRequest
- - GraphMemoryBank
- - GraphMemoryBankParams
- GreedySamplingStrategy
- HealthInfo
- ImageContentItem
- ImageDelta
- InferenceStep
+ - InsertChunksRequest
- InsertDocumentsRequest
- InterleavedContent
- InterleavedContentItem
@@ -6230,14 +6252,10 @@ x-tagGroups:
- Job
- JobStatus
- JsonType
- - KeyValueMemoryBank
- - KeyValueMemoryBankParams
- - KeywordMemoryBank
- - KeywordMemoryBankParams
- LLMAsJudgeScoringFnParams
+ - LLMRAGQueryGeneratorConfig
- ListDatasetsResponse
- ListEvalTasksResponse
- - ListMemoryBanksResponse
- ListModelsResponse
- ListPostTrainingJobsResponse
- ListProvidersResponse
@@ -6246,11 +6264,10 @@ x-tagGroups:
- ListShieldsResponse
- ListToolGroupsResponse
- ListToolsResponse
+ - ListVectorDBsResponse
- LogEventRequest
- LogSeverity
- LoraFinetuningConfig
- - MemoryBank
- - MemoryBankDocument
- MemoryRetrievalStep
- Message
- MetricEvent
@@ -6269,21 +6286,26 @@ x-tagGroups:
- PreferenceOptimizeRequest
- ProviderInfo
- QATFinetuningConfig
+ - QueryChunksRequest
+ - QueryChunksResponse
- QueryCondition
- QueryConditionOp
- - QueryDocumentsRequest
- - QueryDocumentsResponse
+ - QueryContextRequest
- QuerySpanTreeResponse
- QuerySpansResponse
- QueryTracesResponse
+ - RAGDocument
+ - RAGQueryConfig
+ - RAGQueryGeneratorConfig
+ - RAGQueryResult
- RegexParserScoringFnParams
- RegisterDatasetRequest
- RegisterEvalTaskRequest
- - RegisterMemoryBankRequest
- RegisterModelRequest
- RegisterScoringFunctionRequest
- RegisterShieldRequest
- RegisterToolGroupRequest
+ - RegisterVectorDbRequest
- ResponseFormat
- RouteInfo
- RunEvalRequest
@@ -6341,7 +6363,6 @@ x-tagGroups:
- UnionType
- UnstructuredLogEvent
- UserMessage
- - VectorMemoryBank
- - VectorMemoryBankParams
+ - VectorDB
- VersionInfo
- ViolationLevel
diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py
index f747fcdc2..8cd798ebf 100644
--- a/llama_stack/apis/tools/__init__.py
+++ b/llama_stack/apis/tools/__init__.py
@@ -5,3 +5,4 @@
# the root directory of this source tree.
from .tools import * # noqa: F401 F403
+from .rag_tool import * # noqa: F401 F403
diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py
new file mode 100644
index 000000000..0247bb384
--- /dev/null
+++ b/llama_stack/apis/tools/rag_tool.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from enum import Enum
+from typing import Any, Dict, List, Literal, Optional, Union
+
+from llama_models.schema_utils import json_schema_type, register_schema, webmethod
+from pydantic import BaseModel, Field
+from typing_extensions import Annotated, Protocol, runtime_checkable
+
+from llama_stack.apis.common.content_types import InterleavedContent, URL
+from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
+
+
+@json_schema_type
+class RAGDocument(BaseModel):
+ document_id: str
+ content: InterleavedContent | URL
+ mime_type: str | None = None
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+
+
+@json_schema_type
+class RAGQueryResult(BaseModel):
+ content: Optional[InterleavedContent] = None
+
+
+@json_schema_type
+class RAGQueryGenerator(Enum):
+ default = "default"
+ llm = "llm"
+ custom = "custom"
+
+
+@json_schema_type
+class DefaultRAGQueryGeneratorConfig(BaseModel):
+ type: Literal["default"] = "default"
+ separator: str = " "
+
+
+@json_schema_type
+class LLMRAGQueryGeneratorConfig(BaseModel):
+ type: Literal["llm"] = "llm"
+ model: str
+ template: str
+
+
+RAGQueryGeneratorConfig = register_schema(
+ Annotated[
+ Union[
+ DefaultRAGQueryGeneratorConfig,
+ LLMRAGQueryGeneratorConfig,
+ ],
+ Field(discriminator="type"),
+ ],
+ name="RAGQueryGeneratorConfig",
+)
+
+
+@json_schema_type
+class RAGQueryConfig(BaseModel):
+ # This config defines how a query is generated using the messages
+ # for memory bank retrieval.
+ query_generator_config: RAGQueryGeneratorConfig = Field(
+ default=DefaultRAGQueryGeneratorConfig()
+ )
+ max_tokens_in_context: int = 4096
+ max_chunks: int = 5
+
+
+@runtime_checkable
+@trace_protocol
+class RAGToolRuntime(Protocol):
+ @webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
+ async def insert_documents(
+ self,
+ documents: List[RAGDocument],
+ vector_db_id: str,
+ chunk_size_in_tokens: int = 512,
+ ) -> None:
+ """Index documents so they can be used by the RAG system"""
+ ...
+
+ @webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
+ async def query_context(
+ self,
+ content: InterleavedContent,
+ query_config: RAGQueryConfig,
+ vector_db_ids: List[str],
+ ) -> RAGQueryResult:
+ """Query the RAG system for context; typically invoked by the agent"""
+ ...
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index fb990cc41..1af019bd4 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
+from .rag_tool import RAGToolRuntime
+
@json_schema_type
class ToolParameter(BaseModel):
@@ -130,11 +132,17 @@ class ToolGroups(Protocol):
...
+class SpecialToolGroup(Enum):
+ rag_tool = "rag_tool"
+
+
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore
+ rag_tool: RAGToolRuntime
+
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
@@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...
diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py
index bd5a9ae98..dd6d4be6f 100644
--- a/llama_stack/distribution/resolver.py
+++ b/llama_stack/distribution/resolver.py
@@ -333,6 +333,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
+ # TODO: check compliance for special tool groups
+ # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 979c68b72..3ae9833dc 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
-from llama_stack.apis.tools import ToolDef, ToolRuntime
+from llama_stack.apis.tools import (
+ RAGDocument,
+ RAGQueryConfig,
+ RAGQueryResult,
+ RAGToolRuntime,
+ ToolDef,
+ ToolRuntime,
+)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
@@ -400,22 +407,55 @@ class EvalRouter(Eval):
class ToolRuntimeRouter(ToolRuntime):
+ class RagToolImpl(RAGToolRuntime):
+ def __init__(
+ self,
+ routing_table: RoutingTable,
+ ) -> None:
+ self.routing_table = routing_table
+
+ async def query_context(
+ self,
+ content: InterleavedContent,
+ query_config: RAGQueryConfig,
+ vector_db_ids: List[str],
+ ) -> RAGQueryResult:
+ return await self.routing_table.get_provider_impl(
+ "rag_tool.query_context"
+ ).query_context(content, query_config, vector_db_ids)
+
+ async def insert_documents(
+ self,
+ documents: List[RAGDocument],
+ vector_db_id: str,
+ chunk_size_in_tokens: int = 512,
+ ) -> None:
+ return await self.routing_table.get_provider_impl(
+ "rag_tool.insert_documents"
+ ).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
+
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
+ # HACK ALERT this should be in sync with "get_all_api_endpoints()"
+ # TODO: make sure rag_tool vs builtin::memory is correct everywhere
+ self.rag_tool = self.RagToolImpl(routing_table)
+ setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
+ setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
+
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
- async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
+ async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
- args=args,
+ kwargs=kwargs,
)
async def list_runtime_tools(
diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py
index af429e020..180479e40 100644
--- a/llama_stack/distribution/server/endpoints.py
+++ b/llama_stack/distribution/server/endpoints.py
@@ -9,6 +9,8 @@ from typing import Dict, List
from pydantic import BaseModel
+from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
+
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
@@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel):
name: str
+def toolgroup_protocol_map():
+ return {
+ SpecialToolGroup.rag_tool: RAGToolRuntime,
+ }
+
+
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = api_protocol_map()
+ toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
+ # HACK ALERT
+ if api == Api.tool_runtime:
+ for tool_group in SpecialToolGroup:
+ sub_protocol = toolgroup_protocols[tool_group]
+ sub_protocol_methods = inspect.getmembers(
+ sub_protocol, predicate=inspect.isfunction
+ )
+ for name, method in sub_protocol_methods:
+ if not hasattr(method, "__webmethod__"):
+ continue
+ protocol_methods.append((f"{tool_group.value}.{name}", method))
+
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
-
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py
index 180ec0ecc..f0c34dba4 100644
--- a/llama_stack/distribution/stack.py
+++ b/llama_stack/distribution/stack.py
@@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
-from llama_stack.apis.tools import ToolGroups, ToolRuntime
+from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
@@ -62,6 +62,7 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
+ RAGToolRuntime,
):
pass
diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py
index 010d137ec..5c0b8b5db 100644
--- a/llama_stack/distribution/store/registry.py
+++ b/llama_stack/distribution/store/registry.py
@@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
-KEY_VERSION = "v5"
+KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py
index 50f61fb42..de34b8d2c 100644
--- a/llama_stack/providers/inline/agents/meta_reference/__init__.py
+++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py
@@ -19,9 +19,8 @@ async def get_provider_impl(
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],
- deps[Api.memory],
+ deps[Api.vector_io],
deps[Api.safety],
- deps[Api.memory_banks],
deps[Api.tool_runtime],
deps[Api.tool_groups],
)
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index 2ebc7ded1..5b5175cee 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
-from llama_stack.apis.memory import Memory, MemoryBankDocument
-from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
-from llama_stack.apis.tools import ToolGroups, ToolRuntime
+from llama_stack.apis.tools import (
+ DefaultRAGQueryGeneratorConfig,
+ RAGDocument,
+ RAGQueryConfig,
+ ToolGroups,
+ ToolRuntime,
+)
+from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import KVStore
+from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
-
from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
@@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
-MEMORY_QUERY_TOOL = "query_memory"
+MEMORY_QUERY_TOOL = "rag_tool.query_context"
WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory"
@@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
agent_config: AgentConfig,
tempdir: str,
inference_api: Inference,
- memory_api: Memory,
- memory_banks_api: MemoryBanks,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
+ vector_io_api: VectorIO,
persistence_store: KVStore,
):
self.agent_id = agent_id
self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api
- self.memory_api = memory_api
- self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
+ self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
@@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
+ # TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
+ toolgroups = set()
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
+ toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
+ else:
+ toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
+ toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
+ else:
+ toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
- if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
- memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
- if memory_tool_group is None:
- raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
+
+ if MEMORY_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
- query_args = {
- "messages": [msg.content for msg in input_messages],
- **toolgroup_args.get(memory_tool_group, {}),
- }
+ args = toolgroup_args.get(MEMORY_GROUP, {})
+ vector_db_ids = args.get("vector_db_ids", [])
session_info = await self.storage.get_session_info(session_id)
+
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
- if "memory_bank_ids" not in query_args:
- query_args["memory_bank_ids"] = []
- query_args["memory_bank_ids"].append(session_info.memory_bank_id)
+ vector_db_ids.append(session_info.memory_bank_id)
+
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
- result = await self.tool_runtime_api.invoke_tool(
- tool_name=MEMORY_QUERY_TOOL,
- args=query_args,
+ result = await self.tool_runtime_api.rag_tool.query_context(
+ content=concat_interleaved_content(
+ [msg.content for msg in input_messages]
+ ),
+ query_config=RAGQueryConfig(
+ query_generator_config=DefaultRAGQueryGeneratorConfig(),
+ max_tokens_in_context=4096,
+ max_chunks=5,
+ ),
+ vector_db_ids=vector_db_ids,
)
+ retrieved_context = result.content
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin):
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
- content=result.content,
+ content=retrieved_context or [],
)
],
),
@@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
- span.set_attribute("output", result.content)
- span.set_attribute("error_code", result.error_code)
- span.set_attribute("error_message", result.error_message)
+ span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
- if result.error_code == 0:
+ if retrieved_context:
last_message = input_messages[-1]
- last_message.context = result.content
+ last_message.context = retrieved_context
output_attachments = []
@@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
- await self.memory_banks_api.register_memory_bank(
- memory_bank_id=bank_id,
- params=VectorMemoryBankParams(
- embedding_model="all-MiniLM-L6-v2",
- chunk_size_in_tokens=512,
- ),
+
+ # TODO: the semantic for registration is definitely not "creation"
+ # so we need to fix it if we expect the agent to create a new vector db
+ # for each session
+ await self.vector_io_api.register_vector_db(
+ vector_db_id=bank_id,
+ embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
@@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin):
async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
- bank_id = await self._ensure_memory_bank(session_id)
+ vector_db_id = await self._ensure_memory_bank(session_id)
documents = [
- MemoryBankDocument(
+ RAGDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
@@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in data
]
- await self.memory_api.insert_documents(
- bank_id=bank_id,
+ await self.tool_runtime_api.rag_tool.insert_documents(
documents=documents,
+ vector_db_id=vector_db_id,
+ chunk_size_in_tokens=512,
)
@@ -955,7 +970,7 @@ async def execute_tool_call_maybe(
result = await tool_runtime_api.invoke_tool(
tool_name=name,
- args=dict(
+ kwargs=dict(
session_id=session_id,
**tool_call_args,
),
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index d22ef82ab..b1844f4d0 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
-from llama_stack.apis.memory import Memory
-from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
+from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
@@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
self,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
- memory_api: Memory,
+ vector_io_api: VectorIO,
safety_api: Safety,
- memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
- self.memory_api = memory_api
+ self.vector_io_api = vector_io_api
self.safety_api = safety_api
- self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
@@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
- memory_api=self.memory_api,
- memory_banks_api=self.memory_banks_api,
+ vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
index 361c91a92..04434768d 100644
--- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
+++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
@@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
- script = args["code"]
+ script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py
index 928afa484..42a0a6b01 100644
--- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py
+++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py
@@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
- impl = MemoryToolRuntimeImpl(
- config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
- )
+ impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
await impl.initialize()
return impl
diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py
index 6ff242c6b..4a20c986c 100644
--- a/llama_stack/providers/inline/tool_runtime/memory/config.py
+++ b/llama_stack/providers/inline/tool_runtime/memory/config.py
@@ -4,87 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from enum import Enum
-from typing import Annotated, List, Literal, Union
-
-from pydantic import BaseModel, Field
-
-
-class _MemoryBankConfigCommon(BaseModel):
- bank_id: str
-
-
-class VectorMemoryBankConfig(_MemoryBankConfigCommon):
- type: Literal["vector"] = "vector"
-
-
-class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
- type: Literal["keyvalue"] = "keyvalue"
- keys: List[str] # what keys to focus on
-
-
-class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
- type: Literal["keyword"] = "keyword"
-
-
-class GraphMemoryBankConfig(_MemoryBankConfigCommon):
- type: Literal["graph"] = "graph"
- entities: List[str] # what entities to focus on
-
-
-MemoryBankConfig = Annotated[
- Union[
- VectorMemoryBankConfig,
- KeyValueMemoryBankConfig,
- KeywordMemoryBankConfig,
- GraphMemoryBankConfig,
- ],
- Field(discriminator="type"),
-]
-
-
-class MemoryQueryGenerator(Enum):
- default = "default"
- llm = "llm"
- custom = "custom"
-
-
-class DefaultMemoryQueryGeneratorConfig(BaseModel):
- type: Literal[MemoryQueryGenerator.default.value] = (
- MemoryQueryGenerator.default.value
- )
- sep: str = " "
-
-
-class LLMMemoryQueryGeneratorConfig(BaseModel):
- type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
- model: str
- template: str
-
-
-class CustomMemoryQueryGeneratorConfig(BaseModel):
- type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
-
-
-MemoryQueryGeneratorConfig = Annotated[
- Union[
- DefaultMemoryQueryGeneratorConfig,
- LLMMemoryQueryGeneratorConfig,
- CustomMemoryQueryGeneratorConfig,
- ],
- Field(discriminator="type"),
-]
-
-
-class MemoryToolConfig(BaseModel):
- memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
+from pydantic import BaseModel
class MemoryToolRuntimeConfig(BaseModel):
- # This config defines how a query is generated using the messages
- # for memory bank retrieval.
- query_generator_config: MemoryQueryGeneratorConfig = Field(
- default=DefaultMemoryQueryGeneratorConfig()
- )
- max_tokens_in_context: int = 4096
- max_chunks: int = 5
+ pass
diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py
index 803981f07..e77ec76af 100644
--- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py
+++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py
@@ -5,68 +5,64 @@
# the root directory of this source tree.
-from typing import List
-
from jinja2 import Template
-from pydantic import BaseModel
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
+
+from llama_stack.apis.tools.rag_tool import (
+ DefaultRAGQueryGeneratorConfig,
+ LLMRAGQueryGeneratorConfig,
+ RAGQueryGenerator,
+ RAGQueryGeneratorConfig,
+)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
-from .config import (
- DefaultMemoryQueryGeneratorConfig,
- LLMMemoryQueryGeneratorConfig,
- MemoryQueryGenerator,
- MemoryQueryGeneratorConfig,
-)
-
async def generate_rag_query(
- config: MemoryQueryGeneratorConfig,
- messages: List[InterleavedContent],
+ config: RAGQueryGeneratorConfig,
+ content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
- if config.type == MemoryQueryGenerator.default.value:
- query = await default_rag_query_generator(config, messages, **kwargs)
- elif config.type == MemoryQueryGenerator.llm.value:
- query = await llm_rag_query_generator(config, messages, **kwargs)
+ if config.type == RAGQueryGenerator.default.value:
+ query = await default_rag_query_generator(config, content, **kwargs)
+ elif config.type == RAGQueryGenerator.llm.value:
+ query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
- config: DefaultMemoryQueryGeneratorConfig,
- messages: List[InterleavedContent],
+ config: DefaultRAGQueryGeneratorConfig,
+ content: InterleavedContent,
**kwargs,
):
- return config.sep.join(interleaved_content_as_str(m) for m in messages)
+ return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
- config: LLMMemoryQueryGeneratorConfig,
- messages: List[InterleavedContent],
+ config: LLMRAGQueryGeneratorConfig,
+ content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
- m_dict = {
- "messages": [
- message.model_dump() if isinstance(message, BaseModel) else message
- for message in messages
- ]
- }
+ messages = []
+ if isinstance(content, list):
+ messages = [interleaved_content_as_str(m) for m in content]
+ else:
+ messages = [interleaved_content_as_str(content)]
template = Template(config.template)
- content = template.render(m_dict)
+ content = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)
diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py
index fe6325abb..d3f8b07dc 100644
--- a/llama_stack/providers/inline/tool_runtime/memory/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py
@@ -10,20 +10,29 @@ import secrets
import string
from typing import Any, Dict, List, Optional
-from llama_stack.apis.common.content_types import URL
-from llama_stack.apis.inference import Inference, InterleavedContent
-from llama_stack.apis.memory import Memory, QueryDocumentsResponse
-from llama_stack.apis.memory_banks import MemoryBanks
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ TextContentItem,
+ URL,
+)
+from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
+ RAGDocument,
+ RAGQueryConfig,
+ RAGQueryResult,
+ RAGToolRuntime,
ToolDef,
ToolInvocationResult,
- ToolParameter,
ToolRuntime,
)
+from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
-from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
+from llama_stack.providers.utils.memory.vector_store import (
+ content_from_doc,
+ make_overlapped_chunks,
+)
-from .config import MemoryToolConfig, MemoryToolRuntimeConfig
+from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
@@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
)
-class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
+class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: MemoryToolRuntimeConfig,
- memory_api: Memory,
- memory_banks_api: MemoryBanks,
+ vector_io_api: VectorIO,
inference_api: Inference,
):
self.config = config
- self.memory_api = memory_api
- self.memory_banks_api = memory_banks_api
+ self.vector_io_api = vector_io_api
self.inference_api = inference_api
async def initialize(self):
pass
- async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
- ) -> List[ToolDef]:
- return [
- ToolDef(
- name="query_memory",
- description="Retrieve context from memory",
- parameters=[
- ToolParameter(
- name="messages",
- description="The input messages to search for",
- parameter_type="array",
- ),
- ],
- )
- ]
+ async def shutdown(self):
+ pass
+
+ async def insert_documents(
+ self,
+ documents: List[RAGDocument],
+ vector_db_id: str,
+ chunk_size_in_tokens: int = 512,
+ ) -> None:
+ chunks = []
+ for doc in documents:
+ content = await content_from_doc(doc)
+ chunks.extend(
+ make_overlapped_chunks(
+ doc.document_id,
+ content,
+ chunk_size_in_tokens,
+ chunk_size_in_tokens // 4,
+ )
+ )
+
+ if not chunks:
+ return
+
+ await self.vector_io_api.insert_chunks(
+ chunks=chunks,
+ vector_db_id=vector_db_id,
+ )
+
+ async def query_context(
+ self,
+ content: InterleavedContent,
+ query_config: RAGQueryConfig,
+ vector_db_ids: List[str],
+ ) -> RAGQueryResult:
+ if not vector_db_ids:
+ return RAGQueryResult(content=None)
- async def _retrieve_context(
- self, input_messages: List[InterleavedContent], bank_ids: List[str]
- ) -> Optional[List[InterleavedContent]]:
- if not bank_ids:
- return None
query = await generate_rag_query(
- self.config.query_generator_config,
- input_messages,
+ query_config.query_generator_config,
+ content,
inference_api=self.inference_api,
)
tasks = [
- self.memory_api.query_documents(
- bank_id=bank_id,
+ self.vector_io_api.query_chunks(
+ vector_db_id=vector_db_id,
query=query,
params={
- "max_chunks": self.config.max_chunks,
+ "max_chunks": query_config.max_chunks,
},
)
- for bank_id in bank_ids
+ for vector_db_id in vector_db_ids
]
- results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
+ results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
- return None
+ return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(
@@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
tokens = 0
picked = []
- for c in chunks[: self.config.max_chunks]:
- tokens += c.token_count
- if tokens > self.config.max_tokens_in_context:
+ for c in chunks[: query_config.max_chunks]:
+ metadata = c.metadata
+ tokens += metadata["token_count"]
+ if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
- picked.append(f"id:{c.document_id}; content:{c.content}")
+ picked.append(
+ TextContentItem(
+ text=f"id:{metadata['document_id']}; content:{c.content}",
+ )
+ )
+ return RAGQueryResult(
+ content=[
+ TextContentItem(
+ text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
+ ),
+ *picked,
+ TextContentItem(
+ text="\n=== END-RETRIEVED-CONTEXT ===\n",
+ ),
+ ],
+ )
+
+ async def list_runtime_tools(
+ self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ ) -> List[ToolDef]:
+ # Parameters are not listed since these methods are not yet invoked automatically
+ # by the LLM. The method is only implemented so things like /tools can list without
+ # encountering fatals.
return [
- "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
- *picked,
- "\n=== END-RETRIEVED-CONTEXT ===\n",
+ ToolDef(
+ name="rag_tool.query_context",
+ description="Retrieve context from memory",
+ ),
+ ToolDef(
+ name="rag_tool.insert_documents",
+ description="Insert documents into memory",
+ ),
]
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
- tool = await self.tool_store.get_tool(tool_name)
- tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
- final_args = tool_group.args or {}
- final_args.update(args)
- config = MemoryToolConfig()
- if tool.metadata and tool.metadata.get("config") is not None:
- config = MemoryToolConfig(**tool.metadata["config"])
- if "memory_bank_ids" in final_args:
- bank_ids = final_args["memory_bank_ids"]
- else:
- bank_ids = [
- bank_config.bank_id for bank_config in config.memory_bank_configs
- ]
- if "messages" not in final_args:
- raise ValueError("messages are required")
- context = await self._retrieve_context(
- final_args["messages"],
- bank_ids,
- )
- if context is None:
- context = []
- return ToolInvocationResult(
- content=concat_interleaved_content(context), error_code=0
+ raise RuntimeError(
+ "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
)
diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py
index b3ea68949..426fe22f2 100644
--- a/llama_stack/providers/registry/tool_runtime.py
+++ b/llama_stack/providers/registry/tool_runtime.py
@@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
- api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
+ api_dependencies=[Api.vector_io, Api.inference],
),
InlineProviderSpec(
api=Api.tool_runtime,
diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py
index 5114e06aa..677e29c12 100644
--- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py
+++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py
@@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
]
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
@@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
"count": self.config.top_k,
"textDecorations": True,
"textFormat": "HTML",
- "q": args["query"],
+ "q": kwargs["query"],
}
response = requests.get(
diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
index 016f746ea..1162cc900 100644
--- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
+++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
@@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
]
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
@@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
- payload = {"q": args["query"]}
+ payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
results = self._clean_brave_response(response.json())
diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
index a304167e9..e0caec1d0 100644
--- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
+++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
@@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return tools
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
@@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
- result = await session.call_tool(tool.identifier, args)
+ result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
index 82077193e..f5826c0ff 100644
--- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
+++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
@@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
]
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
response = requests.post(
"https://api.tavily.com/search",
- json={"api_key": api_key, "query": args["query"]},
+ json={"api_key": api_key, "query": kwargs["query"]},
)
return ToolInvocationResult(
diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
index 04ecfcc15..bf298c13e 100644
--- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
+++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
@@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
]
async def invoke_tool(
- self, tool_name: str, args: Dict[str, Any]
+ self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
- "input": args["query"],
+ "input": kwargs["query"],
"appid": api_key,
"format": "plaintext",
"output": "json",
diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py
index 4efdfe8b7..9c115e3a1 100644
--- a/llama_stack/providers/tests/agents/conftest.py
+++ b/llama_stack/providers/tests/agents/conftest.py
@@ -12,10 +12,10 @@ from ..conftest import (
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
-from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
+from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
@@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "meta_reference",
"safety": "llama_guard",
- "memory": "faiss",
+ "vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "ollama",
"safety": "llama_guard",
- "memory": "faiss",
+ "vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together",
"safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports
- "memory": "faiss",
+ "vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "fireworks",
"safety": "llama_guard",
- "memory": "faiss",
+ "vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "remote",
"safety": "remote",
- "memory": "remote",
+ "vector_io": "remote",
"agents": "remote",
"tool_runtime": "memory_and_search",
},
@@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,
- "memory": MEMORY_FIXTURES,
+ "vector_io": VECTOR_IO_FIXTURES,
"agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py
index 1b1781f36..bb4a6e6a3 100644
--- a/llama_stack/providers/tests/agents/fixtures.py
+++ b/llama_stack/providers/tests/agents/fixtures.py
@@ -69,7 +69,7 @@ async def agents_stack(
providers = {}
provider_data = {}
- for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
+ for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@@ -118,7 +118,7 @@ async def agents_stack(
)
test_stack = await construct_stack_for_test(
- [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
+ [Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
providers,
provider_data,
models=models,
diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py
index 320096826..f11aef3ec 100644
--- a/llama_stack/providers/tests/agents/test_agents.py
+++ b/llama_stack/providers/tests/agents/test_agents.py
@@ -214,9 +214,11 @@ class TestAgents:
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
-
assert len(turn_response) > 0
+ # FIXME: we need to check the content of the turn response and ensure
+ # RAG actually worked
+
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
diff --git a/llama_stack/providers/tests/vector_io/test_vector_io.py b/llama_stack/providers/tests/vector_io/test_vector_io.py
index 901b8bd11..521131f63 100644
--- a/llama_stack/providers/tests/vector_io/test_vector_io.py
+++ b/llama_stack/providers/tests/vector_io/test_vector_io.py
@@ -8,13 +8,12 @@ import uuid
import pytest
+from llama_stack.apis.tools import RAGDocument
+
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse
-from llama_stack.providers.utils.memory.vector_store import (
- make_overlapped_chunks,
- MemoryBankDocument,
-)
+from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
# How to run this test:
#
@@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
@pytest.fixture(scope="session")
def sample_chunks():
docs = [
- MemoryBankDocument(
+ RAGDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
- MemoryBankDocument(
+ RAGDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
- MemoryBankDocument(
+ RAGDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
- MemoryBankDocument(
+ RAGDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py
index c2de6c714..82c0c9c07 100644
--- a/llama_stack/providers/utils/memory/vector_store.py
+++ b/llama_stack/providers/utils/memory/vector_store.py
@@ -19,7 +19,6 @@ import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray
-from pydantic import BaseModel, Field
from pypdf import PdfReader
from llama_stack.apis.common.content_types import (
@@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
URL,
)
+from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import Api
@@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
-
log = logging.getLogger(__name__)
-class MemoryBankDocument(BaseModel):
- document_id: str
- content: InterleavedContent | URL
- mime_type: str | None = None
- metadata: Dict[str, Any] = Field(default_factory=dict)
-
-
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
@@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
return ret
-async def content_from_doc(doc: MemoryBankDocument) -> str:
+async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
@@ -161,7 +153,13 @@ def make_overlapped_chunks(
chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append(
- Chunk(content=chunk, token_count=len(toks), document_id=document_id)
+ Chunk(
+ content=chunk,
+ metadata={
+ "token_count": len(toks),
+ "document_id": document_id,
+ },
+ )
)
return chunks
diff --git a/llama_stack/scripts/test_rag_via_curl.py b/llama_stack/scripts/test_rag_via_curl.py
new file mode 100644
index 000000000..28d6fb601
--- /dev/null
+++ b/llama_stack/scripts/test_rag_via_curl.py
@@ -0,0 +1,105 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import json
+from typing import List
+
+import pytest
+import requests
+from pydantic import TypeAdapter
+
+from llama_stack.apis.tools import (
+ DefaultRAGQueryGeneratorConfig,
+ RAGDocument,
+ RAGQueryConfig,
+ RAGQueryResult,
+)
+from llama_stack.apis.vector_dbs import VectorDB
+from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
+
+
+class TestRAGToolEndpoints:
+ @pytest.fixture
+ def base_url(self) -> str:
+ return "http://localhost:8321/v1" # Adjust port if needed
+
+ @pytest.fixture
+ def sample_documents(self) -> List[RAGDocument]:
+ return [
+ RAGDocument(
+ document_id="doc1",
+ content="Python is a high-level programming language.",
+ metadata={"category": "programming", "difficulty": "beginner"},
+ ),
+ RAGDocument(
+ document_id="doc2",
+ content="Machine learning is a subset of artificial intelligence.",
+ metadata={"category": "AI", "difficulty": "advanced"},
+ ),
+ RAGDocument(
+ document_id="doc3",
+ content="Data structures are fundamental to computer science.",
+ metadata={"category": "computer science", "difficulty": "intermediate"},
+ ),
+ ]
+
+ @pytest.mark.asyncio
+ async def test_rag_workflow(
+ self, base_url: str, sample_documents: List[RAGDocument]
+ ):
+ vector_db_payload = {
+ "vector_db_id": "test_vector_db",
+ "embedding_model": "all-MiniLM-L6-v2",
+ "embedding_dimension": 384,
+ }
+
+ response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
+ assert response.status_code == 200
+ vector_db = VectorDB(**response.json())
+
+ insert_payload = {
+ "documents": [
+ json.loads(doc.model_dump_json()) for doc in sample_documents
+ ],
+ "vector_db_id": vector_db.identifier,
+ "chunk_size_in_tokens": 512,
+ }
+
+ response = requests.post(
+ f"{base_url}/tool-runtime/rag-tool/insert-documents",
+ json=insert_payload,
+ )
+ assert response.status_code == 200
+
+ query = "What is Python?"
+ query_config = RAGQueryConfig(
+ query_generator_config=DefaultRAGQueryGeneratorConfig(),
+ max_tokens_in_context=4096,
+ max_chunks=2,
+ )
+
+ query_payload = {
+ "content": query,
+ "query_config": json.loads(query_config.model_dump_json()),
+ "vector_db_ids": [vector_db.identifier],
+ }
+
+ response = requests.post(
+ f"{base_url}/tool-runtime/rag-tool/query-context",
+ json=query_payload,
+ )
+ assert response.status_code == 200
+ result = response.json()
+ result = TypeAdapter(RAGQueryResult).validate_python(result)
+
+ content_str = interleaved_content_as_str(result.content)
+ print(f"content: {content_str}")
+ assert len(content_str) > 0
+ assert "Python" in content_str
+
+ # Clean up: Delete the vector DB
+ response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
+ assert response.status_code == 200
diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml
index ea7387a24..2160adb8e 100644
--- a/llama_stack/templates/together/build.yaml
+++ b/llama_stack/templates/together/build.yaml
@@ -4,7 +4,7 @@ distribution_spec:
providers:
inference:
- remote::together
- memory:
+ vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml
index da25fd144..135b124e4 100644
--- a/llama_stack/templates/together/run.yaml
+++ b/llama_stack/templates/together/run.yaml
@@ -5,7 +5,7 @@ apis:
- datasetio
- eval
- inference
-- memory
+- vector_io
- safety
- scoring
- telemetry
@@ -20,7 +20,7 @@ providers:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
- memory:
+ vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
@@ -145,7 +145,6 @@ models:
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
-memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []