refactor(passthrough): replace client with httpx, simplify implementation

- Use httpx directly instead of AsyncLlamaStackClient
- Remove ModelRegistryHelper dependency (unused)
- Consolidate to auth_credential field via RemoteInferenceProviderConfig
- Implement list_models() to fetch from downstream /v1/models
- Implement all inference methods (completion, chat, embeddings)
- Fix provider data validator field names
- Add SSE parsing for streaming responses
This commit is contained in:
Ashwin Bharambe 2025-11-05 13:47:42 -08:00
parent c899b50723
commit 1ff6eeb434
3 changed files with 177 additions and 80 deletions

View file

@ -10,8 +10,8 @@ from .config import PassthroughImplConfig
class PassthroughProviderDataValidator(BaseModel): class PassthroughProviderDataValidator(BaseModel):
url: str passthrough_url: str
api_key: str passthrough_api_key: str
async def get_adapter_impl(config: PassthroughImplConfig, _deps): async def get_adapter_impl(config: PassthroughImplConfig, _deps):

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -19,11 +19,6 @@ class PassthroughImplConfig(RemoteInferenceProviderConfig):
description="The URL for the passthrough endpoint", description="The URL for the passthrough endpoint",
) )
api_key: SecretStr | None = Field(
default=None,
description="API Key for the passthrouth endpoint",
)
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from enum import Enum
from llama_stack_client import AsyncLlamaStackClient import httpx
from pydantic import BaseModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
@ -20,15 +22,13 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference): class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def __init__(self, config: PassthroughImplConfig) -> None: def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config self.config = config
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
@ -37,86 +37,188 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model
def _get_client(self) -> AsyncLlamaStackClient: async def list_models(self) -> list[Model]:
passthrough_url = None """List models by calling the downstream /v1/models endpoint."""
passthrough_api_key = None base_url = self._get_passthrough_url().rstrip("/")
provider_data = None api_key = self._get_passthrough_api_key()
url = f"{base_url}/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=30.0)
response.raise_for_status()
response_data = response.json()
# The response should be a list of Model objects
return [Model.model_validate(m) for m in response_data]
async def should_refresh_models(self) -> bool:
"""Passthrough should refresh models since they come from downstream dynamically."""
return self.config.refresh_models
def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data."""
if self.config.url is not None: if self.config.url is not None:
passthrough_url = self.config.url return self.config.url
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_url:
raise ValueError(
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
)
passthrough_url = provider_data.passthrough_url
if self.config.api_key is not None: provider_data = self.get_request_provider_data()
passthrough_api_key = self.config.api_key.get_secret_value() if provider_data is None:
else: raise ValueError(
provider_data = self.get_request_provider_data() 'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
if provider_data is None or not provider_data.passthrough_api_key: )
raise ValueError( return provider_data.passthrough_url
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
passthrough_api_key = provider_data.passthrough_api_key
return AsyncLlamaStackClient( def _get_passthrough_api_key(self) -> str:
base_url=passthrough_url, """Get the passthrough API key from config or provider data."""
api_key=passthrough_api_key, if self.config.auth_credential is not None:
provider_data=provider_data, return self.config.auth_credential.get_secret_value()
provider_data = self.get_request_provider_data()
if provider_data is None:
raise ValueError(
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
return provider_data.passthrough_api_key
def _serialize_value(self, value):
"""Convert Pydantic models and enums to JSON-serializable values."""
if isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [self._serialize_value(item) for item in value]
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
else:
return value
async def _make_request(
self,
endpoint: str,
params: dict,
response_type: type,
stream: bool = False,
):
"""Make an HTTP request to the passthrough endpoint."""
base_url = self._get_passthrough_url().rstrip("/")
api_key = self._get_passthrough_api_key()
url = f"{base_url}/v1/{endpoint.lstrip('/')}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# Serialize the request body
json_body = self._serialize_value(params)
if stream:
return self._stream_request(url, headers, json_body, response_type)
else:
return await self._non_stream_request(url, headers, json_body, response_type)
async def _non_stream_request(
self,
url: str,
headers: dict,
json_body: dict,
response_type: type,
):
"""Make a non-streaming HTTP request."""
async with httpx.AsyncClient() as client:
response = await client.post(url, json=json_body, headers=headers, timeout=30.0)
response.raise_for_status()
response_data = response.json()
return response_type.model_validate(response_data)
async def _stream_request(
self,
url: str,
headers: dict,
json_body: dict,
response_type: type,
) -> AsyncIterator:
"""Make a streaming HTTP request with Server-Sent Events parsing."""
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
# Extract JSON after "data: " prefix
data_str = line[len("data:") :].strip()
# Skip empty lines or "[DONE]" marker
if not data_str or data_str == "[DONE]":
continue
try:
data = json.loads(data_str)
yield response_type.model_validate(data)
except Exception:
# Log and skip malformed chunks
continue
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(params.model)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
endpoint="completions",
params=request_params,
response_type=OpenAICompletion,
stream=params.stream or False,
)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(params.model)
# Create a copy with the provider's model ID
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
endpoint="chat/completions",
params=request_params,
response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion,
stream=params.stream or False,
) )
async def openai_embeddings( async def openai_embeddings(
self, self,
params: OpenAIEmbeddingsRequestWithExtraBody, params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(params.model) model_obj = await self.model_store.get_model(params.model)
# Create a copy with the provider's model ID
params = params.model_copy() params = params.model_copy()
params.model = model_obj.provider_resource_id params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params) return await self._make_request(
endpoint="embeddings",
async def openai_chat_completion( params=request_params,
self, response_type=OpenAIEmbeddingsResponse,
params: OpenAIChatCompletionRequestWithExtraBody, stream=False,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: )
client = self._get_client()
model_obj = await self.model_store.get_model(params.model)
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params