chore: turn OpenAIMixin into a pydantic.BaseModel

- implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key
 - remove use of LiteLLMOpenAIMixin
 - add default initialize/shutdown methods to OpenAIMixin
 - remove __init__s to allow proper pydantic construction
 - remove dead code from vllm adapter and associated / duplicate unit tests
 - update vllm adapter to use openaimixin for model registration
 - remove ModelRegistryHelper from fireworks & together adapters
 - remove Inference from nvidia adapter
 - complete type hints on embedding_model_metadata
 - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc
 - new recordings for ollama
This commit is contained in:
Matthew Farrellee 2025-10-02 20:47:54 -04:00
parent ce77c27ff8
commit 60f0056cbc
57 changed files with 12520 additions and 1254 deletions

View file

@ -22,7 +22,7 @@ def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter = GroqInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
@ -40,7 +40,7 @@ def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter = OpenAIInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
@ -59,7 +59,7 @@ def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter = TogetherInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
@ -77,7 +77,7 @@ def test_together_provider_openai_client_caching():
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter = LlamaCompatInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (

View file

@ -18,7 +18,7 @@ class TestOpenAIBaseURLConfig:
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://api.openai.com/v1"
@ -27,7 +27,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == custom_url
@ -39,7 +39,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://env.openai.com/v1"
@ -49,7 +49,7 @@ class TestOpenAIBaseURLConfig:
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Config should take precedence over environment variable
@ -60,7 +60,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
@ -80,7 +80,7 @@ class TestOpenAIBaseURLConfig:
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
@ -122,7 +122,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method

View file

@ -5,45 +5,21 @@
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
_process_vllm_chat_completion_stream_response,
)
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
@ -56,37 +32,15 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
await vllm_inference_adapter.register_model(foo_model)
mock_openai_models_list.assert_called()
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -115,403 +69,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
tool call response that didn't call any tools. Previously we would
emit chunks with spurious ToolCallParseStatus.succeeded or
ToolCallParseStatus.failed when processing chunks that didn't
actually make any tool calls.
"""
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
completion request. This can happen when there's an error thrown
in vLLM for example.
"""
async def mock_stream():
choices = []
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
finish reason.
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
"""
mock_tool_name = "mock_tool"
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": "",
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == "{}"
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -642,94 +199,30 @@ async def test_should_refresh_models():
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1 = VLLMInferenceAdapter(config=config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
adapter2 = VLLMInferenceAdapter(config=config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
adapter3 = VLLMInferenceAdapter(config=config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
adapter4 = VLLMInferenceAdapter(config=config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
adapter5 = VLLMInferenceAdapter(config=config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.openai_chat_completion(
model="test-model",
messages=[UserMessage(content="Hello")],
stream=False,
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -29,7 +30,7 @@ class OpenAIMixinImpl(OpenAIMixin):
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
@ -38,7 +39,8 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
@ -53,7 +55,8 @@ def mixin():
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
@ -504,7 +507,8 @@ class TestOpenAIMixinProviderDataApiKey:
@pytest.fixture
def mixin_with_provider_data_field(self):
"""Mixin instance with provider_data_api_key_field set"""
mixin_instance = OpenAIMixinWithProviderData()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinWithProviderData(config=config)
# Mock provider_spec for provider data validation
mock_provider_spec = MagicMock()