mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: Add dynamic authentication token forwarding support for vLLM (#3388)
# What does this PR do? *Add dynamic authentication token forwarding support for vLLM provider* This enables per-request authentication tokens for vLLM providers, supporting use cases like RAG operations where different requests may need different authentication tokens. The implementation follows the same pattern as other providers like Together AI, Fireworks, and Passthrough. - Add LiteLLMOpenAIMixin that manages the vllm_api_token properly Usage: - Static: VLLM_API_TOKEN env var or config.api_token - Dynamic: X-LlamaStack-Provider-Data header with vllm_api_token All existing functionality is preserved while adding new dynamic capabilities. <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> ``` curl -X POST "http://localhost:8000/v1/chat/completions" -H "Authorization: Bearer my-dynamic-token" \ -H "X-LlamaStack-Provider-Data: {\"vllm_api_token\": \"Bearer my-dynamic-token\", \"vllm_url\": \"http://dynamic-server:8000\"}" \ -H "Content-Type: application/json" \ -d '{"model": "llama-3.1-8b", "messages": [{"role": "user", "content": "Hello!"}]}' ``` --------- Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
parent
42c23b45f6
commit
4842145202
4 changed files with 219 additions and 48 deletions
|
@ -78,6 +78,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.vllm",
|
module="llama_stack.providers.remote.inference.vllm",
|
||||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
|
@ -4,9 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMProviderDataValidator(BaseModel):
|
||||||
|
vllm_api_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||||
from .vllm import VLLMInferenceAdapter
|
from .vllm import VLLMInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import APIConnectionError, AsyncOpenAI
|
from openai import APIConnectionError, AsyncOpenAI
|
||||||
|
@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import (
|
||||||
HealthStatus,
|
HealthStatus,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
|
@ -62,6 +64,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
UnparseableToolCall,
|
UnparseableToolCall,
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
|
convert_openai_chat_completion_stream,
|
||||||
convert_tool_call,
|
convert_tool_call,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
|
@ -281,15 +284,31 @@ async def _process_vllm_chat_completion_stream_response(
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
# automatically set by the resolver when instantiating the provider
|
# automatically set by the resolver when instantiating the provider
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
|
LiteLLMOpenAIMixin.__init__(
|
||||||
|
self,
|
||||||
|
build_hf_repo_model_entries(),
|
||||||
|
litellm_provider_name="vllm",
|
||||||
|
api_key_from_config=config.api_token,
|
||||||
|
provider_data_api_key_field="vllm_api_token",
|
||||||
|
openai_compat_api_base=config.url,
|
||||||
|
)
|
||||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
"""Get the base URL from config."""
|
||||||
|
if not self.config.url:
|
||||||
|
raise ValueError("No base URL configured")
|
||||||
|
return self.config.url
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.config.url:
|
if not self.config.url:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -297,6 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
|
# Strictly respecting the refresh_models directive
|
||||||
return self.config.refresh_models
|
return self.config.refresh_models
|
||||||
|
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
|
@ -325,13 +345,19 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||||
This method is used by the Provider API to verify
|
This method is used by the Provider API to verify
|
||||||
that the service is running correctly.
|
that the service is running correctly.
|
||||||
|
Uses the unauthenticated /health endpoint.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
HealthResponse: A dictionary containing the health status.
|
HealthResponse: A dictionary containing the health status.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
|
base_url = self.get_base_url()
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
health_url = urljoin(base_url, "health")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(health_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
|
@ -340,16 +366,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
return await self.model_store.get_model(model_id)
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
def get_api_key(self):
|
|
||||||
return self.config.api_token
|
|
||||||
|
|
||||||
def get_base_url(self):
|
|
||||||
return self.config.url
|
|
||||||
|
|
||||||
def get_extra_client_params(self):
|
def get_extra_client_params(self):
|
||||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||||
|
|
||||||
async def completion(
|
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
|
@ -411,13 +431,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, self.client)
|
return self._stream_chat_completion_with_client(request, self.client)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request, self.client)
|
return await self._nonstream_chat_completion(request, self.client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
|
assert self.client is not None
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await client.chat.completions.create(**params)
|
r = await client.chat.completions.create(**params)
|
||||||
choice = r.choices[0]
|
choice = r.choices[0]
|
||||||
|
@ -431,9 +452,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
|
# This method is called from LiteLLMOpenAIMixin.chat_completion
|
||||||
|
# The response parameter contains the litellm response
|
||||||
|
# We need to convert it to our format
|
||||||
|
async def _stream_generator():
|
||||||
|
async for chunk in response:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async for chunk in convert_openai_chat_completion_stream(
|
||||||
|
_stream_generator(), enable_incremental_tool_calls=True
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _stream_chat_completion_with_client(
|
||||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
"""Helper method for streaming with explicit client parameter."""
|
||||||
|
assert self.client is not None
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await client.chat.completions.create(**params)
|
stream = await client.chat.completions.create(**params)
|
||||||
|
@ -445,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
assert self.client is not None
|
if self.client is None:
|
||||||
|
raise RuntimeError("Client is not initialized")
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
@ -453,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
async def _stream_completion(
|
async def _stream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
assert self.client is not None
|
if self.client is None:
|
||||||
|
raise RuntimeError("Client is not initialized")
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
|
@ -66,11 +66,15 @@ def mock_openai_models_list():
|
||||||
yield mock_list
|
yield mock_list
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="function")
|
||||||
async def vllm_inference_adapter():
|
async def vllm_inference_adapter():
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||||
inference_adapter = VLLMInferenceAdapter(config)
|
inference_adapter = VLLMInferenceAdapter(config)
|
||||||
inference_adapter.model_store = AsyncMock()
|
inference_adapter.model_store = AsyncMock()
|
||||||
|
# Mock the __provider_spec__ attribute that would normally be set by the resolver
|
||||||
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
|
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
|
||||||
|
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
|
||||||
await inference_adapter.initialize()
|
await inference_adapter.initialize()
|
||||||
return inference_adapter
|
return inference_adapter
|
||||||
|
|
||||||
|
@ -120,6 +124,10 @@ async def test_tool_call_response(vllm_inference_adapter):
|
||||||
mock_client.chat.completions.create = AsyncMock()
|
mock_client.chat.completions.create = AsyncMock()
|
||||||
mock_create_client.return_value = mock_client
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock the model to return a proper provider_resource_id
|
||||||
|
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
|
||||||
|
vllm_inference_adapter.model_store.get_model.return_value = mock_model
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content="You are a helpful assistant"),
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
UserMessage(content="How many?"),
|
UserMessage(content="How many?"),
|
||||||
|
@ -555,31 +563,29 @@ async def test_health_status_success(vllm_inference_adapter):
|
||||||
"""
|
"""
|
||||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||||
|
|
||||||
This test verifies that the health method returns a HealthResponse with status OK, only
|
This test verifies that the health method returns a HealthResponse with status OK
|
||||||
when the connection to the vLLM server is successful.
|
when the /health endpoint responds successfully.
|
||||||
"""
|
"""
|
||||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
with patch("httpx.AsyncClient") as mock_client_class:
|
||||||
# Create mock client and models
|
# Create mock response
|
||||||
mock_client = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_models = MagicMock()
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
# Create a mock async iterator that yields a model when iterated
|
# Create mock client instance
|
||||||
async def mock_list():
|
mock_client_instance = MagicMock()
|
||||||
for model in [MagicMock()]:
|
mock_client_instance.get = AsyncMock(return_value=mock_response)
|
||||||
yield model
|
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
# Set up the models.list to return our mock async iterator
|
|
||||||
mock_models.list.return_value = mock_list()
|
|
||||||
mock_client.models = mock_models
|
|
||||||
mock_create_client.return_value = mock_client
|
|
||||||
|
|
||||||
# Call the health method
|
# Call the health method
|
||||||
health_response = await vllm_inference_adapter.health()
|
health_response = await vllm_inference_adapter.health()
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert health_response["status"] == HealthStatus.OK
|
assert health_response["status"] == HealthStatus.OK
|
||||||
|
|
||||||
# Verify that models.list was called
|
# Verify that the health endpoint was called
|
||||||
mock_models.list.assert_called_once()
|
mock_client_instance.get.assert_called_once()
|
||||||
|
call_args = mock_client_instance.get.call_args[0]
|
||||||
|
assert call_args[0].endswith("/health")
|
||||||
|
|
||||||
|
|
||||||
async def test_health_status_failure(vllm_inference_adapter):
|
async def test_health_status_failure(vllm_inference_adapter):
|
||||||
|
@ -589,28 +595,42 @@ async def test_health_status_failure(vllm_inference_adapter):
|
||||||
This test verifies that the health method returns a HealthResponse with status ERROR
|
This test verifies that the health method returns a HealthResponse with status ERROR
|
||||||
and an appropriate error message when the connection to the vLLM server fails.
|
and an appropriate error message when the connection to the vLLM server fails.
|
||||||
"""
|
"""
|
||||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
with patch("httpx.AsyncClient") as mock_client_class:
|
||||||
# Create mock client and models
|
# Create mock client instance that raises an exception
|
||||||
mock_client = MagicMock()
|
mock_client_instance = MagicMock()
|
||||||
mock_models = MagicMock()
|
mock_client_instance.get.side_effect = Exception("Connection failed")
|
||||||
|
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||||
# Create a mock async iterator that raises an exception when iterated
|
|
||||||
async def mock_list():
|
|
||||||
raise Exception("Connection failed")
|
|
||||||
yield # Unreachable code
|
|
||||||
|
|
||||||
# Set up the models.list to return our mock async iterator
|
|
||||||
mock_models.list.return_value = mock_list()
|
|
||||||
mock_client.models = mock_models
|
|
||||||
mock_create_client.return_value = mock_client
|
|
||||||
|
|
||||||
# Call the health method
|
# Call the health method
|
||||||
health_response = await vllm_inference_adapter.health()
|
health_response = await vllm_inference_adapter.health()
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert health_response["status"] == HealthStatus.ERROR
|
assert health_response["status"] == HealthStatus.ERROR
|
||||||
assert "Health check failed: Connection failed" in health_response["message"]
|
assert "Health check failed: Connection failed" in health_response["message"]
|
||||||
|
|
||||||
mock_models.list.assert_called_once()
|
|
||||||
|
async def test_health_status_no_static_api_key(vllm_inference_adapter):
|
||||||
|
"""
|
||||||
|
Test the health method of VLLM InferenceAdapter when no static API key is provided.
|
||||||
|
|
||||||
|
This test verifies that the health method returns a HealthResponse with status OK
|
||||||
|
when the /health endpoint responds successfully, regardless of API token configuration.
|
||||||
|
"""
|
||||||
|
with patch("httpx.AsyncClient") as mock_client_class:
|
||||||
|
# Create mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
# Create mock client instance
|
||||||
|
mock_client_instance = MagicMock()
|
||||||
|
mock_client_instance.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
# Call the health method
|
||||||
|
health_response = await vllm_inference_adapter.health()
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert health_response["status"] == HealthStatus.OK
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
||||||
|
@ -656,3 +676,109 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
||||||
|
|
||||||
assert mock_create_client.call_count == 4 # no cheating
|
assert mock_create_client.call_count == 4 # no cheating
|
||||||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_should_refresh_models():
|
||||||
|
"""
|
||||||
|
Test the should_refresh_models method with different refresh_models configurations.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
|
||||||
|
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Test case 1: refresh_models is True, api_token is None
|
||||||
|
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
|
||||||
|
adapter1 = VLLMInferenceAdapter(config1)
|
||||||
|
result1 = await adapter1.should_refresh_models()
|
||||||
|
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
|
||||||
|
|
||||||
|
# Test case 2: refresh_models is True, api_token is empty string
|
||||||
|
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
|
||||||
|
adapter2 = VLLMInferenceAdapter(config2)
|
||||||
|
result2 = await adapter2.should_refresh_models()
|
||||||
|
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
|
||||||
|
|
||||||
|
# Test case 3: refresh_models is True, api_token is "fake" (default)
|
||||||
|
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
|
||||||
|
adapter3 = VLLMInferenceAdapter(config3)
|
||||||
|
result3 = await adapter3.should_refresh_models()
|
||||||
|
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
|
||||||
|
|
||||||
|
# Test case 4: refresh_models is True, api_token is real token
|
||||||
|
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
|
||||||
|
adapter4 = VLLMInferenceAdapter(config4)
|
||||||
|
result4 = await adapter4.should_refresh_models()
|
||||||
|
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
|
||||||
|
|
||||||
|
# Test case 5: refresh_models is False, api_token is real token
|
||||||
|
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
|
||||||
|
adapter5 = VLLMInferenceAdapter(config5)
|
||||||
|
result5 = await adapter5.should_refresh_models()
|
||||||
|
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
|
||||||
|
"""
|
||||||
|
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
|
||||||
|
This ensures that dynamic provider data (like API tokens) can be passed through context.
|
||||||
|
Note: The base URL is always taken from config.url, not from provider data.
|
||||||
|
"""
|
||||||
|
# Mock the AsyncOpenAI class to capture provider data
|
||||||
|
with (
|
||||||
|
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
|
||||||
|
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
|
||||||
|
):
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock()
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock provider data to return test data
|
||||||
|
mock_provider_data = MagicMock()
|
||||||
|
mock_provider_data.vllm_api_token = "test-token-123"
|
||||||
|
mock_provider_data.vllm_url = "http://test-server:8000/v1"
|
||||||
|
mock_get_provider_data.return_value = mock_provider_data
|
||||||
|
|
||||||
|
# Mock the model
|
||||||
|
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
|
||||||
|
vllm_inference_adapter.model_store.get_model.return_value = mock_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute chat completion
|
||||||
|
await vllm_inference_adapter.chat_completion(
|
||||||
|
"test-model",
|
||||||
|
[UserMessage(content="Hello")],
|
||||||
|
stream=False,
|
||||||
|
tools=None,
|
||||||
|
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that ALL client calls were made with the correct parameters
|
||||||
|
calls = mock_openai_class.call_args_list
|
||||||
|
incorrect_calls = []
|
||||||
|
|
||||||
|
for i, call in enumerate(calls):
|
||||||
|
api_key = call[1]["api_key"]
|
||||||
|
base_url = call[1]["base_url"]
|
||||||
|
|
||||||
|
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
|
||||||
|
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
|
||||||
|
|
||||||
|
if incorrect_calls:
|
||||||
|
error_msg = (
|
||||||
|
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
|
||||||
|
)
|
||||||
|
for incorrect_call in incorrect_calls:
|
||||||
|
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
|
||||||
|
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
|
||||||
|
raise AssertionError(error_msg)
|
||||||
|
|
||||||
|
# Ensure at least one call was made
|
||||||
|
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
|
||||||
|
|
||||||
|
# Verify that chat completion was called
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up context
|
||||||
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue