From bef1b044bde10fa5a1ef70eb0269c04afeaef817 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 5 Nov 2025 18:15:11 -0800 Subject: [PATCH] refactor(passthrough): use AsyncOpenAI instead of AsyncLlamaStackClient (#4085) We'd like to remove the dependence of `llama-stack` on `llama-stack-client`. This is a necessary step. A few small cleanups - Enables `embeddings` now also - Remove ModelRegistryHelper dependency (unused) - Consolidate to auth_credential field via RemoteInferenceProviderConfig - Implement list_models() to fetch from downstream /v1/models ## Test Plan Tested using this script https://gist.github.com/ashwinb/6356463d10f989c0682ab3bff8589581 Output: ``` Listing models from downstream server... Available models: ['passthrough/ollama/nomic-embed-text:latest', 'passthrough/ollama/all-minilm:l6-v2', 'passthrough/ollama/llama3.2-vision:11b', 'passthrough/ollama/llama3.2-vision:latest', 'passthrough/ollama/llama-guard3:1b', 'passthrough/o llama/llama3.2:1b', 'passthrough/ollama/all-minilm:latest', 'passthrough/ollama/llama3.2:3b', 'passthrough/ollama/llama3.2:3b-instruct-fp16', 'passthrough/bedrock/meta.llama3-1-8b-instruct-v1:0', 'passthrough/bedrock/meta.llama3-1-70b-instruct -v1:0', 'passthrough/bedrock/meta.llama3-1-405b-instruct-v1:0', 'passthrough/sentence-transformers/nomic-ai/nomic-embed-text-v1.5'] Using LLM model: passthrough/ollama/llama3.2-vision:11b Making inference request... Response: 4. --- Testing streaming --- Streamed response: ChatCompletionChunk(id='chatcmpl-64', choices=[Choice(delta=ChoiceDelta(content='1', reasoning_content=None, refusal=None, role='assistant', tool_calls=None), finish_reason='', index=0, logprobs=None)], created=1762381674, m odel='passthrough/ollama/llama3.2-vision:11b', object='chat.completion.chunk', usage=None) ... 5ChatCompletionChunk(id='chatcmpl-64', choices=[Choice(delta=ChoiceDelta(content='', reasoning_content=None, refusal=None, role='assistant', tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1762381674, model='passthrou gh/ollama/llama3.2-vision:11b', object='chat.completion.chunk', usage=None) ``` --- .../inference/remote_passthrough.mdx | 2 +- .../remote/inference/passthrough/__init__.py | 4 +- .../remote/inference/passthrough/config.py | 7 +- .../inference/passthrough/passthrough.py | 155 ++++++++++-------- 4 files changed, 88 insertions(+), 80 deletions(-) diff --git a/docs/docs/providers/inference/remote_passthrough.mdx b/docs/docs/providers/inference/remote_passthrough.mdx index 7a2931690..957cd04da 100644 --- a/docs/docs/providers/inference/remote_passthrough.mdx +++ b/docs/docs/providers/inference/remote_passthrough.mdx @@ -16,7 +16,7 @@ Passthrough inference provider for connecting to any external inference service |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | | The URL for the passthrough endpoint | ## Sample Configuration diff --git a/src/llama_stack/providers/remote/inference/passthrough/__init__.py b/src/llama_stack/providers/remote/inference/passthrough/__init__.py index 69dd4c461..1cc46bff1 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/__init__.py +++ b/src/llama_stack/providers/remote/inference/passthrough/__init__.py @@ -10,8 +10,8 @@ from .config import PassthroughImplConfig class PassthroughProviderDataValidator(BaseModel): - url: str - api_key: str + passthrough_url: str + passthrough_api_key: str async def get_adapter_impl(config: PassthroughImplConfig, _deps): diff --git a/src/llama_stack/providers/remote/inference/passthrough/config.py b/src/llama_stack/providers/remote/inference/passthrough/config.py index f8e8b8ce5..eca28a86a 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/config.py +++ b/src/llama_stack/providers/remote/inference/passthrough/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -19,11 +19,6 @@ class PassthroughImplConfig(RemoteInferenceProviderConfig): description="The URL for the passthrough endpoint", ) - api_key: SecretStr | None = Field( - default=None, - description="API Key for the passthrouth endpoint", - ) - @classmethod def sample_run_config( cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs diff --git a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py index 4d4d4f41d..3c56acfbd 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -5,9 +5,8 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from typing import Any -from llama_stack_client import AsyncLlamaStackClient +from openai import AsyncOpenAI from llama_stack.apis.inference import ( Inference, @@ -20,103 +19,117 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsResponse, ) from llama_stack.apis.models import Model -from llama_stack.core.library_client import convert_pydantic_to_json_value -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.core.request_headers import NeedsRequestProviderData from .config import PassthroughImplConfig -class PassthroughInferenceAdapter(Inference): +class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): def __init__(self, config: PassthroughImplConfig) -> None: - ModelRegistryHelper.__init__(self) self.config = config + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + async def unregister_model(self, model_id: str) -> None: pass async def register_model(self, model: Model) -> Model: return model - def _get_client(self) -> AsyncLlamaStackClient: - passthrough_url = None - passthrough_api_key = None - provider_data = None + async def list_models(self) -> list[Model]: + """List models by calling the downstream /v1/models endpoint.""" + client = self._get_openai_client() - if self.config.url is not None: - passthrough_url = self.config.url - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.passthrough_url: - raise ValueError( - 'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": }' - ) - passthrough_url = provider_data.passthrough_url + response = await client.models.list() - if self.config.api_key is not None: - passthrough_api_key = self.config.api_key.get_secret_value() - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.passthrough_api_key: - raise ValueError( - 'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": }' - ) - passthrough_api_key = provider_data.passthrough_api_key + # Convert from OpenAI format to Llama Stack Model format + models = [] + for model_data in response.data: + downstream_model_id = model_data.id + custom_metadata = getattr(model_data, "custom_metadata", {}) or {} - return AsyncLlamaStackClient( - base_url=passthrough_url, - api_key=passthrough_api_key, - provider_data=provider_data, + # Prefix identifier with provider ID for local registry + local_identifier = f"{self.__provider_id__}/{downstream_model_id}" + + model = Model( + identifier=local_identifier, + provider_id=self.__provider_id__, + provider_resource_id=downstream_model_id, + model_type=custom_metadata.get("model_type", "llm"), + metadata=custom_metadata, + ) + models.append(model) + + return models + + async def should_refresh_models(self) -> bool: + """Passthrough should refresh models since they come from downstream dynamically.""" + return self.config.refresh_models + + def _get_openai_client(self) -> AsyncOpenAI: + """Get an AsyncOpenAI client configured for the downstream server.""" + base_url = self._get_passthrough_url() + api_key = self._get_passthrough_api_key() + + return AsyncOpenAI( + base_url=f"{base_url.rstrip('/')}/v1", + api_key=api_key, ) - async def openai_embeddings( - self, - params: OpenAIEmbeddingsRequestWithExtraBody, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + def _get_passthrough_url(self) -> str: + """Get the passthrough URL from config or provider data.""" + if self.config.url is not None: + return self.config.url + + provider_data = self.get_request_provider_data() + if provider_data is None: + raise ValueError( + 'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": }' + ) + return provider_data.passthrough_url + + def _get_passthrough_api_key(self) -> str: + """Get the passthrough API key from config or provider data.""" + if self.config.auth_credential is not None: + return self.config.auth_credential.get_secret_value() + + provider_data = self.get_request_provider_data() + if provider_data is None: + raise ValueError( + 'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": }' + ) + return provider_data.passthrough_api_key async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: - client = self._get_client() - model_obj = await self.model_store.get_model(params.model) - - params = params.model_copy() - params.model = model_obj.provider_resource_id - + """Forward completion request to downstream using OpenAI client.""" + client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) - - return await client.inference.openai_completion(**request_params) + response = await client.completions.create(**request_params) + return response # type: ignore async def openai_chat_completion( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - client = self._get_client() - model_obj = await self.model_store.get_model(params.model) - - params = params.model_copy() - params.model = model_obj.provider_resource_id - + """Forward chat completion request to downstream using OpenAI client.""" + client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) + response = await client.chat.completions.create(**request_params) + return response # type: ignore - return await client.inference.openai_chat_completion(**request_params) - - def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: - json_params = {} - for key, value in request_params.items(): - json_input = convert_pydantic_to_json_value(value) - if isinstance(json_input, dict): - json_input = {k: v for k, v in json_input.items() if v is not None} - elif isinstance(json_input, list): - json_input = [x for x in json_input if x is not None] - new_input = [] - for x in json_input: - if isinstance(x, dict): - x = {k: v for k, v in x.items() if v is not None} - new_input.append(x) - json_input = new_input - - json_params[key] = json_input - - return json_params + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """Forward embeddings request to downstream using OpenAI client.""" + client = self._get_openai_client() + request_params = params.model_dump(exclude_none=True) + response = await client.embeddings.create(**request_params) + return response # type: ignore