mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 18:13:44 +00:00
simpler
This commit is contained in:
parent
6115679679
commit
b37eb24650
1 changed files with 41 additions and 144 deletions
|
|
@ -4,12 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import httpx
|
from openai import AsyncOpenAI
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
|
@ -45,46 +42,44 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
|
|
||||||
async def list_models(self) -> list[Model]:
|
async def list_models(self) -> list[Model]:
|
||||||
"""List models by calling the downstream /v1/models endpoint."""
|
"""List models by calling the downstream /v1/models endpoint."""
|
||||||
base_url = self._get_passthrough_url().rstrip("/")
|
client = self._get_openai_client()
|
||||||
api_key = self._get_passthrough_api_key()
|
|
||||||
|
|
||||||
url = f"{base_url}/v1/models"
|
response = await client.models.list()
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
# Convert from OpenAI format to Llama Stack Model format
|
||||||
response = await client.get(url, headers=headers, timeout=30.0)
|
models = []
|
||||||
response.raise_for_status()
|
for model_data in response.data:
|
||||||
|
downstream_model_id = model_data.id
|
||||||
|
custom_metadata = getattr(model_data, "custom_metadata", {}) or {}
|
||||||
|
|
||||||
response_data = response.json()
|
# Prefix identifier with provider ID for local registry
|
||||||
models_data = response_data["data"]
|
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
||||||
|
|
||||||
# Convert from OpenAI format to Llama Stack Model format
|
model = Model(
|
||||||
models = []
|
identifier=local_identifier,
|
||||||
for model_data in models_data:
|
provider_id=self.__provider_id__,
|
||||||
downstream_model_id = model_data["id"]
|
provider_resource_id=downstream_model_id,
|
||||||
custom_metadata = model_data.get("custom_metadata", {})
|
model_type=custom_metadata.get("model_type", "llm"),
|
||||||
|
metadata=custom_metadata,
|
||||||
|
)
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
# Prefix identifier with provider ID for local registry
|
return models
|
||||||
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
|
||||||
|
|
||||||
model = Model(
|
|
||||||
identifier=local_identifier,
|
|
||||||
provider_id=self.__provider_id__,
|
|
||||||
provider_resource_id=downstream_model_id,
|
|
||||||
model_type=custom_metadata.get("model_type", "llm"),
|
|
||||||
metadata=custom_metadata,
|
|
||||||
)
|
|
||||||
models.append(model)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
"""Passthrough should refresh models since they come from downstream dynamically."""
|
"""Passthrough should refresh models since they come from downstream dynamically."""
|
||||||
return self.config.refresh_models
|
return self.config.refresh_models
|
||||||
|
|
||||||
|
def _get_openai_client(self) -> AsyncOpenAI:
|
||||||
|
"""Get an AsyncOpenAI client configured for the downstream server."""
|
||||||
|
base_url = self._get_passthrough_url()
|
||||||
|
api_key = self._get_passthrough_api_key()
|
||||||
|
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=f"{base_url.rstrip('/')}/v1",
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_passthrough_url(self) -> str:
|
def _get_passthrough_url(self) -> str:
|
||||||
"""Get the passthrough URL from config or provider data."""
|
"""Get the passthrough URL from config or provider data."""
|
||||||
if self.config.url is not None:
|
if self.config.url is not None:
|
||||||
|
|
@ -109,130 +104,32 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
)
|
)
|
||||||
return provider_data.passthrough_api_key
|
return provider_data.passthrough_api_key
|
||||||
|
|
||||||
def _serialize_value(self, value):
|
|
||||||
"""Convert Pydantic models and enums to JSON-serializable values."""
|
|
||||||
if isinstance(value, BaseModel):
|
|
||||||
return json.loads(value.model_dump_json())
|
|
||||||
elif isinstance(value, Enum):
|
|
||||||
return value.value
|
|
||||||
elif isinstance(value, list):
|
|
||||||
return [self._serialize_value(item) for item in value]
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
|
||||||
else:
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def _make_request(
|
|
||||||
self,
|
|
||||||
endpoint: str,
|
|
||||||
params: dict,
|
|
||||||
response_type: type,
|
|
||||||
stream: bool = False,
|
|
||||||
):
|
|
||||||
"""Make an HTTP request to the passthrough endpoint."""
|
|
||||||
base_url = self._get_passthrough_url().rstrip("/")
|
|
||||||
api_key = self._get_passthrough_api_key()
|
|
||||||
|
|
||||||
url = f"{base_url}/v1/{endpoint.lstrip('/')}"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Serialize the request body
|
|
||||||
json_body = self._serialize_value(params)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._stream_request(url, headers, json_body, response_type)
|
|
||||||
else:
|
|
||||||
return await self._non_stream_request(url, headers, json_body, response_type)
|
|
||||||
|
|
||||||
async def _non_stream_request(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
headers: dict,
|
|
||||||
json_body: dict,
|
|
||||||
response_type: type,
|
|
||||||
):
|
|
||||||
"""Make a non-streaming HTTP request."""
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(url, json=json_body, headers=headers, timeout=30.0)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
response_data = response.json()
|
|
||||||
return response_type.model_validate(response_data)
|
|
||||||
|
|
||||||
async def _stream_request(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
headers: dict,
|
|
||||||
json_body: dict,
|
|
||||||
response_type: type,
|
|
||||||
) -> AsyncIterator:
|
|
||||||
"""Make a streaming HTTP request with Server-Sent Events parsing."""
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if line.startswith("data:"):
|
|
||||||
# Extract JSON after "data: " prefix
|
|
||||||
data_str = line[len("data:") :].strip()
|
|
||||||
|
|
||||||
# Skip empty lines or "[DONE]" marker
|
|
||||||
if not data_str or data_str == "[DONE]":
|
|
||||||
continue
|
|
||||||
|
|
||||||
data = json.loads(data_str)
|
|
||||||
|
|
||||||
# Fix OpenAI compatibility: finish_reason can be null in intermediate chunks
|
|
||||||
# but our Pydantic model may not accept null
|
|
||||||
if "choices" in data:
|
|
||||||
for choice in data["choices"]:
|
|
||||||
if choice.get("finish_reason") is None:
|
|
||||||
choice["finish_reason"] = ""
|
|
||||||
|
|
||||||
yield response_type.model_validate(data)
|
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
# params.model is already the provider_resource_id (router translated it)
|
"""Forward completion request to downstream using OpenAI client."""
|
||||||
|
client = self._get_openai_client()
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
response = await client.completions.create(**request_params)
|
||||||
return await self._make_request(
|
return response # type: ignore
|
||||||
endpoint="completions",
|
|
||||||
params=request_params,
|
|
||||||
response_type=OpenAICompletion,
|
|
||||||
stream=params.stream or False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
# params.model is already the provider_resource_id (router translated it)
|
"""Forward chat completion request to downstream using OpenAI client."""
|
||||||
|
client = self._get_openai_client()
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
response = await client.chat.completions.create(**request_params)
|
||||||
return await self._make_request(
|
return response # type: ignore
|
||||||
endpoint="chat/completions",
|
|
||||||
params=request_params,
|
|
||||||
response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion,
|
|
||||||
stream=params.stream or False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
# params.model is already the provider_resource_id (router translated it)
|
"""Forward embeddings request to downstream using OpenAI client."""
|
||||||
|
client = self._get_openai_client()
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
response = await client.embeddings.create(**request_params)
|
||||||
return await self._make_request(
|
return response # type: ignore
|
||||||
endpoint="embeddings",
|
|
||||||
params=request_params,
|
|
||||||
response_type=OpenAIEmbeddingsResponse,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue