mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
simpler
This commit is contained in:
parent
6115679679
commit
b37eb24650
1 changed files with 41 additions and 144 deletions
|
|
@ -4,12 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
|
|
@ -45,46 +42,44 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
|
||||
async def list_models(self) -> list[Model]:
|
||||
"""List models by calling the downstream /v1/models endpoint."""
|
||||
base_url = self._get_passthrough_url().rstrip("/")
|
||||
api_key = self._get_passthrough_api_key()
|
||||
client = self._get_openai_client()
|
||||
|
||||
url = f"{base_url}/v1/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = await client.models.list()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
# Convert from OpenAI format to Llama Stack Model format
|
||||
models = []
|
||||
for model_data in response.data:
|
||||
downstream_model_id = model_data.id
|
||||
custom_metadata = getattr(model_data, "custom_metadata", {}) or {}
|
||||
|
||||
response_data = response.json()
|
||||
models_data = response_data["data"]
|
||||
# Prefix identifier with provider ID for local registry
|
||||
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
||||
|
||||
# Convert from OpenAI format to Llama Stack Model format
|
||||
models = []
|
||||
for model_data in models_data:
|
||||
downstream_model_id = model_data["id"]
|
||||
custom_metadata = model_data.get("custom_metadata", {})
|
||||
model = Model(
|
||||
identifier=local_identifier,
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=downstream_model_id,
|
||||
model_type=custom_metadata.get("model_type", "llm"),
|
||||
metadata=custom_metadata,
|
||||
)
|
||||
models.append(model)
|
||||
|
||||
# Prefix identifier with provider ID for local registry
|
||||
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
||||
|
||||
model = Model(
|
||||
identifier=local_identifier,
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=downstream_model_id,
|
||||
model_type=custom_metadata.get("model_type", "llm"),
|
||||
metadata=custom_metadata,
|
||||
)
|
||||
models.append(model)
|
||||
|
||||
return models
|
||||
return models
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
"""Passthrough should refresh models since they come from downstream dynamically."""
|
||||
return self.config.refresh_models
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
"""Get an AsyncOpenAI client configured for the downstream server."""
|
||||
base_url = self._get_passthrough_url()
|
||||
api_key = self._get_passthrough_api_key()
|
||||
|
||||
return AsyncOpenAI(
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def _get_passthrough_url(self) -> str:
|
||||
"""Get the passthrough URL from config or provider data."""
|
||||
if self.config.url is not None:
|
||||
|
|
@ -109,130 +104,32 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
)
|
||||
return provider_data.passthrough_api_key
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""Convert Pydantic models and enums to JSON-serializable values."""
|
||||
if isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
else:
|
||||
return value
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: dict,
|
||||
response_type: type,
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Make an HTTP request to the passthrough endpoint."""
|
||||
base_url = self._get_passthrough_url().rstrip("/")
|
||||
api_key = self._get_passthrough_api_key()
|
||||
|
||||
url = f"{base_url}/v1/{endpoint.lstrip('/')}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
# Serialize the request body
|
||||
json_body = self._serialize_value(params)
|
||||
|
||||
if stream:
|
||||
return self._stream_request(url, headers, json_body, response_type)
|
||||
else:
|
||||
return await self._non_stream_request(url, headers, json_body, response_type)
|
||||
|
||||
async def _non_stream_request(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict,
|
||||
json_body: dict,
|
||||
response_type: type,
|
||||
):
|
||||
"""Make a non-streaming HTTP request."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=json_body, headers=headers, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
return response_type.model_validate(response_data)
|
||||
|
||||
async def _stream_request(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict,
|
||||
json_body: dict,
|
||||
response_type: type,
|
||||
) -> AsyncIterator:
|
||||
"""Make a streaming HTTP request with Server-Sent Events parsing."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
# Extract JSON after "data: " prefix
|
||||
data_str = line[len("data:") :].strip()
|
||||
|
||||
# Skip empty lines or "[DONE]" marker
|
||||
if not data_str or data_str == "[DONE]":
|
||||
continue
|
||||
|
||||
data = json.loads(data_str)
|
||||
|
||||
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
|
||||
# but our Pydantic model may not accept null
|
||||
if "choices" in data:
|
||||
for choice in data["choices"]:
|
||||
if choice.get("finish_reason") is None:
|
||||
choice["finish_reason"] = ""
|
||||
|
||||
yield response_type.model_validate(data)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
# params.model is already the provider_resource_id (router translated it)
|
||||
"""Forward completion request to downstream using OpenAI client."""
|
||||
client = self._get_openai_client()
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await self._make_request(
|
||||
endpoint="completions",
|
||||
params=request_params,
|
||||
response_type=OpenAICompletion,
|
||||
stream=params.stream or False,
|
||||
)
|
||||
response = await client.completions.create(**request_params)
|
||||
return response # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
# params.model is already the provider_resource_id (router translated it)
|
||||
"""Forward chat completion request to downstream using OpenAI client."""
|
||||
client = self._get_openai_client()
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await self._make_request(
|
||||
endpoint="chat/completions",
|
||||
params=request_params,
|
||||
response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion,
|
||||
stream=params.stream or False,
|
||||
)
|
||||
response = await client.chat.completions.create(**request_params)
|
||||
return response # type: ignore
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
# params.model is already the provider_resource_id (router translated it)
|
||||
"""Forward embeddings request to downstream using OpenAI client."""
|
||||
client = self._get_openai_client()
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await self._make_request(
|
||||
endpoint="embeddings",
|
||||
params=request_params,
|
||||
response_type=OpenAIEmbeddingsResponse,
|
||||
stream=False,
|
||||
)
|
||||
response = await client.embeddings.create(**request_params)
|
||||
return response # type: ignore
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue