Merge remote-tracking branch 'origin/main' into resp_branching

This commit is contained in:
Ashwin Bharambe 2025-10-01 21:13:12 -07:00
commit 1536ae0333
144 changed files with 62682 additions and 51560 deletions

View file

@ -333,6 +333,132 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[5].response.output[0].name == "get_weather"
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
# Setup
input_text = "What is the time right now?"
model = "meta-llama/Llama-3.1-8B-Instruct"
async def fake_stream_toolcall():
yield ChatCompletionChunk(
id="123",
choices=[
Choice(
index=0,
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="tc_123",
function=ChoiceDeltaToolCallFunction(name="get_current_time", arguments=None),
type=None,
)
]
),
),
],
created=1,
model=model,
object="chat.completion.chunk",
)
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
# Function does not accept arguments
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
stream=True,
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={},
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
# Function accepts optional arguments
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
stream=True,
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={
"timezone": "string",
},
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with multiple messages."""
# Setup

View file

@ -19,6 +19,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://api.openai.com/v1"
@ -27,6 +28,7 @@ class TestOpenAIBaseURLConfig:
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == custom_url
@ -38,6 +40,7 @@ class TestOpenAIBaseURLConfig:
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://env.openai.com/v1"
@ -47,6 +50,7 @@ class TestOpenAIBaseURLConfig:
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Config should take precedence over environment variable
assert adapter.get_base_url() == custom_url
@ -57,6 +61,7 @@ class TestOpenAIBaseURLConfig:
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
@ -76,6 +81,7 @@ class TestOpenAIBaseURLConfig:
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
@ -117,6 +123,7 @@ class TestOpenAIBaseURLConfig:
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")

View file

@ -4,18 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import json
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class OpenAIMixinImpl(OpenAIMixin):
def __init__(self):
self.__provider_id__ = "test-provider"
__provider_id__: str = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@ -24,7 +26,7 @@ class OpenAIMixinImpl(OpenAIMixin):
raise NotImplementedError("This method should be mocked in tests")
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
@ -32,14 +34,6 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@pytest.fixture
def mixin():
@ -366,3 +360,78 @@ class TestOpenAIMixinAllowedModels:
assert await mixin.check_model_availability("final-mock-model-id")
assert not await mixin.check_model_availability("some-mock-model-id")
assert not await mixin.check_model_availability("another-mock-model-id")
class ProviderDataValidator(BaseModel):
"""Validator for provider data in tests"""
test_api_key: str | None = Field(default=None)
class OpenAIMixinWithProviderData(OpenAIMixinImpl):
"""Test implementation that supports provider data API key field"""
provider_data_api_key_field: str = "test_api_key"
def get_api_key(self) -> str:
return "default-api-key"
def get_base_url(self):
return "default-base-url"
class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality"""
@pytest.fixture
def mixin_with_provider_data_field(self):
"""Mixin instance with provider_data_api_key_field set"""
mixin_instance = OpenAIMixinWithProviderData()
# Mock provider_spec for provider data validation
mock_provider_spec = MagicMock()
mock_provider_spec.provider_type = "test-provider-with-data"
mock_provider_spec.provider_data_validator = (
"tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator"
)
mixin_instance.__provider_spec__ = mock_provider_spec
return mixin_instance
@pytest.fixture
def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field):
mixin_with_provider_data_field.get_api_key = Mock(return_value=None)
return mixin_with_provider_data_field
def test_no_provider_data(self, mixin_with_provider_data_field):
"""Test that client uses config API key when no provider data is available"""
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
def test_with_provider_data(self, mixin_with_provider_data_field):
"""Test that provider data API key overrides config API key"""
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
):
assert mixin_with_provider_data_field.client.api_key == "provider-data-key"
def test_with_wrong_key(self, mixin_with_provider_data_field):
"""Test fallback to config when provider data doesn't have the required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
def test_error_when_no_config_and_provider_data_has_wrong_key(
self, mixin_with_provider_data_field_and_none_api_key
):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
_ = mixin_with_provider_data_field_and_none_api_key.client
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
"""Test that error message includes correct field name and header information"""
with pytest.raises(ValueError) as exc_info:
_ = mixin_with_provider_data_field_and_none_api_key.client
error_message = str(exc_info.value)
assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message