This commit is contained in:
Ashwin Bharambe 2025-11-05 15:05:52 -08:00
parent 6115679679
commit b37eb24650

View file

@ -4,12 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum
import httpx from openai import AsyncOpenAI
from pydantic import BaseModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
@ -45,27 +42,15 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
async def list_models(self) -> list[Model]: async def list_models(self) -> list[Model]:
"""List models by calling the downstream /v1/models endpoint.""" """List models by calling the downstream /v1/models endpoint."""
base_url = self._get_passthrough_url().rstrip("/") client = self._get_openai_client()
api_key = self._get_passthrough_api_key()
url = f"{base_url}/v1/models" response = await client.models.list()
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=30.0)
response.raise_for_status()
response_data = response.json()
models_data = response_data["data"]
# Convert from OpenAI format to Llama Stack Model format # Convert from OpenAI format to Llama Stack Model format
models = [] models = []
for model_data in models_data: for model_data in response.data:
downstream_model_id = model_data["id"] downstream_model_id = model_data.id
custom_metadata = model_data.get("custom_metadata", {}) custom_metadata = getattr(model_data, "custom_metadata", {}) or {}
# Prefix identifier with provider ID for local registry # Prefix identifier with provider ID for local registry
local_identifier = f"{self.__provider_id__}/{downstream_model_id}" local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
@ -85,6 +70,16 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
"""Passthrough should refresh models since they come from downstream dynamically.""" """Passthrough should refresh models since they come from downstream dynamically."""
return self.config.refresh_models return self.config.refresh_models
def _get_openai_client(self) -> AsyncOpenAI:
"""Get an AsyncOpenAI client configured for the downstream server."""
base_url = self._get_passthrough_url()
api_key = self._get_passthrough_api_key()
return AsyncOpenAI(
base_url=f"{base_url.rstrip('/')}/v1",
api_key=api_key,
)
def _get_passthrough_url(self) -> str: def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data.""" """Get the passthrough URL from config or provider data."""
if self.config.url is not None: if self.config.url is not None:
@ -109,130 +104,32 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
) )
return provider_data.passthrough_api_key return provider_data.passthrough_api_key
def _serialize_value(self, value):
"""Convert Pydantic models and enums to JSON-serializable values."""
if isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [self._serialize_value(item) for item in value]
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
else:
return value
async def _make_request(
self,
endpoint: str,
params: dict,
response_type: type,
stream: bool = False,
):
"""Make an HTTP request to the passthrough endpoint."""
base_url = self._get_passthrough_url().rstrip("/")
api_key = self._get_passthrough_api_key()
url = f"{base_url}/v1/{endpoint.lstrip('/')}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# Serialize the request body
json_body = self._serialize_value(params)
if stream:
return self._stream_request(url, headers, json_body, response_type)
else:
return await self._non_stream_request(url, headers, json_body, response_type)
async def _non_stream_request(
self,
url: str,
headers: dict,
json_body: dict,
response_type: type,
):
"""Make a non-streaming HTTP request."""
async with httpx.AsyncClient() as client:
response = await client.post(url, json=json_body, headers=headers, timeout=30.0)
response.raise_for_status()
response_data = response.json()
return response_type.model_validate(response_data)
async def _stream_request(
self,
url: str,
headers: dict,
json_body: dict,
response_type: type,
) -> AsyncIterator:
"""Make a streaming HTTP request with Server-Sent Events parsing."""
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
# Extract JSON after "data: " prefix
data_str = line[len("data:") :].strip()
# Skip empty lines or "[DONE]" marker
if not data_str or data_str == "[DONE]":
continue
data = json.loads(data_str)
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
# but our Pydantic model may not accept null
if "choices" in data:
for choice in data["choices"]:
if choice.get("finish_reason") is None:
choice["finish_reason"] = ""
yield response_type.model_validate(data)
async def openai_completion( async def openai_completion(
self, self,
params: OpenAICompletionRequestWithExtraBody, params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion: ) -> OpenAICompletion:
# params.model is already the provider_resource_id (router translated it) """Forward completion request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.completions.create(**request_params)
return await self._make_request( return response # type: ignore
endpoint="completions",
params=request_params,
response_type=OpenAICompletion,
stream=params.stream or False,
)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# params.model is already the provider_resource_id (router translated it) """Forward chat completion request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.chat.completions.create(**request_params)
return await self._make_request( return response # type: ignore
endpoint="chat/completions",
params=request_params,
response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion,
stream=params.stream or False,
)
async def openai_embeddings( async def openai_embeddings(
self, self,
params: OpenAIEmbeddingsRequestWithExtraBody, params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
# params.model is already the provider_resource_id (router translated it) """Forward embeddings request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True) request_params = params.model_dump(exclude_none=True)
response = await client.embeddings.create(**request_params)
return await self._make_request( return response # type: ignore
endpoint="embeddings",
params=request_params,
response_type=OpenAIEmbeddingsResponse,
stream=False,
)