featu: support passing "extra body" throught to providers

# What does this PR do?
Allows passing through extra_body parameters to inference providers.


closes #2720

## Test Plan
CI and added new test
This commit is contained in:
Eric Huang 2025-10-10 15:40:30 -07:00
parent cb7fb0705b
commit 9f50338a4e
35 changed files with 1892 additions and 199 deletions

View file

@ -33,7 +33,7 @@ from llama_stack.apis.agents.openai_responses import (
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequest,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIJSONSchema,
OpenAIResponseFormatJSONObject,
@ -162,7 +162,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
OpenAIChatCompletionRequest(
OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,

View file

@ -13,11 +13,16 @@ import pytest
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionRequest,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChoice,
OpenAICompletion,
OpenAICompletionChoice,
OpenAICompletionRequestWithExtraBody,
ToolChoice,
)
from llama_stack.apis.models import Model
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
@ -57,7 +62,7 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client
# No tools but auto tool choice
params = OpenAIChatCompletionRequest(
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
@ -173,7 +178,7 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
)
async def do_inference():
params = OpenAIChatCompletionRequest(
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
@ -191,3 +196,148 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_vllm_completion_extra_body():
"""
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(
return_value=OpenAICompletion(
id="cmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAICompletionChoice(
text="joy",
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with guided_choice and prompt_logprobs as extra fields
params = OpenAICompletionRequestWithExtraBody(
model="mock-model",
prompt="I am feeling happy",
stream=False,
guided_choice=["joy", "sadness"],
prompt_logprobs=5,
)
await router.openai_completion(params)
# Verify that the client was called with extra_body containing both parameters
mock_client.completions.create.assert_called_once()
call_kwargs = mock_client.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "guided_choice" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
assert "prompt_logprobs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
async def test_vllm_chat_completion_extra_body():
"""
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="test response",
),
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with chat_template_kwargs as extra field
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
chat_template_kwargs={"thinking": True},
)
await router.openai_chat_completion(params)
# Verify that the client was called with extra_body containing chat_template_kwargs
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -271,7 +271,7 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
params = OpenAIChatCompletionRequest(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -304,7 +304,7 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
params = OpenAIChatCompletionRequest(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called()