diff --git a/src/llama_stack/providers/remote/inference/passthrough/__init__.py b/src/llama_stack/providers/remote/inference/passthrough/__init__.py index 69dd4c461..1cc46bff1 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/__init__.py +++ b/src/llama_stack/providers/remote/inference/passthrough/__init__.py @@ -10,8 +10,8 @@ from .config import PassthroughImplConfig class PassthroughProviderDataValidator(BaseModel): - url: str - api_key: str + passthrough_url: str + passthrough_api_key: str async def get_adapter_impl(config: PassthroughImplConfig, _deps): diff --git a/src/llama_stack/providers/remote/inference/passthrough/config.py b/src/llama_stack/providers/remote/inference/passthrough/config.py index f8e8b8ce5..eca28a86a 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/config.py +++ b/src/llama_stack/providers/remote/inference/passthrough/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,11 +19,6 @@ class PassthroughImplConfig(RemoteInferenceProviderConfig): description="The URL for the passthrough endpoint", ) - api_key: SecretStr | None = Field( - default=None, - description="API Key for the passthrouth endpoint", - ) - @classmethod def sample_run_config( cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs diff --git a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py index 4d4d4f41d..81581a238 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from collections.abc import AsyncIterator -from typing import Any +from enum import Enum -from llama_stack_client import AsyncLlamaStackClient +import httpx +from pydantic import BaseModel from llama_stack.apis.inference import ( Inference, @@ -20,15 +22,13 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsResponse, ) from llama_stack.apis.models import Model -from llama_stack.core.library_client import convert_pydantic_to_json_value -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.core.request_headers import NeedsRequestProviderData from .config import PassthroughImplConfig -class PassthroughInferenceAdapter(Inference): +class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): def __init__(self, config: PassthroughImplConfig) -> None: - ModelRegistryHelper.__init__(self) self.config = config async def unregister_model(self, model_id: str) -> None: @@ -37,86 +37,188 @@ class PassthroughInferenceAdapter(Inference): async def register_model(self, model: Model) -> Model: return model - def _get_client(self) -> AsyncLlamaStackClient: - passthrough_url = None - passthrough_api_key = None - provider_data = None + async def list_models(self) -> list[Model]: + """List models by calling the downstream /v1/models endpoint.""" + base_url = self._get_passthrough_url().rstrip("/") + api_key = self._get_passthrough_api_key() + url = f"{base_url}/v1/models" + headers = { + "Authorization": f"Bearer {api_key}", + "Accept": "application/json", + } + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=30.0) + response.raise_for_status() + + response_data = response.json() + # The response should be a list of Model objects + return [Model.model_validate(m) for m in response_data] + + async def should_refresh_models(self) -> bool: + """Passthrough should refresh models since they come from downstream dynamically.""" + return self.config.refresh_models + + def _get_passthrough_url(self) -> str: + """Get the passthrough URL from config or provider data.""" if self.config.url is not None: - passthrough_url = self.config.url - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.passthrough_url: - raise ValueError( - 'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": }' - ) - passthrough_url = provider_data.passthrough_url + return self.config.url - if self.config.api_key is not None: - passthrough_api_key = self.config.api_key.get_secret_value() - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.passthrough_api_key: - raise ValueError( - 'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": }' - ) - passthrough_api_key = provider_data.passthrough_api_key + provider_data = self.get_request_provider_data() + if provider_data is None: + raise ValueError( + 'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": }' + ) + return provider_data.passthrough_url - return AsyncLlamaStackClient( - base_url=passthrough_url, - api_key=passthrough_api_key, - provider_data=provider_data, + def _get_passthrough_api_key(self) -> str: + """Get the passthrough API key from config or provider data.""" + if self.config.auth_credential is not None: + return self.config.auth_credential.get_secret_value() + + provider_data = self.get_request_provider_data() + if provider_data is None: + raise ValueError( + 'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": }' + ) + return provider_data.passthrough_api_key + + def _serialize_value(self, value): + """Convert Pydantic models and enums to JSON-serializable values.""" + if isinstance(value, BaseModel): + return json.loads(value.model_dump_json()) + elif isinstance(value, Enum): + return value.value + elif isinstance(value, list): + return [self._serialize_value(item) for item in value] + elif isinstance(value, dict): + return {k: self._serialize_value(v) for k, v in value.items()} + else: + return value + + async def _make_request( + self, + endpoint: str, + params: dict, + response_type: type, + stream: bool = False, + ): + """Make an HTTP request to the passthrough endpoint.""" + base_url = self._get_passthrough_url().rstrip("/") + api_key = self._get_passthrough_api_key() + + url = f"{base_url}/v1/{endpoint.lstrip('/')}" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + # Serialize the request body + json_body = self._serialize_value(params) + + if stream: + return self._stream_request(url, headers, json_body, response_type) + else: + return await self._non_stream_request(url, headers, json_body, response_type) + + async def _non_stream_request( + self, + url: str, + headers: dict, + json_body: dict, + response_type: type, + ): + """Make a non-streaming HTTP request.""" + async with httpx.AsyncClient() as client: + response = await client.post(url, json=json_body, headers=headers, timeout=30.0) + response.raise_for_status() + + response_data = response.json() + return response_type.model_validate(response_data) + + async def _stream_request( + self, + url: str, + headers: dict, + json_body: dict, + response_type: type, + ) -> AsyncIterator: + """Make a streaming HTTP request with Server-Sent Events parsing.""" + async with httpx.AsyncClient() as client: + async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if line.startswith("data:"): + # Extract JSON after "data: " prefix + data_str = line[len("data:") :].strip() + + # Skip empty lines or "[DONE]" marker + if not data_str or data_str == "[DONE]": + continue + + try: + data = json.loads(data_str) + yield response_type.model_validate(data) + except Exception: + # Log and skip malformed chunks + continue + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(params.model) + + # Create a copy with the provider's model ID + params = params.model_copy() + params.model = model_obj.provider_resource_id + + request_params = params.model_dump(exclude_none=True) + + return await self._make_request( + endpoint="completions", + params=request_params, + response_type=OpenAICompletion, + stream=params.stream or False, + ) + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + model_obj = await self.model_store.get_model(params.model) + + # Create a copy with the provider's model ID + params = params.model_copy() + params.model = model_obj.provider_resource_id + + request_params = params.model_dump(exclude_none=True) + + return await self._make_request( + endpoint="chat/completions", + params=request_params, + response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion, + stream=params.stream or False, ) async def openai_embeddings( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - - async def openai_completion( - self, - params: OpenAICompletionRequestWithExtraBody, - ) -> OpenAICompletion: - client = self._get_client() model_obj = await self.model_store.get_model(params.model) + # Create a copy with the provider's model ID params = params.model_copy() params.model = model_obj.provider_resource_id request_params = params.model_dump(exclude_none=True) - return await client.inference.openai_completion(**request_params) - - async def openai_chat_completion( - self, - params: OpenAIChatCompletionRequestWithExtraBody, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - client = self._get_client() - model_obj = await self.model_store.get_model(params.model) - - params = params.model_copy() - params.model = model_obj.provider_resource_id - - request_params = params.model_dump(exclude_none=True) - - return await client.inference.openai_chat_completion(**request_params) - - def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: - json_params = {} - for key, value in request_params.items(): - json_input = convert_pydantic_to_json_value(value) - if isinstance(json_input, dict): - json_input = {k: v for k, v in json_input.items() if v is not None} - elif isinstance(json_input, list): - json_input = [x for x in json_input if x is not None] - new_input = [] - for x in json_input: - if isinstance(x, dict): - x = {k: v for k, v in x.items() if v is not None} - new_input.append(x) - json_input = new_input - - json_params[key] = json_input - - return json_params + return await self._make_request( + endpoint="embeddings", + params=request_params, + response_type=OpenAIEmbeddingsResponse, + stream=False, + )