refactor(passthrough): use AsyncOpenAI instead of AsyncLlamaStackClient (#4085)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 48s

We'd like to remove the dependence of `llama-stack` on
`llama-stack-client`. This is a necessary step.

A few small cleanups
- Enables `embeddings` now also
- Remove ModelRegistryHelper dependency (unused)
- Consolidate to auth_credential field via RemoteInferenceProviderConfig
- Implement list_models() to fetch from downstream /v1/models

## Test Plan

Tested using this script
https://gist.github.com/ashwinb/6356463d10f989c0682ab3bff8589581

Output:
```
Listing models from downstream server...
Available models: ['passthrough/ollama/nomic-embed-text:latest', 'passthrough/ollama/all-minilm:l6-v2', 'passthrough/ollama/llama3.2-vision:11b', 'passthrough/ollama/llama3.2-vision:latest', 'passthrough/ollama/llama-guard3:1b', 'passthrough/o
llama/llama3.2:1b', 'passthrough/ollama/all-minilm:latest', 'passthrough/ollama/llama3.2:3b', 'passthrough/ollama/llama3.2:3b-instruct-fp16', 'passthrough/bedrock/meta.llama3-1-8b-instruct-v1:0', 'passthrough/bedrock/meta.llama3-1-70b-instruct
-v1:0', 'passthrough/bedrock/meta.llama3-1-405b-instruct-v1:0', 'passthrough/sentence-transformers/nomic-ai/nomic-embed-text-v1.5']

Using LLM model: passthrough/ollama/llama3.2-vision:11b

Making inference request...

Response: 4.

--- Testing streaming ---
Streamed response: ChatCompletionChunk(id='chatcmpl-64', choices=[Choice(delta=ChoiceDelta(content='1', reasoning_content=None, refusal=None, role='assistant', tool_calls=None), finish_reason='', index=0, logprobs=None)], created=1762381674, m
odel='passthrough/ollama/llama3.2-vision:11b', object='chat.completion.chunk', usage=None)
...
5ChatCompletionChunk(id='chatcmpl-64', choices=[Choice(delta=ChoiceDelta(content='', reasoning_content=None, refusal=None, role='assistant', tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1762381674, model='passthrou
gh/ollama/llama3.2-vision:11b', object='chat.completion.chunk', usage=None)
```
This commit is contained in:
Ashwin Bharambe 2025-11-05 18:15:11 -08:00 committed by GitHub
parent b335419faa
commit bef1b044bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 88 additions and 80 deletions

View file

@ -16,7 +16,7 @@ Passthrough inference provider for connecting to any external inference service
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint | | `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
## Sample Configuration ## Sample Configuration

View file

@ -10,8 +10,8 @@ from .config import PassthroughImplConfig
class PassthroughProviderDataValidator(BaseModel): class PassthroughProviderDataValidator(BaseModel):
url: str passthrough_url: str
api_key: str passthrough_api_key: str
async def get_adapter_impl(config: PassthroughImplConfig, _deps): async def get_adapter_impl(config: PassthroughImplConfig, _deps):

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -19,11 +19,6 @@ class PassthroughImplConfig(RemoteInferenceProviderConfig):
description="The URL for the passthrough endpoint", description="The URL for the passthrough endpoint",
) )
api_key: SecretStr | None = Field(
default=None,
description="API Key for the passthrouth endpoint",
)
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs

View file

@ -5,9 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack_client import AsyncLlamaStackClient from openai import AsyncOpenAI
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
@ -20,103 +19,117 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference): class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def __init__(self, config: PassthroughImplConfig) -> None: def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model
def _get_client(self) -> AsyncLlamaStackClient: async def list_models(self) -> list[Model]:
passthrough_url = None """List models by calling the downstream /v1/models endpoint."""
passthrough_api_key = None client = self._get_openai_client()
provider_data = None
if self.config.url is not None: response = await client.models.list()
passthrough_url = self.config.url
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_url:
raise ValueError(
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
)
passthrough_url = provider_data.passthrough_url
if self.config.api_key is not None: # Convert from OpenAI format to Llama Stack Model format
passthrough_api_key = self.config.api_key.get_secret_value() models = []
else: for model_data in response.data:
provider_data = self.get_request_provider_data() downstream_model_id = model_data.id
if provider_data is None or not provider_data.passthrough_api_key: custom_metadata = getattr(model_data, "custom_metadata", {}) or {}
raise ValueError(
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
passthrough_api_key = provider_data.passthrough_api_key
return AsyncLlamaStackClient( # Prefix identifier with provider ID for local registry
base_url=passthrough_url, local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
api_key=passthrough_api_key,
provider_data=provider_data, model = Model(
identifier=local_identifier,
provider_id=self.__provider_id__,
provider_resource_id=downstream_model_id,
model_type=custom_metadata.get("model_type", "llm"),
metadata=custom_metadata,
)
models.append(model)
return models
async def should_refresh_models(self) -> bool:
"""Passthrough should refresh models since they come from downstream dynamically."""
return self.config.refresh_models
def _get_openai_client(self) -> AsyncOpenAI:
"""Get an AsyncOpenAI client configured for the downstream server."""
base_url = self._get_passthrough_url()
api_key = self._get_passthrough_api_key()
return AsyncOpenAI(
base_url=f"{base_url.rstrip('/')}/v1",
api_key=api_key,
) )
async def openai_embeddings( def _get_passthrough_url(self) -> str:
self, """Get the passthrough URL from config or provider data."""
params: OpenAIEmbeddingsRequestWithExtraBody, if self.config.url is not None:
) -> OpenAIEmbeddingsResponse: return self.config.url
raise NotImplementedError()
provider_data = self.get_request_provider_data()
if provider_data is None:
raise ValueError(
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
)
return provider_data.passthrough_url
def _get_passthrough_api_key(self) -> str:
"""Get the passthrough API key from config or provider data."""
if self.config.auth_credential is not None:
return self.config.auth_credential.get_secret_value()
provider_data = self.get_request_provider_data()
if provider_data is None:
raise ValueError(
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
return provider_data.passthrough_api_key
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion:
client = self._get_client() """Forward completion request to downstream using OpenAI client."""
model_obj = await self.model_store.get_model(params.model) client = self._get_openai_client()
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.completions.create(**request_params)
return await client.inference.openai_completion(**request_params) return response # type: ignore
async def openai_chat_completion( async def openai_chat_completion(
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client() """Forward chat completion request to downstream using OpenAI client."""
model_obj = await self.model_store.get_model(params.model) client = self._get_openai_client()
params = params.model_copy()
params.model = model_obj.provider_resource_id
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.chat.completions.create(**request_params)
return response # type: ignore
return await client.inference.openai_chat_completion(**request_params) async def openai_embeddings(
self,
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: params: OpenAIEmbeddingsRequestWithExtraBody,
json_params = {} ) -> OpenAIEmbeddingsResponse:
for key, value in request_params.items(): """Forward embeddings request to downstream using OpenAI client."""
json_input = convert_pydantic_to_json_value(value) client = self._get_openai_client()
if isinstance(json_input, dict): request_params = params.model_dump(exclude_none=True)
json_input = {k: v for k, v in json_input.items() if v is not None} response = await client.embeddings.create(**request_params)
elif isinstance(json_input, list): return response # type: ignore
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params