mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fixes, add auth test
This commit is contained in:
parent
5937d94da5
commit
9f7ed4be43
3 changed files with 41 additions and 7 deletions
|
@ -57,7 +57,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ class OpenAIResponsesImpl:
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
await self._prepend_instructions(messages, instructions)
|
await self._prepend_instructions(messages, instructions)
|
||||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {})
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
)
|
)
|
||||||
if mcp_list_message:
|
if mcp_list_message:
|
||||||
output_messages.append(mcp_list_message)
|
output_messages.append(mcp_list_message)
|
||||||
|
@ -487,7 +487,7 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
tool_defs = await list_mcp_tools(
|
tool_defs = await list_mcp_tools(
|
||||||
endpoint=input_tool.server_url,
|
endpoint=input_tool.server_url,
|
||||||
headers=convert_header_list_to_dict(input_tool.headers or []),
|
headers=input_tool.headers or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
|
@ -584,12 +584,13 @@ class OpenAIResponsesImpl:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
error_exc = None
|
error_exc = None
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
if function.name in ctx.mcp_tool_to_server:
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||||
result = await invoke_mcp_tool(
|
result = await invoke_mcp_tool(
|
||||||
endpoint=mcp_tool.server_url,
|
endpoint=mcp_tool.server_url,
|
||||||
headers=convert_header_list_to_dict(mcp_tool.headers or []),
|
headers=mcp_tool.headers or {},
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
)
|
)
|
||||||
|
@ -628,7 +629,7 @@ class OpenAIResponsesImpl:
|
||||||
raise ValueError(f"Unknown tool {function.name} called")
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
input_message = None
|
input_message = None
|
||||||
if result.content:
|
if result and result.content:
|
||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
content = result.content
|
content = result.content
|
||||||
elif isinstance(result.content, list):
|
elif isinstance(result.content, list):
|
||||||
|
|
|
@ -41,8 +41,6 @@ test_response_mcp_tool:
|
||||||
- type: mcp
|
- type: mcp
|
||||||
server_label: "localmcp"
|
server_label: "localmcp"
|
||||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||||
headers:
|
|
||||||
Authorization: "Bearer test-token"
|
|
||||||
output: "Hello, world!"
|
output: "Hello, world!"
|
||||||
|
|
||||||
test_response_custom_tool:
|
test_response_custom_tool:
|
||||||
|
|
|
@ -6,8 +6,11 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
from tests.common.mcp import make_mcp_server
|
from tests.common.mcp import make_mcp_server
|
||||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||||
case_id_generator,
|
case_id_generator,
|
||||||
|
@ -166,6 +169,38 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
|
||||||
text_content = message.content[0].text
|
text_content = message.content[0].text
|
||||||
assert "boiling point" in text_content.lower()
|
assert "boiling point" in text_content.lower()
|
||||||
|
|
||||||
|
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
|
||||||
|
tools = case["tools"]
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == "mcp":
|
||||||
|
tool["server_url"] = mcp_server_info["server_url"]
|
||||||
|
|
||||||
|
exc_type = (
|
||||||
|
AuthenticationRequiredError
|
||||||
|
if isinstance(openai_client, LlamaStackAsLibraryClient)
|
||||||
|
else httpx.HTTPStatusError
|
||||||
|
)
|
||||||
|
with pytest.raises(exc_type):
|
||||||
|
openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == "mcp":
|
||||||
|
tool["server_url"] = mcp_server_info["server_url"]
|
||||||
|
tool["headers"] = {"Authorization": "Bearer test-token"}
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) >= 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case",
|
"case",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue