diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index 4cea9d970..abeb16936 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -172,10 +172,16 @@ def _get_endpoint_functions( def _get_defining_class(member_fn: str, derived_cls: type) -> type: "Find the class in which a member function is first defined in a class inheritance hierarchy." + # This import must be dynamic here + from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime + # iterate in reverse member resolution order to find most specific class first for cls in reversed(inspect.getmro(derived_cls)): for name, _ in inspect.getmembers(cls, inspect.isfunction): if name == member_fn: + # HACK ALERT + if cls == RAGToolRuntime: + return ToolRuntime return cls raise ValidationError( diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 459a53888..f00d7b291 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -1108,98 +1108,6 @@ ] } }, - "/v1/memory-banks/{memory_bank_id}": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/MemoryBank" - }, - { - "type": "null" - } - ] - } - } - } - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "memory_bank_id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ] - }, - "delete": { - "responses": { - "200": { - "description": "OK" - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "memory_bank_id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ] - } - }, "/v1/models/{model_id}": { "get": { "responses": { @@ -1848,6 +1756,98 @@ ] } }, + "/v1/vector-dbs/{vector_db_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/VectorDB" + }, + { + "type": "null" + } + ] + } + } + } + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "vector_db_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "vector_db_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/health": { "get": { "responses": { @@ -1887,7 +1887,7 @@ ] } }, - "/v1/memory/insert": { + "/v1/vector-io/insert": { "post": { "responses": { "200": { @@ -1895,7 +1895,7 @@ } }, "tags": [ - "Memory" + "VectorIO" ], "parameters": [ { @@ -1917,6 +1917,49 @@ } } ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InsertChunksRequest" + } + } + }, + "required": true + } + } + }, + "/v1/tool-runtime/rag-tool/insert-documents": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Index documents so they can be used by the RAG system", + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ], "requestBody": { "content": { "application/json": { @@ -2300,105 +2343,6 @@ } } }, - "/v1/memory-banks": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListMemoryBanksResponse" - } - } - } - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ] - }, - "post": { - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBank" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBank" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBank" - }, - { - "$ref": "#/components/schemas/GraphMemoryBank" - } - ] - } - } - } - } - }, - "tags": [ - "MemoryBanks" - ], - "parameters": [ - { - "name": "X-LlamaStack-Provider-Data", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - }, - { - "name": "X-LlamaStack-Client-Version", - "in": "header", - "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/RegisterMemoryBankRequest" - } - } - }, - "required": true - } - } - }, "/v1/models": { "get": { "responses": { @@ -2912,6 +2856,92 @@ ] } }, + "/v1/vector-dbs": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListVectorDBsResponse" + } + } + } + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ] + }, + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VectorDB" + } + } + } + } + }, + "tags": [ + "VectorDBs" + ], + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RegisterVectorDbRequest" + } + } + }, + "required": true + } + } + }, "/v1/telemetry/events": { "post": { "responses": { @@ -3003,7 +3033,7 @@ } } }, - "/v1/memory/query": { + "/v1/vector-io/query": { "post": { "responses": { "200": { @@ -3011,14 +3041,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/QueryDocumentsResponse" + "$ref": "#/components/schemas/QueryChunksResponse" } } } } }, "tags": [ - "Memory" + "VectorIO" ], "parameters": [ { @@ -3044,7 +3074,57 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/QueryDocumentsRequest" + "$ref": "#/components/schemas/QueryChunksRequest" + } + } + }, + "required": true + } + } + }, + "/v1/tool-runtime/rag-tool/query-context": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RAGQueryResult" + } + } + } + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Query the RAG system for context; typically invoked by the agent", + "parameters": [ + { + "name": "X-LlamaStack-Provider-Data", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-Client-Version", + "in": "header", + "description": "Version of the client making the request. This is used to ensure that the client and server are compatible.", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryContextRequest" } } }, @@ -5851,118 +5931,6 @@ "aggregated_results" ] }, - "GraphMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "graph", - "default": "graph" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type" - ] - }, - "KeyValueMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "keyvalue", - "default": "keyvalue" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type" - ] - }, - "KeywordMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "keyword", - "default": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type" - ] - }, - "MemoryBank": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBank" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBank" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBank" - }, - { - "$ref": "#/components/schemas/GraphMemoryBank" - } - ] - }, "Session": { "type": "object", "properties": { @@ -5981,9 +5949,6 @@ "started_at": { "type": "string", "format": "date-time" - }, - "memory_bank": { - "$ref": "#/components/schemas/MemoryBank" } }, "additionalProperties": false, @@ -5995,53 +5960,6 @@ ], "title": "A single session of an interaction with an Agentic System." }, - "VectorMemoryBank": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "memory_bank", - "default": "memory_bank" - }, - "memory_bank_type": { - "type": "string", - "const": "vector", - "default": "vector" - }, - "embedding_model": { - "type": "string" - }, - "chunk_size_in_tokens": { - "type": "integer" - }, - "embedding_dimension": { - "type": "integer", - "default": 384 - }, - "overlap_size_in_tokens": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "memory_bank_type", - "embedding_model", - "chunk_size_in_tokens" - ] - }, "AgentStepResponse": { "type": "object", "properties": { @@ -7012,6 +6930,40 @@ "data" ] }, + "VectorDB": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "provider_resource_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "vector_db", + "default": "vector_db" + }, + "embedding_model": { + "type": "string" + }, + "embedding_dimension": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "provider_resource_id", + "provider_id", + "type", + "embedding_model", + "embedding_dimension" + ] + }, "HealthInfo": { "type": "object", "properties": { @@ -7024,7 +6976,64 @@ "status" ] }, - "MemoryBankDocument": { + "InsertChunksRequest": { + "type": "object", + "properties": { + "vector_db_id": { + "type": "string" + }, + "chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "content", + "metadata" + ] + } + }, + "ttl_seconds": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "vector_db_id", + "chunks" + ] + }, + "RAGDocument": { "type": "object", "properties": { "document_id": { @@ -7088,23 +7097,24 @@ "InsertDocumentsRequest": { "type": "object", "properties": { - "bank_id": { - "type": "string" - }, "documents": { "type": "array", "items": { - "$ref": "#/components/schemas/MemoryBankDocument" + "$ref": "#/components/schemas/RAGDocument" } }, - "ttl_seconds": { + "vector_db_id": { + "type": "string" + }, + "chunk_size_in_tokens": { "type": "integer" } }, "additionalProperties": false, "required": [ - "bank_id", - "documents" + "documents", + "vector_db_id", + "chunk_size_in_tokens" ] }, "InvokeToolRequest": { @@ -7113,7 +7123,7 @@ "tool_name": { "type": "string" }, - "args": { + "kwargs": { "type": "object", "additionalProperties": { "oneOf": [ @@ -7142,7 +7152,7 @@ "additionalProperties": false, "required": [ "tool_name", - "args" + "kwargs" ] }, "ToolInvocationResult": { @@ -7193,21 +7203,6 @@ "data" ] }, - "ListMemoryBanksResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/MemoryBank" - } - } - }, - "additionalProperties": false, - "required": [ - "data" - ] - }, "ListModelsResponse": { "type": "object", "properties": { @@ -7356,6 +7351,21 @@ "data" ] }, + "ListVectorDBsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/VectorDB" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ] + }, "LogSeverity": { "type": "string", "enum": [ @@ -7873,10 +7883,10 @@ "job_uuid" ] }, - "QueryDocumentsRequest": { + "QueryChunksRequest": { "type": "object", "properties": { - "bank_id": { + "vector_db_id": { "type": "string" }, "query": { @@ -7910,11 +7920,11 @@ }, "additionalProperties": false, "required": [ - "bank_id", + "vector_db_id", "query" ] }, - "QueryDocumentsResponse": { + "QueryChunksResponse": { "type": "object", "properties": { "chunks": { @@ -7925,18 +7935,36 @@ "content": { "$ref": "#/components/schemas/InterleavedContent" }, - "token_count": { - "type": "integer" - }, - "document_id": { - "type": "string" + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, "required": [ "content", - "token_count", - "document_id" + "metadata" ] } }, @@ -7953,6 +7981,111 @@ "scores" ] }, + "DefaultRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default", + "default": "default" + }, + "separator": { + "type": "string", + "default": " " + } + }, + "additionalProperties": false, + "required": [ + "type", + "separator" + ] + }, + "LLMRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm", + "default": "llm" + }, + "model": { + "type": "string" + }, + "template": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ] + }, + "RAGQueryConfig": { + "type": "object", + "properties": { + "query_generator_config": { + "$ref": "#/components/schemas/RAGQueryGeneratorConfig" + }, + "max_tokens_in_context": { + "type": "integer", + "default": 4096 + }, + "max_chunks": { + "type": "integer", + "default": 5 + } + }, + "additionalProperties": false, + "required": [ + "query_generator_config", + "max_tokens_in_context", + "max_chunks" + ] + }, + "RAGQueryGeneratorConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig" + }, + { + "$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } + ] + }, + "QueryContextRequest": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + }, + "query_config": { + "$ref": "#/components/schemas/RAGQueryConfig" + }, + "vector_db_ids": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "content", + "query_config", + "vector_db_ids" + ] + }, + "RAGQueryResult": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + } + }, + "additionalProperties": false + }, "QueryCondition": { "type": "object", "properties": { @@ -8139,108 +8272,6 @@ "scoring_functions" ] }, - "GraphMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "graph", - "default": "graph" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type" - ] - }, - "KeyValueMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "keyvalue", - "default": "keyvalue" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type" - ] - }, - "KeywordMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "keyword", - "default": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type" - ] - }, - "VectorMemoryBankParams": { - "type": "object", - "properties": { - "memory_bank_type": { - "type": "string", - "const": "vector", - "default": "vector" - }, - "embedding_model": { - "type": "string" - }, - "chunk_size_in_tokens": { - "type": "integer" - }, - "overlap_size_in_tokens": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_type", - "embedding_model", - "chunk_size_in_tokens" - ] - }, - "RegisterMemoryBankRequest": { - "type": "object", - "properties": { - "memory_bank_id": { - "type": "string" - }, - "params": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorMemoryBankParams" - }, - { - "$ref": "#/components/schemas/KeyValueMemoryBankParams" - }, - { - "$ref": "#/components/schemas/KeywordMemoryBankParams" - }, - { - "$ref": "#/components/schemas/GraphMemoryBankParams" - } - ] - }, - "provider_id": { - "type": "string" - }, - "provider_memory_bank_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "memory_bank_id", - "params" - ] - }, "RegisterModelRequest": { "type": "object", "properties": { @@ -8413,6 +8444,31 @@ "provider_id" ] }, + "RegisterVectorDbRequest": { + "type": "object", + "properties": { + "vector_db_id": { + "type": "string" + }, + "embedding_model": { + "type": "string" + }, + "embedding_dimension": { + "type": "integer" + }, + "provider_id": { + "type": "string" + }, + "provider_vector_db_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "vector_db_id", + "embedding_model" + ] + }, "RunEvalRequest": { "type": "object", "properties": { @@ -9128,6 +9184,10 @@ { "name": "Datasets" }, + { + "name": "DefaultRAGQueryGeneratorConfig", + "description": "" + }, { "name": "EfficiencyConfig", "description": "" @@ -9158,14 +9218,6 @@ "name": "EvaluateRowsRequest", "description": "" }, - { - "name": "GraphMemoryBank", - "description": "" - }, - { - "name": "GraphMemoryBankParams", - "description": "" - }, { "name": "GreedySamplingStrategy", "description": "" @@ -9189,6 +9241,10 @@ "name": "InferenceStep", "description": "" }, + { + "name": "InsertChunksRequest", + "description": "" + }, { "name": "InsertDocumentsRequest", "description": "" @@ -9220,26 +9276,14 @@ "name": "JsonType", "description": "" }, - { - "name": "KeyValueMemoryBank", - "description": "" - }, - { - "name": "KeyValueMemoryBankParams", - "description": "" - }, - { - "name": "KeywordMemoryBank", - "description": "" - }, - { - "name": "KeywordMemoryBankParams", - "description": "" - }, { "name": "LLMAsJudgeScoringFnParams", "description": "" }, + { + "name": "LLMRAGQueryGeneratorConfig", + "description": "" + }, { "name": "ListDatasetsResponse", "description": "" @@ -9248,10 +9292,6 @@ "name": "ListEvalTasksResponse", "description": "" }, - { - "name": "ListMemoryBanksResponse", - "description": "" - }, { "name": "ListModelsResponse", "description": "" @@ -9284,6 +9324,10 @@ "name": "ListToolsResponse", "description": "" }, + { + "name": "ListVectorDBsResponse", + "description": "" + }, { "name": "LogEventRequest", "description": "" @@ -9296,20 +9340,6 @@ "name": "LoraFinetuningConfig", "description": "" }, - { - "name": "Memory" - }, - { - "name": "MemoryBank", - "description": "" - }, - { - "name": "MemoryBankDocument", - "description": "" - }, - { - "name": "MemoryBanks" - }, { "name": "MemoryRetrievalStep", "description": "" @@ -9388,6 +9418,14 @@ "name": "QATFinetuningConfig", "description": "" }, + { + "name": "QueryChunksRequest", + "description": "" + }, + { + "name": "QueryChunksResponse", + "description": "" + }, { "name": "QueryCondition", "description": "" @@ -9397,12 +9435,8 @@ "description": "" }, { - "name": "QueryDocumentsRequest", - "description": "" - }, - { - "name": "QueryDocumentsResponse", - "description": "" + "name": "QueryContextRequest", + "description": "" }, { "name": "QuerySpanTreeResponse", @@ -9416,6 +9450,22 @@ "name": "QueryTracesResponse", "description": "" }, + { + "name": "RAGDocument", + "description": "" + }, + { + "name": "RAGQueryConfig", + "description": "" + }, + { + "name": "RAGQueryGeneratorConfig", + "description": "" + }, + { + "name": "RAGQueryResult", + "description": "" + }, { "name": "RegexParserScoringFnParams", "description": "" @@ -9428,10 +9478,6 @@ "name": "RegisterEvalTaskRequest", "description": "" }, - { - "name": "RegisterMemoryBankRequest", - "description": "" - }, { "name": "RegisterModelRequest", "description": "" @@ -9448,6 +9494,10 @@ "name": "RegisterToolGroupRequest", "description": "" }, + { + "name": "RegisterVectorDbRequest", + "description": "" + }, { "name": "ResponseFormat", "description": "" @@ -9701,12 +9751,14 @@ "description": "" }, { - "name": "VectorMemoryBank", - "description": "" + "name": "VectorDB", + "description": "" }, { - "name": "VectorMemoryBankParams", - "description": "" + "name": "VectorDBs" + }, + { + "name": "VectorIO" }, { "name": "VersionInfo", @@ -9729,8 +9781,6 @@ "EvalTasks", "Inference", "Inspect", - "Memory", - "MemoryBanks", "Models", "PostTraining (Coming Soon)", "Safety", @@ -9740,7 +9790,9 @@ "SyntheticDataGeneration (Coming Soon)", "Telemetry", "ToolGroups", - "ToolRuntime" + "ToolRuntime", + "VectorDBs", + "VectorIO" ] }, { @@ -9793,19 +9845,19 @@ "DataConfig", "Dataset", "DatasetFormat", + "DefaultRAGQueryGeneratorConfig", "EfficiencyConfig", "EmbeddingsRequest", "EmbeddingsResponse", "EvalTask", "EvaluateResponse", "EvaluateRowsRequest", - "GraphMemoryBank", - "GraphMemoryBankParams", "GreedySamplingStrategy", "HealthInfo", "ImageContentItem", "ImageDelta", "InferenceStep", + "InsertChunksRequest", "InsertDocumentsRequest", "InterleavedContent", "InterleavedContentItem", @@ -9813,14 +9865,10 @@ "Job", "JobStatus", "JsonType", - "KeyValueMemoryBank", - "KeyValueMemoryBankParams", - "KeywordMemoryBank", - "KeywordMemoryBankParams", "LLMAsJudgeScoringFnParams", + "LLMRAGQueryGeneratorConfig", "ListDatasetsResponse", "ListEvalTasksResponse", - "ListMemoryBanksResponse", "ListModelsResponse", "ListPostTrainingJobsResponse", "ListProvidersResponse", @@ -9829,11 +9877,10 @@ "ListShieldsResponse", "ListToolGroupsResponse", "ListToolsResponse", + "ListVectorDBsResponse", "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", - "MemoryBank", - "MemoryBankDocument", "MemoryRetrievalStep", "Message", "MetricEvent", @@ -9852,21 +9899,26 @@ "PreferenceOptimizeRequest", "ProviderInfo", "QATFinetuningConfig", + "QueryChunksRequest", + "QueryChunksResponse", "QueryCondition", "QueryConditionOp", - "QueryDocumentsRequest", - "QueryDocumentsResponse", + "QueryContextRequest", "QuerySpanTreeResponse", "QuerySpansResponse", "QueryTracesResponse", + "RAGDocument", + "RAGQueryConfig", + "RAGQueryGeneratorConfig", + "RAGQueryResult", "RegexParserScoringFnParams", "RegisterDatasetRequest", "RegisterEvalTaskRequest", - "RegisterMemoryBankRequest", "RegisterModelRequest", "RegisterScoringFunctionRequest", "RegisterShieldRequest", "RegisterToolGroupRequest", + "RegisterVectorDbRequest", "ResponseFormat", "RouteInfo", "RunEvalRequest", @@ -9924,8 +9976,7 @@ "UnionType", "UnstructuredLogEvent", "UserMessage", - "VectorMemoryBank", - "VectorMemoryBankParams", + "VectorDB", "VersionInfo", "ViolationLevel" ] diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 9aeac6db3..e1ae07c45 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -761,6 +761,20 @@ components: - instruct - dialog type: string + DefaultRAGQueryGeneratorConfig: + additionalProperties: false + properties: + separator: + default: ' ' + type: string + type: + const: default + default: default + type: string + required: + - type + - separator + type: object EfficiencyConfig: additionalProperties: false properties: @@ -891,40 +905,6 @@ components: - scoring_functions - task_config type: object - GraphMemoryBank: - additionalProperties: false - properties: - identifier: - type: string - memory_bank_type: - const: graph - default: graph - type: string - provider_id: - type: string - provider_resource_id: - type: string - type: - const: memory_bank - default: memory_bank - type: string - required: - - identifier - - provider_resource_id - - provider_id - - type - - memory_bank_type - type: object - GraphMemoryBankParams: - additionalProperties: false - properties: - memory_bank_type: - const: graph - default: graph - type: string - required: - - memory_bank_type - type: object GreedySamplingStrategy: additionalProperties: false properties: @@ -997,20 +977,53 @@ components: - step_type - model_response type: object - InsertDocumentsRequest: + InsertChunksRequest: additionalProperties: false properties: - bank_id: - type: string - documents: + chunks: items: - $ref: '#/components/schemas/MemoryBankDocument' + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + required: + - content + - metadata + type: object type: array ttl_seconds: type: integer + vector_db_id: + type: string + required: + - vector_db_id + - chunks + type: object + InsertDocumentsRequest: + additionalProperties: false + properties: + chunk_size_in_tokens: + type: integer + documents: + items: + $ref: '#/components/schemas/RAGDocument' + type: array + vector_db_id: + type: string required: - - bank_id - documents + - vector_db_id + - chunk_size_in_tokens type: object InterleavedContent: oneOf: @@ -1026,7 +1039,7 @@ components: InvokeToolRequest: additionalProperties: false properties: - args: + kwargs: additionalProperties: oneOf: - type: 'null' @@ -1040,7 +1053,7 @@ components: type: string required: - tool_name - - args + - kwargs type: object Job: additionalProperties: false @@ -1067,74 +1080,6 @@ components: required: - type type: object - KeyValueMemoryBank: - additionalProperties: false - properties: - identifier: - type: string - memory_bank_type: - const: keyvalue - default: keyvalue - type: string - provider_id: - type: string - provider_resource_id: - type: string - type: - const: memory_bank - default: memory_bank - type: string - required: - - identifier - - provider_resource_id - - provider_id - - type - - memory_bank_type - type: object - KeyValueMemoryBankParams: - additionalProperties: false - properties: - memory_bank_type: - const: keyvalue - default: keyvalue - type: string - required: - - memory_bank_type - type: object - KeywordMemoryBank: - additionalProperties: false - properties: - identifier: - type: string - memory_bank_type: - const: keyword - default: keyword - type: string - provider_id: - type: string - provider_resource_id: - type: string - type: - const: memory_bank - default: memory_bank - type: string - required: - - identifier - - provider_resource_id - - provider_id - - type - - memory_bank_type - type: object - KeywordMemoryBankParams: - additionalProperties: false - properties: - memory_bank_type: - const: keyword - default: keyword - type: string - required: - - memory_bank_type - type: object LLMAsJudgeScoringFnParams: additionalProperties: false properties: @@ -1158,6 +1103,22 @@ components: - type - judge_model type: object + LLMRAGQueryGeneratorConfig: + additionalProperties: false + properties: + model: + type: string + template: + type: string + type: + const: llm + default: llm + type: string + required: + - type + - model + - template + type: object ListDatasetsResponse: additionalProperties: false properties: @@ -1178,16 +1139,6 @@ components: required: - data type: object - ListMemoryBanksResponse: - additionalProperties: false - properties: - data: - items: - $ref: '#/components/schemas/MemoryBank' - type: array - required: - - data - type: object ListModelsResponse: additionalProperties: false properties: @@ -1274,6 +1225,16 @@ components: required: - data type: object + ListVectorDBsResponse: + additionalProperties: false + properties: + data: + items: + $ref: '#/components/schemas/VectorDB' + type: array + required: + - data + type: object LogEventRequest: additionalProperties: false properties: @@ -1330,42 +1291,6 @@ components: - rank - alpha type: object - MemoryBank: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBank' - - $ref: '#/components/schemas/KeyValueMemoryBank' - - $ref: '#/components/schemas/KeywordMemoryBank' - - $ref: '#/components/schemas/GraphMemoryBank' - MemoryBankDocument: - additionalProperties: false - properties: - content: - oneOf: - - type: string - - $ref: '#/components/schemas/InterleavedContentItem' - - items: - $ref: '#/components/schemas/InterleavedContentItem' - type: array - - $ref: '#/components/schemas/URL' - document_id: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - mime_type: - type: string - required: - - document_id - - content - - metadata - type: object MemoryRetrievalStep: additionalProperties: false properties: @@ -1705,6 +1630,59 @@ components: - quantizer_name - group_size type: object + QueryChunksRequest: + additionalProperties: false + properties: + params: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + query: + $ref: '#/components/schemas/InterleavedContent' + vector_db_id: + type: string + required: + - vector_db_id + - query + type: object + QueryChunksResponse: + additionalProperties: false + properties: + chunks: + items: + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + required: + - content + - metadata + type: object + type: array + scores: + items: + type: number + type: array + required: + - chunks + - scores + type: object QueryCondition: additionalProperties: false properties: @@ -1732,53 +1710,21 @@ components: - gt - lt type: string - QueryDocumentsRequest: + QueryContextRequest: additionalProperties: false properties: - bank_id: - type: string - params: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - query: + content: $ref: '#/components/schemas/InterleavedContent' - required: - - bank_id - - query - type: object - QueryDocumentsResponse: - additionalProperties: false - properties: - chunks: + query_config: + $ref: '#/components/schemas/RAGQueryConfig' + vector_db_ids: items: - additionalProperties: false - properties: - content: - $ref: '#/components/schemas/InterleavedContent' - document_id: - type: string - token_count: - type: integer - required: - - content - - token_count - - document_id - type: object - type: array - scores: - items: - type: number + type: string type: array required: - - chunks - - scores + - content + - query_config + - vector_db_ids type: object QuerySpanTreeResponse: additionalProperties: false @@ -1810,6 +1756,62 @@ components: required: - data type: object + RAGDocument: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - items: + $ref: '#/components/schemas/InterleavedContentItem' + type: array + - $ref: '#/components/schemas/URL' + document_id: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + mime_type: + type: string + required: + - document_id + - content + - metadata + type: object + RAGQueryConfig: + additionalProperties: false + properties: + max_chunks: + default: 5 + type: integer + max_tokens_in_context: + default: 4096 + type: integer + query_generator_config: + $ref: '#/components/schemas/RAGQueryGeneratorConfig' + required: + - query_generator_config + - max_tokens_in_context + - max_chunks + type: object + RAGQueryGeneratorConfig: + oneOf: + - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig' + RAGQueryResult: + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + type: object RegexParserScoringFnParams: additionalProperties: false properties: @@ -1888,25 +1890,6 @@ components: - dataset_id - scoring_functions type: object - RegisterMemoryBankRequest: - additionalProperties: false - properties: - memory_bank_id: - type: string - params: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBankParams' - - $ref: '#/components/schemas/KeyValueMemoryBankParams' - - $ref: '#/components/schemas/KeywordMemoryBankParams' - - $ref: '#/components/schemas/GraphMemoryBankParams' - provider_id: - type: string - provider_memory_bank_id: - type: string - required: - - memory_bank_id - - params - type: object RegisterModelRequest: additionalProperties: false properties: @@ -1999,6 +1982,23 @@ components: - toolgroup_id - provider_id type: object + RegisterVectorDbRequest: + additionalProperties: false + properties: + embedding_dimension: + type: integer + embedding_model: + type: string + provider_id: + type: string + provider_vector_db_id: + type: string + vector_db_id: + type: string + required: + - vector_db_id + - embedding_model + type: object ResponseFormat: oneOf: - additionalProperties: false @@ -2298,8 +2298,6 @@ components: Session: additionalProperties: false properties: - memory_bank: - $ref: '#/components/schemas/MemoryBank' session_id: type: string session_name: @@ -3202,58 +3200,30 @@ components: - role - content type: object - VectorMemoryBank: + VectorDB: additionalProperties: false properties: - chunk_size_in_tokens: - type: integer embedding_dimension: - default: 384 type: integer embedding_model: type: string identifier: type: string - memory_bank_type: - const: vector - default: vector - type: string - overlap_size_in_tokens: - type: integer provider_id: type: string provider_resource_id: type: string type: - const: memory_bank - default: memory_bank + const: vector_db + default: vector_db type: string required: - identifier - provider_resource_id - provider_id - type - - memory_bank_type - embedding_model - - chunk_size_in_tokens - type: object - VectorMemoryBankParams: - additionalProperties: false - properties: - chunk_size_in_tokens: - type: integer - embedding_model: - type: string - memory_bank_type: - const: vector - default: vector - type: string - overlap_size_in_tokens: - type: integer - required: - - memory_bank_type - - embedding_model - - chunk_size_in_tokens + - embedding_dimension type: object VersionInfo: additionalProperties: false @@ -4272,186 +4242,6 @@ paths: description: OK tags: - Inspect - /v1/memory-banks: - get: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/ListMemoryBanksResponse' - description: OK - tags: - - MemoryBanks - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterMemoryBankRequest' - required: true - responses: - '200': - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/VectorMemoryBank' - - $ref: '#/components/schemas/KeyValueMemoryBank' - - $ref: '#/components/schemas/KeywordMemoryBank' - - $ref: '#/components/schemas/GraphMemoryBank' - description: '' - tags: - - MemoryBanks - /v1/memory-banks/{memory_bank_id}: - delete: - parameters: - - in: path - name: memory_bank_id - required: true - schema: - type: string - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - responses: - '200': - description: OK - tags: - - MemoryBanks - get: - parameters: - - in: path - name: memory_bank_id - required: true - schema: - type: string - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - responses: - '200': - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/MemoryBank' - - type: 'null' - description: OK - tags: - - MemoryBanks - /v1/memory/insert: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/InsertDocumentsRequest' - required: true - responses: - '200': - description: OK - tags: - - Memory - /v1/memory/query: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-Provider-Data - required: false - schema: - type: string - - description: Version of the client making the request. This is used to ensure - that the client and server are compatible. - in: header - name: X-LlamaStack-Client-Version - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/QueryDocumentsRequest' - required: true - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/QueryDocumentsResponse' - description: OK - tags: - - Memory /v1/models: get: parameters: @@ -5386,6 +5176,68 @@ paths: description: OK tags: - ToolRuntime + /v1/tool-runtime/rag-tool/insert-documents: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertDocumentsRequest' + required: true + responses: + '200': + description: OK + summary: Index documents so they can be used by the RAG system + tags: + - ToolRuntime + /v1/tool-runtime/rag-tool/query-context: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryContextRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RAGQueryResult' + description: OK + summary: Query the RAG system for context; typically invoked by the agent + tags: + - ToolRuntime /v1/toolgroups: get: parameters: @@ -5562,6 +5414,182 @@ paths: description: OK tags: - ToolGroups + /v1/vector-dbs: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ListVectorDBsResponse' + description: OK + tags: + - VectorDBs + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterVectorDbRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/VectorDB' + description: OK + tags: + - VectorDBs + /v1/vector-dbs/{vector_db_id}: + delete: + parameters: + - in: path + name: vector_db_id + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + responses: + '200': + description: OK + tags: + - VectorDBs + get: + parameters: + - in: path + name: vector_db_id + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/VectorDB' + - type: 'null' + description: OK + tags: + - VectorDBs + /v1/vector-io/insert: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertChunksRequest' + required: true + responses: + '200': + description: OK + tags: + - VectorIO + /v1/vector-io/query: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-Provider-Data + required: false + schema: + type: string + - description: Version of the client making the request. This is used to ensure + that the client and server are compatible. + in: header + name: X-LlamaStack-Client-Version + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryChunksRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/QueryChunksResponse' + description: OK + tags: + - VectorIO /v1/version: get: parameters: @@ -5748,6 +5776,9 @@ tags: name: DatasetFormat - name: DatasetIO - name: Datasets +- description: + name: DefaultRAGQueryGeneratorConfig - description: name: EfficiencyConfig @@ -5767,12 +5798,6 @@ tags: - description: name: EvaluateRowsRequest -- description: - name: GraphMemoryBank -- description: - name: GraphMemoryBankParams - description: name: GreedySamplingStrategy @@ -5786,6 +5811,9 @@ tags: - name: Inference - description: name: InferenceStep +- description: + name: InsertChunksRequest - description: name: InsertDocumentsRequest @@ -5805,30 +5833,18 @@ tags: name: JobStatus - description: name: JsonType -- description: - name: KeyValueMemoryBank -- description: - name: KeyValueMemoryBankParams -- description: - name: KeywordMemoryBank -- description: - name: KeywordMemoryBankParams - description: name: LLMAsJudgeScoringFnParams +- description: + name: LLMRAGQueryGeneratorConfig - description: name: ListDatasetsResponse - description: name: ListEvalTasksResponse -- description: - name: ListMemoryBanksResponse - description: name: ListModelsResponse @@ -5853,6 +5869,9 @@ tags: - description: name: ListToolsResponse +- description: + name: ListVectorDBsResponse - description: name: LogEventRequest @@ -5861,13 +5880,6 @@ tags: - description: name: LoraFinetuningConfig -- name: Memory -- description: - name: MemoryBank -- description: - name: MemoryBankDocument -- name: MemoryBanks - description: name: MemoryRetrievalStep @@ -5920,17 +5932,20 @@ tags: - description: name: QATFinetuningConfig +- description: + name: QueryChunksRequest +- description: + name: QueryChunksResponse - description: name: QueryCondition - description: name: QueryConditionOp -- description: - name: QueryDocumentsRequest -- description: - name: QueryDocumentsResponse + name: QueryContextRequest - description: name: QuerySpanTreeResponse @@ -5940,6 +5955,15 @@ tags: - description: name: QueryTracesResponse +- description: + name: RAGDocument +- description: + name: RAGQueryConfig +- description: + name: RAGQueryGeneratorConfig +- description: + name: RAGQueryResult - description: name: RegexParserScoringFnParams @@ -5949,9 +5973,6 @@ tags: - description: name: RegisterEvalTaskRequest -- description: - name: RegisterMemoryBankRequest - description: name: RegisterModelRequest @@ -5964,6 +5985,9 @@ tags: - description: name: RegisterToolGroupRequest +- description: + name: RegisterVectorDbRequest - description: name: ResponseFormat - description: @@ -6128,12 +6152,10 @@ tags: name: UnstructuredLogEvent - description: name: UserMessage -- description: - name: VectorMemoryBank -- description: - name: VectorMemoryBankParams +- description: + name: VectorDB +- name: VectorDBs +- name: VectorIO - description: name: VersionInfo - description: @@ -6149,8 +6171,6 @@ x-tagGroups: - EvalTasks - Inference - Inspect - - Memory - - MemoryBanks - Models - PostTraining (Coming Soon) - Safety @@ -6161,6 +6181,8 @@ x-tagGroups: - Telemetry - ToolGroups - ToolRuntime + - VectorDBs + - VectorIO - name: Types tags: - AgentCandidate @@ -6210,19 +6232,19 @@ x-tagGroups: - DataConfig - Dataset - DatasetFormat + - DefaultRAGQueryGeneratorConfig - EfficiencyConfig - EmbeddingsRequest - EmbeddingsResponse - EvalTask - EvaluateResponse - EvaluateRowsRequest - - GraphMemoryBank - - GraphMemoryBankParams - GreedySamplingStrategy - HealthInfo - ImageContentItem - ImageDelta - InferenceStep + - InsertChunksRequest - InsertDocumentsRequest - InterleavedContent - InterleavedContentItem @@ -6230,14 +6252,10 @@ x-tagGroups: - Job - JobStatus - JsonType - - KeyValueMemoryBank - - KeyValueMemoryBankParams - - KeywordMemoryBank - - KeywordMemoryBankParams - LLMAsJudgeScoringFnParams + - LLMRAGQueryGeneratorConfig - ListDatasetsResponse - ListEvalTasksResponse - - ListMemoryBanksResponse - ListModelsResponse - ListPostTrainingJobsResponse - ListProvidersResponse @@ -6246,11 +6264,10 @@ x-tagGroups: - ListShieldsResponse - ListToolGroupsResponse - ListToolsResponse + - ListVectorDBsResponse - LogEventRequest - LogSeverity - LoraFinetuningConfig - - MemoryBank - - MemoryBankDocument - MemoryRetrievalStep - Message - MetricEvent @@ -6269,21 +6286,26 @@ x-tagGroups: - PreferenceOptimizeRequest - ProviderInfo - QATFinetuningConfig + - QueryChunksRequest + - QueryChunksResponse - QueryCondition - QueryConditionOp - - QueryDocumentsRequest - - QueryDocumentsResponse + - QueryContextRequest - QuerySpanTreeResponse - QuerySpansResponse - QueryTracesResponse + - RAGDocument + - RAGQueryConfig + - RAGQueryGeneratorConfig + - RAGQueryResult - RegexParserScoringFnParams - RegisterDatasetRequest - RegisterEvalTaskRequest - - RegisterMemoryBankRequest - RegisterModelRequest - RegisterScoringFunctionRequest - RegisterShieldRequest - RegisterToolGroupRequest + - RegisterVectorDbRequest - ResponseFormat - RouteInfo - RunEvalRequest @@ -6341,7 +6363,6 @@ x-tagGroups: - UnionType - UnstructuredLogEvent - UserMessage - - VectorMemoryBank - - VectorMemoryBankParams + - VectorDB - VersionInfo - ViolationLevel diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py index f747fcdc2..8cd798ebf 100644 --- a/llama_stack/apis/tools/__init__.py +++ b/llama_stack/apis/tools/__init__.py @@ -5,3 +5,4 @@ # the root directory of this source tree. from .tools import * # noqa: F401 F403 +from .rag_tool import * # noqa: F401 F403 diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py new file mode 100644 index 000000000..0247bb384 --- /dev/null +++ b/llama_stack/apis/tools/rag_tool.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Protocol, runtime_checkable + +from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol + + +@json_schema_type +class RAGDocument(BaseModel): + document_id: str + content: InterleavedContent | URL + mime_type: str | None = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + +@json_schema_type +class RAGQueryResult(BaseModel): + content: Optional[InterleavedContent] = None + + +@json_schema_type +class RAGQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +@json_schema_type +class DefaultRAGQueryGeneratorConfig(BaseModel): + type: Literal["default"] = "default" + separator: str = " " + + +@json_schema_type +class LLMRAGQueryGeneratorConfig(BaseModel): + type: Literal["llm"] = "llm" + model: str + template: str + + +RAGQueryGeneratorConfig = register_schema( + Annotated[ + Union[ + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, + ], + Field(discriminator="type"), + ], + name="RAGQueryGeneratorConfig", +) + + +@json_schema_type +class RAGQueryConfig(BaseModel): + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: RAGQueryGeneratorConfig = Field( + default=DefaultRAGQueryGeneratorConfig() + ) + max_tokens_in_context: int = 4096 + max_chunks: int = 5 + + +@runtime_checkable +@trace_protocol +class RAGToolRuntime(Protocol): + @webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST") + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + """Index documents so they can be used by the RAG system""" + ... + + @webmethod(route="/tool-runtime/rag-tool/query-context", method="POST") + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + """Query the RAG system for context; typically invoked by the agent""" + ... diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index fb990cc41..1af019bd4 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from .rag_tool import RAGToolRuntime + @json_schema_type class ToolParameter(BaseModel): @@ -130,11 +132,17 @@ class ToolGroups(Protocol): ... +class SpecialToolGroup(Enum): + rag_tool = "rag_tool" + + @runtime_checkable @trace_protocol class ToolRuntime(Protocol): tool_store: ToolStore + rag_tool: RAGToolRuntime + # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( @@ -143,7 +151,7 @@ class ToolRuntime(Protocol): @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index bd5a9ae98..dd6d4be6f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -333,6 +333,8 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config + # TODO: check compliance for special tool groups + # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol check_protocol_compliance(impl, protocols[provider_spec.api]) if ( not isinstance(provider_spec, AutoRoutedProviderSpec) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 979c68b72..3ae9833dc 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -36,7 +36,14 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import ToolDef, ToolRuntime +from llama_stack.apis.tools import ( + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolDef, + ToolRuntime, +) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable @@ -400,22 +407,55 @@ class EvalRouter(Eval): class ToolRuntimeRouter(ToolRuntime): + class RagToolImpl(RAGToolRuntime): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + return await self.routing_table.get_provider_impl( + "rag_tool.query_context" + ).query_context(content, query_config, vector_db_ids) + + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + return await self.routing_table.get_provider_impl( + "rag_tool.insert_documents" + ).insert_documents(documents, vector_db_id, chunk_size_in_tokens) + def __init__( self, routing_table: RoutingTable, ) -> None: self.routing_table = routing_table + # HACK ALERT this should be in sync with "get_all_api_endpoints()" + # TODO: make sure rag_tool vs builtin::memory is correct everywhere + self.rag_tool = self.RagToolImpl(routing_table) + setattr(self, "rag_tool.query_context", self.rag_tool.query_context) + setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents) + async def initialize(self) -> None: pass async def shutdown(self) -> None: pass - async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: + async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, - args=args, + kwargs=kwargs, ) async def list_runtime_tools( diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index af429e020..180479e40 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -9,6 +9,8 @@ from typing import Dict, List from pydantic import BaseModel +from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup + from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map @@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel): name: str +def toolgroup_protocol_map(): + return { + SpecialToolGroup.rag_tool: RAGToolRuntime, + } + + def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: apis = {} protocols = api_protocol_map() + toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + # HACK ALERT + if api == Api.tool_runtime: + for tool_group in SpecialToolGroup: + sub_protocol = toolgroup_protocols[tool_group] + sub_protocol_methods = inspect.getmembers( + sub_protocol, predicate=inspect.isfunction + ) + for name, method in sub_protocol_methods: + if not hasattr(method, "__webmethod__"): + continue + protocol_methods.append((f"{tool_group.value}.{name}", method)) + for name, method in protocol_methods: if not hasattr(method, "__webmethod__"): continue webmethod = method.__webmethod__ route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" - if webmethod.method == "GET": method = "get" elif webmethod.method == "DELETE": diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 180ec0ecc..f0c34dba4 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.distribution.datatypes import StackRunConfig @@ -62,6 +62,7 @@ class LlamaStack( Inspect, ToolGroups, ToolRuntime, + RAGToolRuntime, ): pass diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 010d137ec..5c0b8b5db 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -35,7 +35,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v5" +KEY_VERSION = "v6" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 50f61fb42..de34b8d2c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -19,9 +19,8 @@ async def get_provider_impl( impl = MetaReferenceAgentsImpl( config, deps[Api.inference], - deps[Api.memory], + deps[Api.vector_io], deps[Api.safety], - deps[Api.memory_banks], deps[Api.tool_runtime], deps[Api.tool_groups], ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2ebc7ded1..5b5175cee 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -59,13 +59,18 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import Memory, MemoryBankDocument -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import ( + DefaultRAGQueryGeneratorConfig, + RAGDocument, + RAGQueryConfig, + ToolGroups, + ToolRuntime, +) +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing - from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin @@ -79,7 +84,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_QUERY_TOOL = "query_memory" +MEMORY_QUERY_TOOL = "rag_tool.query_context" WEB_SEARCH_TOOL = "web_search" MEMORY_GROUP = "builtin::memory" @@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin): agent_config: AgentConfig, tempdir: str, inference_api: Inference, - memory_api: Memory, - memory_banks_api: MemoryBanks, safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + vector_io_api: VectorIO, persistence_store: KVStore, ): self.agent_id = agent_id self.agent_config = agent_config self.tempdir = tempdir self.inference_api = inference_api - self.memory_api = memory_api - self.memory_banks_api = memory_banks_api self.safety_api = safety_api + self.vector_io_api = vector_io_api self.storage = AgentPersistence(agent_id, persistence_store) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin): documents: Optional[List[Document]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: + # TODO: simplify all of this code, it can be simpler toolgroup_args = {} + toolgroups = set() for toolgroup in self.agent_config.toolgroups: if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroups.add(toolgroup.name) toolgroup_args[toolgroup.name] = toolgroup.args + else: + toolgroups.add(toolgroup) if toolgroups_for_turn: for toolgroup in toolgroups_for_turn: if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroups.add(toolgroup.name) toolgroup_args[toolgroup.name] = toolgroup.args + else: + toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: await self.handle_documents( session_id, documents, input_messages, tool_defs ) - if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: - memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) - if memory_tool_group is None: - raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}") + + if MEMORY_GROUP in toolgroups and len(input_messages) > 0: with tracing.span(MEMORY_QUERY_TOOL) as span: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - query_args = { - "messages": [msg.content for msg in input_messages], - **toolgroup_args.get(memory_tool_group, {}), - } + args = toolgroup_args.get(MEMORY_GROUP, {}) + vector_db_ids = args.get("vector_db_ids", []) session_info = await self.storage.get_session_info(session_id) + # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - if "memory_bank_ids" not in query_args: - query_args["memory_bank_ids"] = [] - query_args["memory_bank_ids"].append(session_info.memory_bank_id) + vector_db_ids.append(session_info.memory_bank_id) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=MEMORY_QUERY_TOOL, - args=query_args, + result = await self.tool_runtime_api.rag_tool.query_context( + content=concat_interleaved_content( + [msg.content for msg in input_messages] + ), + query_config=RAGQueryConfig( + query_generator_config=DefaultRAGQueryGeneratorConfig(), + max_tokens_in_context=4096, + max_chunks=5, + ), + vector_db_ids=vector_db_ids, ) + retrieved_context = result.content yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin): ToolResponse( call_id="", tool_name=MEMORY_QUERY_TOOL, - content=result.content, + content=retrieved_context or [], ) ], ), @@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute( "input", [m.model_dump_json() for m in input_messages] ) - span.set_attribute("output", result.content) - span.set_attribute("error_code", result.error_code) - span.set_attribute("error_message", result.error_message) + span.set_attribute("output", retrieved_context) span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - if result.error_code == 0: + if retrieved_context: last_message = input_messages[-1] - last_message.context = result.content + last_message.context = retrieved_context output_attachments = [] @@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin): if session_info.memory_bank_id is None: bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), + + # TODO: the semantic for registration is definitely not "creation" + # so we need to fix it if we expect the agent to create a new vector db + # for each session + await self.vector_io_api.register_vector_db( + vector_db_id=bank_id, + embedding_model="all-MiniLM-L6-v2", ) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: @@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin): async def add_to_session_memory_bank( self, session_id: str, data: List[Document] ) -> None: - bank_id = await self._ensure_memory_bank(session_id) + vector_db_id = await self._ensure_memory_bank(session_id) documents = [ - MemoryBankDocument( + RAGDocument( document_id=str(uuid.uuid4()), content=a.content, mime_type=a.mime_type, @@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin): ) for a in data ] - await self.memory_api.insert_documents( - bank_id=bank_id, + await self.tool_runtime_api.rag_tool.insert_documents( documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=512, ) @@ -955,7 +970,7 @@ async def execute_tool_call_maybe( result = await tool_runtime_api.invoke_tool( tool_name=name, - args=dict( + kwargs=dict( session_id=session_id, **tool_call_args, ), diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index d22ef82ab..b1844f4d0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -26,10 +26,9 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent @@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents): self, config: MetaReferenceAgentsImplConfig, inference_api: Inference, - memory_api: Memory, + vector_io_api: VectorIO, safety_api: Safety, - memory_banks_api: MemoryBanks, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, ): self.config = config self.inference_api = inference_api - self.memory_api = memory_api + self.vector_io_api = vector_io_api self.safety_api = safety_api - self.memory_banks_api = memory_banks_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents): tempdir=self.tempdir, inference_api=self.inference_api, safety_api=self.safety_api, - memory_api=self.memory_api, - memory_banks_api=self.memory_banks_api, + vector_io_api=self.vector_io_api, tool_runtime_api=self.tool_runtime_api, tool_groups_api=self.tool_groups_api, persistence_store=( diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 361c91a92..04434768d 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: - script = args["code"] + script = kwargs["code"] req = CodeExecutionRequest(scripts=[script]) res = self.code_executor.execute(req) pieces = [res["process_status"]] diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py index 928afa484..42a0a6b01 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): - impl = MemoryToolRuntimeImpl( - config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] - ) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py index 6ff242c6b..4a20c986c 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/config.py +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -4,87 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import Annotated, List, Literal, Union - -from pydantic import BaseModel, Field - - -class _MemoryBankConfigCommon(BaseModel): - bank_id: str - - -class VectorMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["vector"] = "vector" - - -class KeyValueMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyvalue"] = "keyvalue" - keys: List[str] # what keys to focus on - - -class KeywordMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyword"] = "keyword" - - -class GraphMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["graph"] = "graph" - entities: List[str] # what entities to focus on - - -MemoryBankConfig = Annotated[ - Union[ - VectorMemoryBankConfig, - KeyValueMemoryBankConfig, - KeywordMemoryBankConfig, - GraphMemoryBankConfig, - ], - Field(discriminator="type"), -] - - -class MemoryQueryGenerator(Enum): - default = "default" - llm = "llm" - custom = "custom" - - -class DefaultMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.default.value] = ( - MemoryQueryGenerator.default.value - ) - sep: str = " " - - -class LLMMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value - model: str - template: str - - -class CustomMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value - - -MemoryQueryGeneratorConfig = Annotated[ - Union[ - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - CustomMemoryQueryGeneratorConfig, - ], - Field(discriminator="type"), -] - - -class MemoryToolConfig(BaseModel): - memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) +from pydantic import BaseModel class MemoryToolRuntimeConfig(BaseModel): - # This config defines how a query is generated using the messages - # for memory bank retrieval. - query_generator_config: MemoryQueryGeneratorConfig = Field( - default=DefaultMemoryQueryGeneratorConfig() - ) - max_tokens_in_context: int = 4096 - max_chunks: int = 5 + pass diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 803981f07..e77ec76af 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -5,68 +5,64 @@ # the root directory of this source tree. -from typing import List - from jinja2 import Template -from pydantic import BaseModel from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import UserMessage + +from llama_stack.apis.tools.rag_tool import ( + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, + RAGQueryGenerator, + RAGQueryGeneratorConfig, +) from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from .config import ( - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - MemoryQueryGenerator, - MemoryQueryGeneratorConfig, -) - async def generate_rag_query( - config: MemoryQueryGeneratorConfig, - messages: List[InterleavedContent], + config: RAGQueryGeneratorConfig, + content: InterleavedContent, **kwargs, ): """ Generates a query that will be used for retrieving relevant information from the memory bank. """ - if config.type == MemoryQueryGenerator.default.value: - query = await default_rag_query_generator(config, messages, **kwargs) - elif config.type == MemoryQueryGenerator.llm.value: - query = await llm_rag_query_generator(config, messages, **kwargs) + if config.type == RAGQueryGenerator.default.value: + query = await default_rag_query_generator(config, content, **kwargs) + elif config.type == RAGQueryGenerator.llm.value: + query = await llm_rag_query_generator(config, content, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") return query async def default_rag_query_generator( - config: DefaultMemoryQueryGeneratorConfig, - messages: List[InterleavedContent], + config: DefaultRAGQueryGeneratorConfig, + content: InterleavedContent, **kwargs, ): - return config.sep.join(interleaved_content_as_str(m) for m in messages) + return interleaved_content_as_str(content, sep=config.separator) async def llm_rag_query_generator( - config: LLMMemoryQueryGeneratorConfig, - messages: List[InterleavedContent], + config: LLMRAGQueryGeneratorConfig, + content: InterleavedContent, **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = { - "messages": [ - message.model_dump() if isinstance(message, BaseModel) else message - for message in messages - ] - } + messages = [] + if isinstance(content, list): + messages = [interleaved_content_as_str(m) for m in content] + else: + messages = [interleaved_content_as_str(content)] template = Template(config.template) - content = template.render(m_dict) + content = template.render({"messages": messages}) model = config.model message = UserMessage(content=content) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index fe6325abb..d3f8b07dc 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -10,20 +10,29 @@ import secrets import string from typing import Any, Dict, List, Optional -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.memory import Memory, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, ToolDef, ToolInvocationResult, - ToolParameter, ToolRuntime, ) +from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content +from llama_stack.providers.utils.memory.vector_store import ( + content_from_doc, + make_overlapped_chunks, +) -from .config import MemoryToolConfig, MemoryToolRuntimeConfig +from .config import MemoryToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) @@ -35,65 +44,79 @@ def make_random_string(length: int = 8): ) -class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): +class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: MemoryToolRuntimeConfig, - memory_api: Memory, - memory_banks_api: MemoryBanks, + vector_io_api: VectorIO, inference_api: Inference, ): self.config = config - self.memory_api = memory_api - self.memory_banks_api = memory_banks_api + self.vector_io_api = vector_io_api self.inference_api = inference_api async def initialize(self): pass - async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None - ) -> List[ToolDef]: - return [ - ToolDef( - name="query_memory", - description="Retrieve context from memory", - parameters=[ - ToolParameter( - name="messages", - description="The input messages to search for", - parameter_type="array", - ), - ], - ) - ] + async def shutdown(self): + pass + + async def insert_documents( + self, + documents: List[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + chunks = [] + for doc in documents: + content = await content_from_doc(doc) + chunks.extend( + make_overlapped_chunks( + doc.document_id, + content, + chunk_size_in_tokens, + chunk_size_in_tokens // 4, + ) + ) + + if not chunks: + return + + await self.vector_io_api.insert_chunks( + chunks=chunks, + vector_db_id=vector_db_id, + ) + + async def query_context( + self, + content: InterleavedContent, + query_config: RAGQueryConfig, + vector_db_ids: List[str], + ) -> RAGQueryResult: + if not vector_db_ids: + return RAGQueryResult(content=None) - async def _retrieve_context( - self, input_messages: List[InterleavedContent], bank_ids: List[str] - ) -> Optional[List[InterleavedContent]]: - if not bank_ids: - return None query = await generate_rag_query( - self.config.query_generator_config, - input_messages, + query_config.query_generator_config, + content, inference_api=self.inference_api, ) tasks = [ - self.memory_api.query_documents( - bank_id=bank_id, + self.vector_io_api.query_chunks( + vector_db_id=vector_db_id, query=query, params={ - "max_chunks": self.config.max_chunks, + "max_chunks": query_config.max_chunks, }, ) - for bank_id in bank_ids + for vector_db_id in vector_db_ids ] - results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) + results: List[QueryChunksResponse] = await asyncio.gather(*tasks) chunks = [c for r in results for c in r.chunks] scores = [s for r in results for s in r.scores] if not chunks: - return None + return RAGQueryResult(content=None) # sort by score chunks, scores = zip( @@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): tokens = 0 picked = [] - for c in chunks[: self.config.max_chunks]: - tokens += c.token_count - if tokens > self.config.max_tokens_in_context: + for c in chunks[: query_config.max_chunks]: + metadata = c.metadata + tokens += metadata["token_count"] + if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break - picked.append(f"id:{c.document_id}; content:{c.content}") + picked.append( + TextContentItem( + text=f"id:{metadata['document_id']}; content:{c.content}", + ) + ) + return RAGQueryResult( + content=[ + TextContentItem( + text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", + ), + *picked, + TextContentItem( + text="\n=== END-RETRIEVED-CONTEXT ===\n", + ), + ], + ) + + async def list_runtime_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + # Parameters are not listed since these methods are not yet invoked automatically + # by the LLM. The method is only implemented so things like /tools can list without + # encountering fatals. return [ - "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - *picked, - "\n=== END-RETRIEVED-CONTEXT ===\n", + ToolDef( + name="rag_tool.query_context", + description="Retrieve context from memory", + ), + ToolDef( + name="rag_tool.insert_documents", + description="Insert documents into memory", + ), ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: - tool = await self.tool_store.get_tool(tool_name) - tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) - final_args = tool_group.args or {} - final_args.update(args) - config = MemoryToolConfig() - if tool.metadata and tool.metadata.get("config") is not None: - config = MemoryToolConfig(**tool.metadata["config"]) - if "memory_bank_ids" in final_args: - bank_ids = final_args["memory_bank_ids"] - else: - bank_ids = [ - bank_config.bank_id for bank_config in config.memory_bank_configs - ] - if "messages" not in final_args: - raise ValueError("messages are required") - context = await self._retrieve_context( - final_args["messages"], - bank_ids, - ) - if context is None: - context = [] - return ToolInvocationResult( - content=concat_interleaved_content(context), error_code=0 + raise RuntimeError( + "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol" ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b3ea68949..426fe22f2 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.inline.tool_runtime.memory", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference], + api_dependencies=[Api.vector_io, Api.inference], ), InlineProviderSpec( api=Api.tool_runtime, diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 5114e06aa..677e29c12 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() headers = { @@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl( "count": self.config.top_k, "textDecorations": True, "textFormat": "HTML", - "q": args["query"], + "q": kwargs["query"], } response = requests.get( diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 016f746ea..1162cc900 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() url = "https://api.search.brave.com/res/v1/web/search" @@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl( "Accept-Encoding": "gzip", "Accept": "application/json", } - payload = {"q": args["query"]} + payload = {"q": kwargs["query"]} response = requests.get(url=url, params=payload, headers=headers) response.raise_for_status() results = self._clean_brave_response(response.json()) diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index a304167e9..e0caec1d0 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): return tools async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: @@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async with sse_client(endpoint) as streams: async with ClientSession(*streams) as session: await session.initialize() - result = await session.call_tool(tool.identifier, args) + result = await session.call_tool(tool.identifier, kwargs) return ToolInvocationResult( content="\n".join([result.model_dump_json() for result in result.content]), diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 82077193e..f5826c0ff 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() response = requests.post( "https://api.tavily.com/search", - json={"api_key": api_key, "query": args["query"]}, + json={"api_key": api_key, "query": kwargs["query"]}, ) return ToolInvocationResult( diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 04ecfcc15..bf298c13e 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl( ] async def invoke_tool( - self, tool_name: str, args: Dict[str, Any] + self, tool_name: str, kwargs: Dict[str, Any] ) -> ToolInvocationResult: api_key = self._get_api_key() params = { - "input": args["query"], + "input": kwargs["query"], "appid": api_key, "format": "plaintext", "output": "json", diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 4efdfe8b7..9c115e3a1 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -12,10 +12,10 @@ from ..conftest import ( get_test_config_for_api, ) from ..inference.fixtures import INFERENCE_FIXTURES -from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from ..tools.fixtures import TOOL_RUNTIME_FIXTURES +from ..vector_io.fixtures import VECTOR_IO_FIXTURES from .fixtures import AGENTS_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ @@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "meta_reference", "safety": "llama_guard", - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "ollama", "safety": "llama_guard", - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "inference": "together", "safety": "llama_guard", # make this work with Weaviate which is what the together distro supports - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "fireworks", "safety": "llama_guard", - "memory": "faiss", + "vector_io": "faiss", "agents": "meta_reference", "tool_runtime": "memory_and_search", }, @@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "remote", "safety": "remote", - "memory": "remote", + "vector_io": "remote", "agents": "remote", "tool_runtime": "memory_and_search", }, @@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc): available_fixtures = { "inference": INFERENCE_FIXTURES, "safety": SAFETY_FIXTURES, - "memory": MEMORY_FIXTURES, + "vector_io": VECTOR_IO_FIXTURES, "agents": AGENTS_FIXTURES, "tool_runtime": TOOL_RUNTIME_FIXTURES, } diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 1b1781f36..bb4a6e6a3 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -69,7 +69,7 @@ async def agents_stack( providers = {} provider_data = {} - for key in ["inference", "safety", "memory", "agents", "tool_runtime"]: + for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -118,7 +118,7 @@ async def agents_stack( ) test_stack = await construct_stack_for_test( - [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], + [Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime], providers, provider_data, models=models, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 320096826..f11aef3ec 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -214,9 +214,11 @@ class TestAgents: turn_response = [ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] - assert len(turn_response) > 0 + # FIXME: we need to check the content of the turn response and ensure + # RAG actually worked + @pytest.mark.asyncio async def test_create_agent_turn_with_tavily_search( self, agents_stack, search_query_messages, common_params diff --git a/llama_stack/providers/tests/vector_io/test_vector_io.py b/llama_stack/providers/tests/vector_io/test_vector_io.py index 901b8bd11..521131f63 100644 --- a/llama_stack/providers/tests/vector_io/test_vector_io.py +++ b/llama_stack/providers/tests/vector_io/test_vector_io.py @@ -8,13 +8,12 @@ import uuid import pytest +from llama_stack.apis.tools import RAGDocument + from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.providers.utils.memory.vector_store import ( - make_overlapped_chunks, - MemoryBankDocument, -) +from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks # How to run this test: # @@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import ( @pytest.fixture(scope="session") def sample_chunks(): docs = [ - MemoryBankDocument( + RAGDocument( document_id="doc1", content="Python is a high-level programming language.", metadata={"category": "programming", "difficulty": "beginner"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc2", content="Machine learning is a subset of artificial intelligence.", metadata={"category": "AI", "difficulty": "advanced"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc3", content="Data structures are fundamental to computer science.", metadata={"category": "computer science", "difficulty": "intermediate"}, ), - MemoryBankDocument( + RAGDocument( document_id="doc4", content="Neural networks are inspired by biological neural networks.", metadata={"category": "AI", "difficulty": "advanced"}, diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index c2de6c714..82c0c9c07 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -19,7 +19,6 @@ import numpy as np from llama_models.llama3.api.tokenizer import Tokenizer from numpy.typing import NDArray -from pydantic import BaseModel, Field from pypdf import PdfReader from llama_stack.apis.common.content_types import ( @@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import ( TextContentItem, URL, ) +from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.datatypes import Api @@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) - log = logging.getLogger(__name__) -class MemoryBankDocument(BaseModel): - document_id: str - content: InterleavedContent | URL - mime_type: str | None = None - metadata: Dict[str, Any] = Field(default_factory=dict) - - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string pdf_bytes = io.BytesIO(data) @@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved return ret -async def content_from_doc(doc: MemoryBankDocument) -> str: +async def content_from_doc(doc: RAGDocument) -> str: if isinstance(doc.content, URL): if doc.content.uri.startswith("data:"): return content_from_data(doc.content.uri) @@ -161,7 +153,13 @@ def make_overlapped_chunks( chunk = tokenizer.decode(toks) # chunk is a string chunks.append( - Chunk(content=chunk, token_count=len(toks), document_id=document_id) + Chunk( + content=chunk, + metadata={ + "token_count": len(toks), + "document_id": document_id, + }, + ) ) return chunks diff --git a/llama_stack/scripts/test_rag_via_curl.py b/llama_stack/scripts/test_rag_via_curl.py new file mode 100644 index 000000000..28d6fb601 --- /dev/null +++ b/llama_stack/scripts/test_rag_via_curl.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import List + +import pytest +import requests +from pydantic import TypeAdapter + +from llama_stack.apis.tools import ( + DefaultRAGQueryGeneratorConfig, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, +) +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str + + +class TestRAGToolEndpoints: + @pytest.fixture + def base_url(self) -> str: + return "http://localhost:8321/v1" # Adjust port if needed + + @pytest.fixture + def sample_documents(self) -> List[RAGDocument]: + return [ + RAGDocument( + document_id="doc1", + content="Python is a high-level programming language.", + metadata={"category": "programming", "difficulty": "beginner"}, + ), + RAGDocument( + document_id="doc2", + content="Machine learning is a subset of artificial intelligence.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + RAGDocument( + document_id="doc3", + content="Data structures are fundamental to computer science.", + metadata={"category": "computer science", "difficulty": "intermediate"}, + ), + ] + + @pytest.mark.asyncio + async def test_rag_workflow( + self, base_url: str, sample_documents: List[RAGDocument] + ): + vector_db_payload = { + "vector_db_id": "test_vector_db", + "embedding_model": "all-MiniLM-L6-v2", + "embedding_dimension": 384, + } + + response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload) + assert response.status_code == 200 + vector_db = VectorDB(**response.json()) + + insert_payload = { + "documents": [ + json.loads(doc.model_dump_json()) for doc in sample_documents + ], + "vector_db_id": vector_db.identifier, + "chunk_size_in_tokens": 512, + } + + response = requests.post( + f"{base_url}/tool-runtime/rag-tool/insert-documents", + json=insert_payload, + ) + assert response.status_code == 200 + + query = "What is Python?" + query_config = RAGQueryConfig( + query_generator_config=DefaultRAGQueryGeneratorConfig(), + max_tokens_in_context=4096, + max_chunks=2, + ) + + query_payload = { + "content": query, + "query_config": json.loads(query_config.model_dump_json()), + "vector_db_ids": [vector_db.identifier], + } + + response = requests.post( + f"{base_url}/tool-runtime/rag-tool/query-context", + json=query_payload, + ) + assert response.status_code == 200 + result = response.json() + result = TypeAdapter(RAGQueryResult).validate_python(result) + + content_str = interleaved_content_as_str(result.content) + print(f"content: {content_str}") + assert len(content_str) > 0 + assert "Python" in content_str + + # Clean up: Delete the vector DB + response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}") + assert response.status_code == 200 diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index ea7387a24..2160adb8e 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -4,7 +4,7 @@ distribution_spec: providers: inference: - remote::together - memory: + vector_io: - inline::faiss - remote::chromadb - remote::pgvector diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index da25fd144..135b124e4 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -5,7 +5,7 @@ apis: - datasetio - eval - inference -- memory +- vector_io - safety - scoring - telemetry @@ -20,7 +20,7 @@ providers: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} - memory: + vector_io: - provider_id: faiss provider_type: inline::faiss config: @@ -145,7 +145,6 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B -memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: []