This commit is contained in:
Ashwin Bharambe 2025-11-05 15:05:52 -08:00
parent 6115679679
commit b37eb24650

View file

@ -4,12 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncIterator
from enum import Enum
import httpx
from pydantic import BaseModel
from openai import AsyncOpenAI
from llama_stack.apis.inference import (
Inference,
@ -45,27 +42,15 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
async def list_models(self) -> list[Model]:
"""List models by calling the downstream /v1/models endpoint."""
base_url = self._get_passthrough_url().rstrip("/")
api_key = self._get_passthrough_api_key()
client = self._get_openai_client()
url = f"{base_url}/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=30.0)
response.raise_for_status()
response_data = response.json()
models_data = response_data["data"]
response = await client.models.list()
# Convert from OpenAI format to Llama Stack Model format
models = []
for model_data in models_data:
downstream_model_id = model_data["id"]
custom_metadata = model_data.get("custom_metadata", {})
for model_data in response.data:
downstream_model_id = model_data.id
custom_metadata = getattr(model_data, "custom_metadata", {}) or {}
# Prefix identifier with provider ID for local registry
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
@ -85,6 +70,16 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
"""Passthrough should refresh models since they come from downstream dynamically."""
return self.config.refresh_models
def _get_openai_client(self) -> AsyncOpenAI:
"""Get an AsyncOpenAI client configured for the downstream server."""
base_url = self._get_passthrough_url()
api_key = self._get_passthrough_api_key()
return AsyncOpenAI(
base_url=f"{base_url.rstrip('/')}/v1",
api_key=api_key,
)
def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data."""
if self.config.url is not None:
@ -109,130 +104,32 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
)
return provider_data.passthrough_api_key
def _serialize_value(self, value):
"""Convert Pydantic models and enums to JSON-serializable values."""
if isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [self._serialize_value(item) for item in value]
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
else:
return value
async def _make_request(
self,
endpoint: str,
params: dict,
response_type: type,
stream: bool = False,
):
"""Make an HTTP request to the passthrough endpoint."""
base_url = self._get_passthrough_url().rstrip("/")
api_key = self._get_passthrough_api_key()
url = f"{base_url}/v1/{endpoint.lstrip('/')}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# Serialize the request body
json_body = self._serialize_value(params)
if stream:
return self._stream_request(url, headers, json_body, response_type)
else:
return await self._non_stream_request(url, headers, json_body, response_type)
async def _non_stream_request(
self,
url: str,
headers: dict,
json_body: dict,
response_type: type,
):
"""Make a non-streaming HTTP request."""
async with httpx.AsyncClient() as client:
response = await client.post(url, json=json_body, headers=headers, timeout=30.0)
response.raise_for_status()
response_data = response.json()
return response_type.model_validate(response_data)
async def _stream_request(
self,
url: str,
headers: dict,
json_body: dict,
response_type: type,
) -> AsyncIterator:
"""Make a streaming HTTP request with Server-Sent Events parsing."""
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
# Extract JSON after "data: " prefix
data_str = line[len("data:") :].strip()
# Skip empty lines or "[DONE]" marker
if not data_str or data_str == "[DONE]":
continue
data = json.loads(data_str)
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
# but our Pydantic model may not accept null
if "choices" in data:
for choice in data["choices"]:
if choice.get("finish_reason") is None:
choice["finish_reason"] = ""
yield response_type.model_validate(data)
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
# params.model is already the provider_resource_id (router translated it)
"""Forward completion request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
endpoint="completions",
params=request_params,
response_type=OpenAICompletion,
stream=params.stream or False,
)
response = await client.completions.create(**request_params)
return response # type: ignore
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# params.model is already the provider_resource_id (router translated it)
"""Forward chat completion request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
endpoint="chat/completions",
params=request_params,
response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion,
stream=params.stream or False,
)
response = await client.chat.completions.create(**request_params)
return response # type: ignore
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
# params.model is already the provider_resource_id (router translated it)
"""Forward embeddings request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
return await self._make_request(
endpoint="embeddings",
params=request_params,
response_type=OpenAIEmbeddingsResponse,
stream=False,
)
response = await client.embeddings.create(**request_params)
return response # type: ignore