[memory refactor][6/n] Update naming and routes (#839)

Making a few small naming changes as per feedback:

- RAGToolRuntime methods are called `insert` and `query` to keep them
more general
- The tool names are changed to non-namespaced forms
`insert_into_memory` and `query_from_memory`
- The REST endpoints are more REST-ful
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:39:13 -08:00 committed by GitHub
parent c9e5578151
commit a63a43c646
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 319 additions and 330 deletions

View file

@ -1887,6 +1887,49 @@
] ]
} }
}, },
"/v1/tool-runtime/rag-tool/insert": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Index documents so they can be used by the RAG system",
"parameters": [
{
"name": "X-LlamaStack-Provider-Data",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "X-LlamaStack-Client-Version",
"in": "header",
"description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/InsertRequest"
}
}
},
"required": true
}
}
},
"/v1/vector-io/insert": { "/v1/vector-io/insert": {
"post": { "post": {
"responses": { "responses": {
@ -1929,49 +1972,6 @@
} }
} }
}, },
"/v1/tool-runtime/rag-tool/insert-documents": {
"post": {
"responses": {
"200": {
"description": "OK"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Index documents so they can be used by the RAG system",
"parameters": [
{
"name": "X-LlamaStack-Provider-Data",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "X-LlamaStack-Client-Version",
"in": "header",
"description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/InsertDocumentsRequest"
}
}
},
"required": true
}
}
},
"/v1/tool-runtime/invoke": { "/v1/tool-runtime/invoke": {
"post": { "post": {
"responses": { "responses": {
@ -3033,6 +3033,56 @@
} }
} }
}, },
"/v1/tool-runtime/rag-tool/query": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RAGQueryResult"
}
}
}
}
},
"tags": [
"ToolRuntime"
],
"summary": "Query the RAG system for context; typically invoked by the agent",
"parameters": [
{
"name": "X-LlamaStack-Provider-Data",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "X-LlamaStack-Client-Version",
"in": "header",
"description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueryRequest"
}
}
},
"required": true
}
}
},
"/v1/vector-io/query": { "/v1/vector-io/query": {
"post": { "post": {
"responses": { "responses": {
@ -3082,56 +3132,6 @@
} }
} }
}, },
"/v1/tool-runtime/rag-tool/query-context": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RAGQueryResult"
}
}
}
}
},
"tags": [
"ToolRuntime"
],
"summary": "Query the RAG system for context; typically invoked by the agent",
"parameters": [
{
"name": "X-LlamaStack-Provider-Data",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "X-LlamaStack-Client-Version",
"in": "header",
"description": "Version of the client making the request. This is used to ensure that the client and server are compatible.",
"required": false,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueryContextRequest"
}
}
},
"required": true
}
}
},
"/v1/telemetry/spans": { "/v1/telemetry/spans": {
"get": { "get": {
"responses": { "responses": {
@ -5256,11 +5256,8 @@
"const": "memory_retrieval", "const": "memory_retrieval",
"default": "memory_retrieval" "default": "memory_retrieval"
}, },
"memory_bank_ids": { "vector_db_ids": {
"type": "array",
"items": {
"type": "string" "type": "string"
}
}, },
"inserted_context": { "inserted_context": {
"$ref": "#/components/schemas/InterleavedContent" "$ref": "#/components/schemas/InterleavedContent"
@ -5271,7 +5268,7 @@
"turn_id", "turn_id",
"step_id", "step_id",
"step_type", "step_type",
"memory_bank_ids", "vector_db_ids",
"inserted_context" "inserted_context"
] ]
}, },
@ -6976,63 +6973,6 @@
"status" "status"
] ]
}, },
"InsertChunksRequest": {
"type": "object",
"properties": {
"vector_db_id": {
"type": "string"
},
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"content",
"metadata"
]
}
},
"ttl_seconds": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
"vector_db_id",
"chunks"
]
},
"RAGDocument": { "RAGDocument": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7094,7 +7034,7 @@
"metadata" "metadata"
] ]
}, },
"InsertDocumentsRequest": { "InsertRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"documents": { "documents": {
@ -7117,6 +7057,63 @@
"chunk_size_in_tokens" "chunk_size_in_tokens"
] ]
}, },
"InsertChunksRequest": {
"type": "object",
"properties": {
"vector_db_id": {
"type": "string"
},
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"content",
"metadata"
]
}
},
"ttl_seconds": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
"vector_db_id",
"chunks"
]
},
"InvokeToolRequest": { "InvokeToolRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7883,6 +7880,110 @@
"job_uuid" "job_uuid"
] ]
}, },
"DefaultRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "default",
"default": "default"
},
"separator": {
"type": "string",
"default": " "
}
},
"additionalProperties": false,
"required": [
"type",
"separator"
]
},
"LLMRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm",
"default": "llm"
},
"model": {
"type": "string"
},
"template": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"template"
]
},
"RAGQueryConfig": {
"type": "object",
"properties": {
"query_generator_config": {
"$ref": "#/components/schemas/RAGQueryGeneratorConfig"
},
"max_tokens_in_context": {
"type": "integer",
"default": 4096
},
"max_chunks": {
"type": "integer",
"default": 5
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
"max_chunks"
]
},
"RAGQueryGeneratorConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
},
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
]
},
"QueryRequest": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"vector_db_ids": {
"type": "array",
"items": {
"type": "string"
}
},
"query_config": {
"$ref": "#/components/schemas/RAGQueryConfig"
}
},
"additionalProperties": false,
"required": [
"content",
"vector_db_ids"
]
},
"RAGQueryResult": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
}
},
"additionalProperties": false
},
"QueryChunksRequest": { "QueryChunksRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7981,111 +8082,6 @@
"scores" "scores"
] ]
}, },
"DefaultRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "default",
"default": "default"
},
"separator": {
"type": "string",
"default": " "
}
},
"additionalProperties": false,
"required": [
"type",
"separator"
]
},
"LLMRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm",
"default": "llm"
},
"model": {
"type": "string"
},
"template": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"template"
]
},
"RAGQueryConfig": {
"type": "object",
"properties": {
"query_generator_config": {
"$ref": "#/components/schemas/RAGQueryGeneratorConfig"
},
"max_tokens_in_context": {
"type": "integer",
"default": 4096
},
"max_chunks": {
"type": "integer",
"default": 5
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
"max_chunks"
]
},
"RAGQueryGeneratorConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
},
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
]
},
"QueryContextRequest": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
},
"query_config": {
"$ref": "#/components/schemas/RAGQueryConfig"
},
"vector_db_ids": {
"type": "array",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
"required": [
"content",
"query_config",
"vector_db_ids"
]
},
"RAGQueryResult": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
}
},
"additionalProperties": false
},
"QueryCondition": { "QueryCondition": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9246,8 +9242,8 @@
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertChunksRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertChunksRequest\" />"
}, },
{ {
"name": "InsertDocumentsRequest", "name": "InsertRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertRequest\" />"
}, },
{ {
"name": "Inspect" "name": "Inspect"
@ -9435,8 +9431,8 @@
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/QueryConditionOp\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/QueryConditionOp\" />"
}, },
{ {
"name": "QueryContextRequest", "name": "QueryRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/QueryContextRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/QueryRequest\" />"
}, },
{ {
"name": "QuerySpanTreeResponse", "name": "QuerySpanTreeResponse",
@ -9858,7 +9854,7 @@
"ImageDelta", "ImageDelta",
"InferenceStep", "InferenceStep",
"InsertChunksRequest", "InsertChunksRequest",
"InsertDocumentsRequest", "InsertRequest",
"InterleavedContent", "InterleavedContent",
"InterleavedContentItem", "InterleavedContentItem",
"InvokeToolRequest", "InvokeToolRequest",
@ -9903,7 +9899,7 @@
"QueryChunksResponse", "QueryChunksResponse",
"QueryCondition", "QueryCondition",
"QueryConditionOp", "QueryConditionOp",
"QueryContextRequest", "QueryRequest",
"QuerySpanTreeResponse", "QuerySpanTreeResponse",
"QuerySpansResponse", "QuerySpansResponse",
"QueryTracesResponse", "QueryTracesResponse",

View file

@ -1009,7 +1009,7 @@ components:
- vector_db_id - vector_db_id
- chunks - chunks
type: object type: object
InsertDocumentsRequest: InsertRequest:
additionalProperties: false additionalProperties: false
properties: properties:
chunk_size_in_tokens: chunk_size_in_tokens:
@ -1299,10 +1299,6 @@ components:
type: string type: string
inserted_context: inserted_context:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
memory_bank_ids:
items:
type: string
type: array
started_at: started_at:
format: date-time format: date-time
type: string type: string
@ -1314,11 +1310,13 @@ components:
type: string type: string
turn_id: turn_id:
type: string type: string
vector_db_ids:
type: string
required: required:
- turn_id - turn_id
- step_id - step_id
- step_type - step_type
- memory_bank_ids - vector_db_ids
- inserted_context - inserted_context
type: object type: object
Message: Message:
@ -1710,7 +1708,7 @@ components:
- gt - gt
- lt - lt
type: string type: string
QueryContextRequest: QueryRequest:
additionalProperties: false additionalProperties: false
properties: properties:
content: content:
@ -1723,7 +1721,6 @@ components:
type: array type: array
required: required:
- content - content
- query_config
- vector_db_ids - vector_db_ids
type: object type: object
QuerySpanTreeResponse: QuerySpanTreeResponse:
@ -5176,7 +5173,7 @@ paths:
description: OK description: OK
tags: tags:
- ToolRuntime - ToolRuntime
/v1/tool-runtime/rag-tool/insert-documents: /v1/tool-runtime/rag-tool/insert:
post: post:
parameters: parameters:
- description: JSON-encoded provider data which will be made available to the - description: JSON-encoded provider data which will be made available to the
@ -5197,7 +5194,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/InsertDocumentsRequest' $ref: '#/components/schemas/InsertRequest'
required: true required: true
responses: responses:
'200': '200':
@ -5205,7 +5202,7 @@ paths:
summary: Index documents so they can be used by the RAG system summary: Index documents so they can be used by the RAG system
tags: tags:
- ToolRuntime - ToolRuntime
/v1/tool-runtime/rag-tool/query-context: /v1/tool-runtime/rag-tool/query:
post: post:
parameters: parameters:
- description: JSON-encoded provider data which will be made available to the - description: JSON-encoded provider data which will be made available to the
@ -5226,7 +5223,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/QueryContextRequest' $ref: '#/components/schemas/QueryRequest'
required: true required: true
responses: responses:
'200': '200':
@ -5814,9 +5811,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertChunksRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/InsertChunksRequest"
/> />
name: InsertChunksRequest name: InsertChunksRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/InsertRequest" />
/> name: InsertRequest
name: InsertDocumentsRequest
- name: Inspect - name: Inspect
- description: <SchemaDefinition schemaRef="#/components/schemas/InterleavedContent" - description: <SchemaDefinition schemaRef="#/components/schemas/InterleavedContent"
/> />
@ -5943,9 +5939,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryConditionOp" - description: <SchemaDefinition schemaRef="#/components/schemas/QueryConditionOp"
/> />
name: QueryConditionOp name: QueryConditionOp
- description: <SchemaDefinition schemaRef="#/components/schemas/QueryContextRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/QueryRequest" />
/> name: QueryRequest
name: QueryContextRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/QuerySpanTreeResponse" - description: <SchemaDefinition schemaRef="#/components/schemas/QuerySpanTreeResponse"
/> />
name: QuerySpanTreeResponse name: QuerySpanTreeResponse
@ -6245,7 +6240,7 @@ x-tagGroups:
- ImageDelta - ImageDelta
- InferenceStep - InferenceStep
- InsertChunksRequest - InsertChunksRequest
- InsertDocumentsRequest - InsertRequest
- InterleavedContent - InterleavedContent
- InterleavedContentItem - InterleavedContentItem
- InvokeToolRequest - InvokeToolRequest
@ -6290,7 +6285,7 @@ x-tagGroups:
- QueryChunksResponse - QueryChunksResponse
- QueryCondition - QueryCondition
- QueryConditionOp - QueryConditionOp
- QueryContextRequest - QueryRequest
- QuerySpanTreeResponse - QuerySpanTreeResponse
- QuerySpansResponse - QuerySpansResponse
- QueryTracesResponse - QueryTracesResponse

View file

@ -74,8 +74,8 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class RAGToolRuntime(Protocol): class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST") @webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert_documents( async def insert(
self, self,
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_id: str, vector_db_id: str,
@ -84,12 +84,12 @@ class RAGToolRuntime(Protocol):
"""Index documents so they can be used by the RAG system""" """Index documents so they can be used by the RAG system"""
... ...
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST") @webmethod(route="/tool-runtime/rag-tool/query", method="POST")
async def query_context( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent""" """Query the RAG system for context; typically invoked by the agent"""
... ...

View file

@ -38,7 +38,7 @@ class VectorDBStore(Protocol):
class VectorIO(Protocol): class VectorIO(Protocol):
vector_db_store: VectorDBStore vector_db_store: VectorDBStore
# this will just block now until documents are inserted, but it should # this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion
@webmethod(route="/vector-io/insert", method="POST") @webmethod(route="/vector-io/insert", method="POST")
async def insert_chunks( async def insert_chunks(

View file

@ -414,25 +414,25 @@ class ToolRuntimeRouter(ToolRuntime):
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
async def query_context( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl(
"rag_tool.query_context" "query_from_memory"
).query_context(content, query_config, vector_db_ids) ).query(content, vector_db_ids, query_config)
async def insert_documents( async def insert(
self, self,
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents" "insert_into_memory"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens) ).insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__( def __init__(
self, self,
@ -441,10 +441,9 @@ class ToolRuntimeRouter(ToolRuntime):
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.RagToolImpl(routing_table) self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context) for method in ("query", "insert"):
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass

View file

@ -84,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "rag_tool.query_context" MEMORY_QUERY_TOOL = "query_from_memory"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory" MEMORY_GROUP = "builtin::memory"
@ -432,16 +432,16 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
result = await self.tool_runtime_api.rag_tool.query_context( result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content( content=concat_interleaved_content(
[msg.content for msg in input_messages] [msg.content for msg in input_messages]
), ),
vector_db_ids=vector_db_ids,
query_config=RAGQueryConfig( query_config=RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(), query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096, max_tokens_in_context=4096,
max_chunks=5, max_chunks=5,
), ),
vector_db_ids=vector_db_ids,
) )
retrieved_context = result.content retrieved_context = result.content
@ -882,7 +882,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in data for a in data
] ]
await self.tool_runtime_api.rag_tool.insert_documents( await self.tool_runtime_api.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,

View file

@ -61,7 +61,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self): async def shutdown(self):
pass pass
async def insert_documents( async def insert(
self, self,
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_id: str, vector_db_id: str,
@ -87,15 +87,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
) )
async def query_context( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
if not vector_db_ids: if not vector_db_ids:
return RAGQueryResult(content=None) return RAGQueryResult(content=None)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query( query = await generate_rag_query(
query_config.query_generator_config, query_config.query_generator_config,
content, content,
@ -159,11 +160,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# encountering fatals. # encountering fatals.
return [ return [
ToolDef( ToolDef(
name="rag_tool.query_context", name="query_from_memory",
description="Retrieve context from memory", description="Retrieve context from memory",
), ),
ToolDef( ToolDef(
name="rag_tool.insert_documents", name="insert_into_memory",
description="Insert documents into memory", description="Insert documents into memory",
), ),
] ]

View file

@ -96,14 +96,14 @@ class TestTools:
) )
# Insert documents into memory # Insert documents into memory
await tools_impl.rag_tool.insert_documents( await tools_impl.rag_tool.insert(
documents=sample_documents, documents=sample_documents,
vector_db_id="test_bank", vector_db_id="test_bank",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
# Execute the memory tool # Execute the memory tool
response = await tools_impl.rag_tool.query_context( response = await tools_impl.rag_tool.query(
content="What are the main topics covered in the documentation?", content="What are the main topics covered in the documentation?",
vector_db_ids=["test_bank"], vector_db_ids=["test_bank"],
) )

View file

@ -11,11 +11,9 @@ from pathlib import Path
import pytest import pytest
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.apis.tools import RAGDocument
content_from_doc,
MemoryBankDocument, from llama_stack.providers.utils.memory.vector_store import content_from_doc, URL
URL,
)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
@ -41,33 +39,33 @@ class TestVectorStore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self): async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH) data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument( doc = RAGDocument(
document_id="dummy", document_id="dummy",
content=data_uri, content=data_uri,
mime_type="application/pdf", mime_type="application/pdf",
metadata={}, metadata={},
) )
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content == "Dummy PDF file" assert content == "Dumm y PDF file"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self): async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file # Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument( doc = RAGDocument(
document_id="dummy", document_id="dummy",
content=url, content=url,
mime_type="application/pdf", mime_type="application/pdf",
metadata={}, metadata={},
) )
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content == "Dummy PDF file" assert content == "Dumm y PDF file"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self): async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file # Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument( doc = RAGDocument(
document_id="dummy", document_id="dummy",
content=URL( content=URL(
uri=url, uri=url,
@ -76,4 +74,4 @@ class TestVectorStore:
metadata={}, metadata={},
) )
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content == "Dummy PDF file" assert content == "Dumm y PDF file"

View file

@ -292,7 +292,7 @@ def test_rag_agent(llama_stack_client, agent_config):
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
) )
llama_stack_client.tool_runtime.rag_tool.insert_documents( llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
@ -321,4 +321,4 @@ def test_rag_agent(llama_stack_client, agent_config):
) )
logs = [str(log) for log in EventLogger().log(response) if log is not None] logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs) logs_str = "".join(logs)
assert "Tool:rag_tool.query_context" in logs_str assert "Tool:query_from_memory" in logs_str

View file

@ -73,7 +73,7 @@ def test_vector_db_insert_inline_and_query(
llama_stack_client, single_entry_vector_db_registry, sample_documents llama_stack_client, single_entry_vector_db_registry, sample_documents
): ):
vector_db_id = single_entry_vector_db_registry[0] vector_db_id = single_entry_vector_db_registry[0]
llama_stack_client.tool_runtime.rag_tool.insert_documents( llama_stack_client.tool_runtime.rag_tool.insert(
documents=sample_documents, documents=sample_documents,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
@ -157,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
llama_stack_client.tool_runtime.rag_tool.insert_documents( llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,