mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
feat: File search tool for Responses API
This is an initial working prototype of wiring up the `file_search` builtin tool for the Responses API to our existing rag knowledge search tool. I stubbed in a new test (that uses a hardcoded url hybrid of the OpenAI and Llama Stack clients for now, only until we finish landing the vector store APIs and insertion support). Note that this is currently under tests/verification only because it sometimes flakes with tool calling of the small Llama-3.2-3B model we run in CI (and that I use as an example below). We'd want to make the test a bit more robust in some way if we moved this over to tests/integration and ran it in CI. ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" pytest -sv 'tests/verifications/openai_api/test_responses.py::test_response_non_streaming_file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model meta-llama/Llama-3.2-3B-Instruct ``` Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
e2e15ebb6c
commit
fa34468308
7 changed files with 234 additions and 11 deletions
69
docs/_static/llama-stack-spec.html
vendored
69
docs/_static/llama-stack-spec.html
vendored
|
@ -7047,6 +7047,9 @@
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
},
|
},
|
||||||
|
@ -7193,7 +7196,7 @@
|
||||||
"const": "file_search",
|
"const": "file_search",
|
||||||
"default": "file_search"
|
"default": "file_search"
|
||||||
},
|
},
|
||||||
"vector_store_id": {
|
"vector_store_ids": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -7217,7 +7220,7 @@
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"type",
|
"type",
|
||||||
"vector_store_id"
|
"vector_store_ids"
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseInputToolFileSearch"
|
"title": "OpenAIResponseInputToolFileSearch"
|
||||||
},
|
},
|
||||||
|
@ -7484,6 +7487,64 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseOutputMessageContentOutputText"
|
"title": "OpenAIResponseOutputMessageContentOutputText"
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseOutputMessageFileSearchToolCall": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"queries": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "file_search_call",
|
||||||
|
"default": "file_search_call"
|
||||||
|
},
|
||||||
|
"results": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"queries",
|
||||||
|
"status",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageFileSearchToolCall"
|
||||||
|
},
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall": {
|
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -7760,6 +7821,9 @@
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
},
|
},
|
||||||
|
@ -7775,6 +7839,7 @@
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"message": "#/components/schemas/OpenAIResponseMessage",
|
"message": "#/components/schemas/OpenAIResponseMessage",
|
||||||
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
||||||
|
"file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall",
|
||||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||||
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||||
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
|
|
42
docs/_static/llama-stack-spec.yaml
vendored
42
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5021,6 +5021,7 @@ components:
|
||||||
OpenAIResponseInput:
|
OpenAIResponseInput:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
|
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||||
|
@ -5115,7 +5116,7 @@ components:
|
||||||
type: string
|
type: string
|
||||||
const: file_search
|
const: file_search
|
||||||
default: file_search
|
default: file_search
|
||||||
vector_store_id:
|
vector_store_ids:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
@ -5132,7 +5133,7 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- vector_store_id
|
- vector_store_ids
|
||||||
title: OpenAIResponseInputToolFileSearch
|
title: OpenAIResponseInputToolFileSearch
|
||||||
OpenAIResponseInputToolFunction:
|
OpenAIResponseInputToolFunction:
|
||||||
type: object
|
type: object
|
||||||
|
@ -5294,6 +5295,41 @@ components:
|
||||||
- type
|
- type
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseOutputMessageContentOutputText
|
OpenAIResponseOutputMessageContentOutputText
|
||||||
|
"OpenAIResponseOutputMessageFileSearchToolCall":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
queries:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: file_search_call
|
||||||
|
default: file_search_call
|
||||||
|
results:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- queries
|
||||||
|
- status
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCall
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5491,6 +5527,7 @@ components:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
|
@ -5499,6 +5536,7 @@ components:
|
||||||
mapping:
|
mapping:
|
||||||
message: '#/components/schemas/OpenAIResponseMessage'
|
message: '#/components/schemas/OpenAIResponseMessage'
|
||||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
file_search_call: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
|
|
|
@ -81,6 +81,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
type: Literal["web_search_call"] = "web_search_call"
|
type: Literal["web_search_call"] = "web_search_call"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
queries: list[str]
|
||||||
|
status: str
|
||||||
|
type: Literal["file_search_call"] = "file_search_call"
|
||||||
|
results: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
|
@ -119,6 +128,7 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage
|
OpenAIResponseMessage
|
||||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||||
| OpenAIResponseOutputMessageFunctionToolCall
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
| OpenAIResponseOutputMessageMCPCall
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
| OpenAIResponseOutputMessageMCPListTools,
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
@ -362,6 +372,7 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||||
OpenAIResponseInput = Annotated[
|
OpenAIResponseInput = Annotated[
|
||||||
# Responses API allows output messages to be passed in as input
|
# Responses API allows output messages to be passed in as input
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall
|
OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||||
| OpenAIResponseOutputMessageFunctionToolCall
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
| OpenAIResponseInputFunctionToolCallOutput
|
| OpenAIResponseInputFunctionToolCallOutput
|
||||||
|
|
|
|
||||||
|
@ -397,9 +408,9 @@ class FileSearchRankingOptions(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
type: Literal["file_search"] = "file_search"
|
type: Literal["file_search"] = "file_search"
|
||||||
vector_store_id: list[str]
|
vector_store_ids: list[str]
|
||||||
ranking_options: FileSearchRankingOptions | None = None
|
ranking_options: FileSearchRankingOptions | None = None
|
||||||
# TODO: add filters
|
# TODO: add filters, max_num_results
|
||||||
|
|
||||||
|
|
||||||
class ApprovalFilter(BaseModel):
|
class ApprovalFilter(BaseModel):
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseInputToolFileSearch,
|
||||||
OpenAIResponseInputToolMCP,
|
OpenAIResponseInputToolMCP,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
@ -198,7 +200,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
class ChatCompletionContext(BaseModel):
|
class ChatCompletionContext(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: list[OpenAIMessageParam]
|
messages: list[OpenAIMessageParam]
|
||||||
tools: list[ChatCompletionToolParam] | None = None
|
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||||
|
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
temperature: float | None
|
temperature: float | None
|
||||||
response_format: OpenAIResponseFormatParam
|
response_format: OpenAIResponseFormatParam
|
||||||
|
@ -388,7 +391,8 @@ class OpenAIResponsesImpl:
|
||||||
ctx = ChatCompletionContext(
|
ctx = ChatCompletionContext(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=chat_tools,
|
response_tools=tools,
|
||||||
|
chat_tools=chat_tools,
|
||||||
mcp_tool_to_server=mcp_tool_to_server,
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -417,7 +421,7 @@ class OpenAIResponsesImpl:
|
||||||
completion_result = await self.inference_api.openai_chat_completion(
|
completion_result = await self.inference_api.openai_chat_completion(
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=ctx.tools,
|
tools=ctx.chat_tools,
|
||||||
stream=True,
|
stream=True,
|
||||||
temperature=ctx.temperature,
|
temperature=ctx.temperature,
|
||||||
response_format=ctx.response_format,
|
response_format=ctx.response_format,
|
||||||
|
@ -606,6 +610,12 @@ class OpenAIResponsesImpl:
|
||||||
if not tool:
|
if not tool:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "file_search":
|
||||||
|
tool_name = "knowledge_search"
|
||||||
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
elif input_tool.type == "mcp":
|
elif input_tool.type == "mcp":
|
||||||
always_allowed = None
|
always_allowed = None
|
||||||
never_allowed = None
|
never_allowed = None
|
||||||
|
@ -667,6 +677,7 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function = tool_call.function
|
function = tool_call.function
|
||||||
|
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||||
|
|
||||||
if not function or not tool_call_id or not function.name:
|
if not function or not tool_call_id or not function.name:
|
||||||
return None, None
|
return None, None
|
||||||
|
@ -680,12 +691,18 @@ class OpenAIResponsesImpl:
|
||||||
endpoint=mcp_tool.server_url,
|
endpoint=mcp_tool.server_url,
|
||||||
headers=mcp_tool.headers or {},
|
headers=mcp_tool.headers or {},
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
kwargs=tool_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if function.name == "knowledge_search":
|
||||||
|
response_file_search_tool = next(
|
||||||
|
t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
|
||||||
|
)
|
||||||
|
if response_file_search_tool:
|
||||||
|
tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
kwargs=tool_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_exc = e
|
error_exc = e
|
||||||
|
@ -713,6 +730,27 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
message.status = "failed"
|
message.status = "failed"
|
||||||
|
elif function.name == "knowledge_search":
|
||||||
|
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
queries=[tool_kwargs.get("query", "")],
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
if "document_ids" in result.metadata:
|
||||||
|
message.results = []
|
||||||
|
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||||
|
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||||
|
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||||
|
message.results.append(
|
||||||
|
{
|
||||||
|
"file_id": doc_id,
|
||||||
|
"filename": doc_id,
|
||||||
|
"text": text,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.status = "failed"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool {function.name} called")
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
content=picked,
|
content=picked,
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
|
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||||
|
"scores": scores[: len(picked)],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,18 @@ test_response_web_search:
|
||||||
search_context_size: "low"
|
search_context_size: "low"
|
||||||
output: "128"
|
output: "128"
|
||||||
|
|
||||||
|
test_response_file_search:
|
||||||
|
test_name: test_response_file_search
|
||||||
|
test_params:
|
||||||
|
case:
|
||||||
|
- case_id: "llama_experts"
|
||||||
|
input: "How many experts does the Llama 4 Maverick model have?"
|
||||||
|
tools:
|
||||||
|
- type: file_search
|
||||||
|
vector_store_ids:
|
||||||
|
- test_vector_store
|
||||||
|
output: "128"
|
||||||
|
|
||||||
test_response_mcp_tool:
|
test_response_mcp_tool:
|
||||||
test_name: test_response_mcp_tool
|
test_name: test_response_mcp_tool
|
||||||
test_params:
|
test_params:
|
||||||
|
|
|
@ -9,6 +9,7 @@ import json
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
|
@ -258,6 +259,62 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
||||||
assert case["output"].lower() in response.output_text.lower().strip()
|
assert case["output"].lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
responses_test_cases["test_response_file_search"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_response_non_streaming_file_search(
|
||||||
|
base_url, request, openai_client, model, provider, verification_config, case
|
||||||
|
):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
lls_client = LlamaStackClient(base_url=base_url.replace("/v1/openai/v1", ""))
|
||||||
|
vector_db_id = "test_vector_store"
|
||||||
|
|
||||||
|
# Ensure the test starts from a clean vector store
|
||||||
|
try:
|
||||||
|
lls_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
lls_client.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
)
|
||||||
|
doc_content = "Llama 4 Maverick has 128 experts"
|
||||||
|
chunks = [
|
||||||
|
{
|
||||||
|
"content": doc_content,
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"metadata": {
|
||||||
|
"document_id": "doc1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
lls_client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=case["tools"],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) > 1
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].queries # ensure it's some non-empty list
|
||||||
|
assert response.output[0].results[0].text == doc_content
|
||||||
|
assert response.output[0].results[0].score > 0
|
||||||
|
assert response.output[1].type == "message"
|
||||||
|
assert response.output[1].status == "completed"
|
||||||
|
assert response.output[1].role == "assistant"
|
||||||
|
assert len(response.output[1].content) > 0
|
||||||
|
assert case["output"].lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case",
|
"case",
|
||||||
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue