From b37eb2465004914a819c606a6e2ad2983bccfe03 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 5 Nov 2025 15:05:52 -0800 Subject: [PATCH] simpler --- .../inference/passthrough/passthrough.py | 185 ++++-------------- 1 file changed, 41 insertions(+), 144 deletions(-) diff --git a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py index d991a663b..3c56acfbd 100644 --- a/src/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/src/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,12 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json from collections.abc import AsyncIterator -from enum import Enum -import httpx -from pydantic import BaseModel +from openai import AsyncOpenAI from llama_stack.apis.inference import ( Inference, @@ -45,46 +42,44 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): async def list_models(self) -> list[Model]: """List models by calling the downstream /v1/models endpoint.""" - base_url = self._get_passthrough_url().rstrip("/") - api_key = self._get_passthrough_api_key() + client = self._get_openai_client() - url = f"{base_url}/v1/models" - headers = { - "Authorization": f"Bearer {api_key}", - "Accept": "application/json", - } + response = await client.models.list() - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=headers, timeout=30.0) - response.raise_for_status() + # Convert from OpenAI format to Llama Stack Model format + models = [] + for model_data in response.data: + downstream_model_id = model_data.id + custom_metadata = getattr(model_data, "custom_metadata", {}) or {} - response_data = response.json() - models_data = response_data["data"] + # Prefix identifier with provider ID for local registry + local_identifier = f"{self.__provider_id__}/{downstream_model_id}" - # Convert from OpenAI format to Llama Stack Model format - models = [] - for model_data in models_data: - downstream_model_id = model_data["id"] - custom_metadata = model_data.get("custom_metadata", {}) + model = Model( + identifier=local_identifier, + provider_id=self.__provider_id__, + provider_resource_id=downstream_model_id, + model_type=custom_metadata.get("model_type", "llm"), + metadata=custom_metadata, + ) + models.append(model) - # Prefix identifier with provider ID for local registry - local_identifier = f"{self.__provider_id__}/{downstream_model_id}" - - model = Model( - identifier=local_identifier, - provider_id=self.__provider_id__, - provider_resource_id=downstream_model_id, - model_type=custom_metadata.get("model_type", "llm"), - metadata=custom_metadata, - ) - models.append(model) - - return models + return models async def should_refresh_models(self) -> bool: """Passthrough should refresh models since they come from downstream dynamically.""" return self.config.refresh_models + def _get_openai_client(self) -> AsyncOpenAI: + """Get an AsyncOpenAI client configured for the downstream server.""" + base_url = self._get_passthrough_url() + api_key = self._get_passthrough_api_key() + + return AsyncOpenAI( + base_url=f"{base_url.rstrip('/')}/v1", + api_key=api_key, + ) + def _get_passthrough_url(self) -> str: """Get the passthrough URL from config or provider data.""" if self.config.url is not None: @@ -109,130 +104,32 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference): ) return provider_data.passthrough_api_key - def _serialize_value(self, value): - """Convert Pydantic models and enums to JSON-serializable values.""" - if isinstance(value, BaseModel): - return json.loads(value.model_dump_json()) - elif isinstance(value, Enum): - return value.value - elif isinstance(value, list): - return [self._serialize_value(item) for item in value] - elif isinstance(value, dict): - return {k: self._serialize_value(v) for k, v in value.items()} - else: - return value - - async def _make_request( - self, - endpoint: str, - params: dict, - response_type: type, - stream: bool = False, - ): - """Make an HTTP request to the passthrough endpoint.""" - base_url = self._get_passthrough_url().rstrip("/") - api_key = self._get_passthrough_api_key() - - url = f"{base_url}/v1/{endpoint.lstrip('/')}" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "Accept": "application/json", - } - - # Serialize the request body - json_body = self._serialize_value(params) - - if stream: - return self._stream_request(url, headers, json_body, response_type) - else: - return await self._non_stream_request(url, headers, json_body, response_type) - - async def _non_stream_request( - self, - url: str, - headers: dict, - json_body: dict, - response_type: type, - ): - """Make a non-streaming HTTP request.""" - async with httpx.AsyncClient() as client: - response = await client.post(url, json=json_body, headers=headers, timeout=30.0) - response.raise_for_status() - - response_data = response.json() - return response_type.model_validate(response_data) - - async def _stream_request( - self, - url: str, - headers: dict, - json_body: dict, - response_type: type, - ) -> AsyncIterator: - """Make a streaming HTTP request with Server-Sent Events parsing.""" - async with httpx.AsyncClient() as client: - async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response: - response.raise_for_status() - - async for line in response.aiter_lines(): - if line.startswith("data:"): - # Extract JSON after "data: " prefix - data_str = line[len("data:") :].strip() - - # Skip empty lines or "[DONE]" marker - if not data_str or data_str == "[DONE]": - continue - - data = json.loads(data_str) - - # Fix OpenAI compatibility: finish_reason can be null in intermediate chunks - # but our Pydantic model may not accept null - if "choices" in data: - for choice in data["choices"]: - if choice.get("finish_reason") is None: - choice["finish_reason"] = "" - - yield response_type.model_validate(data) - async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: - # params.model is already the provider_resource_id (router translated it) + """Forward completion request to downstream using OpenAI client.""" + client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) - - return await self._make_request( - endpoint="completions", - params=request_params, - response_type=OpenAICompletion, - stream=params.stream or False, - ) + response = await client.completions.create(**request_params) + return response # type: ignore async def openai_chat_completion( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - # params.model is already the provider_resource_id (router translated it) + """Forward chat completion request to downstream using OpenAI client.""" + client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) - - return await self._make_request( - endpoint="chat/completions", - params=request_params, - response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion, - stream=params.stream or False, - ) + response = await client.chat.completions.create(**request_params) + return response # type: ignore async def openai_embeddings( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: - # params.model is already the provider_resource_id (router translated it) + """Forward embeddings request to downstream using OpenAI client.""" + client = self._get_openai_client() request_params = params.model_dump(exclude_none=True) - - return await self._make_request( - endpoint="embeddings", - params=request_params, - response_type=OpenAIEmbeddingsResponse, - stream=False, - ) + response = await client.embeddings.create(**request_params) + return response # type: ignore