mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
refactor(passthrough): replace client with httpx, simplify implementation
- Use httpx directly instead of AsyncLlamaStackClient - Remove ModelRegistryHelper dependency (unused) - Consolidate to auth_credential field via RemoteInferenceProviderConfig - Implement list_models() to fetch from downstream /v1/models - Implement all inference methods (completion, chat, embeddings) - Fix provider data validator field names - Add SSE parsing for streaming responses
This commit is contained in:
parent
c899b50723
commit
1ff6eeb434
3 changed files with 177 additions and 80 deletions
|
|
@ -10,8 +10,8 @@ from .config import PassthroughImplConfig
|
||||||
|
|
||||||
|
|
||||||
class PassthroughProviderDataValidator(BaseModel):
|
class PassthroughProviderDataValidator(BaseModel):
|
||||||
url: str
|
passthrough_url: str
|
||||||
api_key: str
|
passthrough_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
|
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
@ -19,11 +19,6 @@ class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||||
description="The URL for the passthrough endpoint",
|
description="The URL for the passthrough endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key: SecretStr | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API Key for the passthrouth endpoint",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from enum import Enum
|
||||||
|
|
||||||
from llama_stack_client import AsyncLlamaStackClient
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
|
@ -20,15 +22,13 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
||||||
|
|
||||||
from .config import PassthroughImplConfig
|
from .config import PassthroughImplConfig
|
||||||
|
|
||||||
|
|
||||||
class PassthroughInferenceAdapter(Inference):
|
class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
|
@ -37,86 +37,188 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _get_client(self) -> AsyncLlamaStackClient:
|
async def list_models(self) -> list[Model]:
|
||||||
passthrough_url = None
|
"""List models by calling the downstream /v1/models endpoint."""
|
||||||
passthrough_api_key = None
|
base_url = self._get_passthrough_url().rstrip("/")
|
||||||
provider_data = None
|
api_key = self._get_passthrough_api_key()
|
||||||
|
|
||||||
|
url = f"{base_url}/v1/models"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url, headers=headers, timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
# The response should be a list of Model objects
|
||||||
|
return [Model.model_validate(m) for m in response_data]
|
||||||
|
|
||||||
|
async def should_refresh_models(self) -> bool:
|
||||||
|
"""Passthrough should refresh models since they come from downstream dynamically."""
|
||||||
|
return self.config.refresh_models
|
||||||
|
|
||||||
|
def _get_passthrough_url(self) -> str:
|
||||||
|
"""Get the passthrough URL from config or provider data."""
|
||||||
if self.config.url is not None:
|
if self.config.url is not None:
|
||||||
passthrough_url = self.config.url
|
return self.config.url
|
||||||
else:
|
|
||||||
provider_data = self.get_request_provider_data()
|
|
||||||
if provider_data is None or not provider_data.passthrough_url:
|
|
||||||
raise ValueError(
|
|
||||||
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
|
|
||||||
)
|
|
||||||
passthrough_url = provider_data.passthrough_url
|
|
||||||
|
|
||||||
if self.config.api_key is not None:
|
provider_data = self.get_request_provider_data()
|
||||||
passthrough_api_key = self.config.api_key.get_secret_value()
|
if provider_data is None:
|
||||||
else:
|
raise ValueError(
|
||||||
provider_data = self.get_request_provider_data()
|
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
|
||||||
if provider_data is None or not provider_data.passthrough_api_key:
|
)
|
||||||
raise ValueError(
|
return provider_data.passthrough_url
|
||||||
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
|
|
||||||
)
|
|
||||||
passthrough_api_key = provider_data.passthrough_api_key
|
|
||||||
|
|
||||||
return AsyncLlamaStackClient(
|
def _get_passthrough_api_key(self) -> str:
|
||||||
base_url=passthrough_url,
|
"""Get the passthrough API key from config or provider data."""
|
||||||
api_key=passthrough_api_key,
|
if self.config.auth_credential is not None:
|
||||||
provider_data=provider_data,
|
return self.config.auth_credential.get_secret_value()
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.passthrough_api_key
|
||||||
|
|
||||||
|
def _serialize_value(self, value):
|
||||||
|
"""Convert Pydantic models and enums to JSON-serializable values."""
|
||||||
|
if isinstance(value, BaseModel):
|
||||||
|
return json.loads(value.model_dump_json())
|
||||||
|
elif isinstance(value, Enum):
|
||||||
|
return value.value
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [self._serialize_value(item) for item in value]
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def _make_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
params: dict,
|
||||||
|
response_type: type,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
"""Make an HTTP request to the passthrough endpoint."""
|
||||||
|
base_url = self._get_passthrough_url().rstrip("/")
|
||||||
|
api_key = self._get_passthrough_api_key()
|
||||||
|
|
||||||
|
url = f"{base_url}/v1/{endpoint.lstrip('/')}"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Serialize the request body
|
||||||
|
json_body = self._serialize_value(params)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_request(url, headers, json_body, response_type)
|
||||||
|
else:
|
||||||
|
return await self._non_stream_request(url, headers, json_body, response_type)
|
||||||
|
|
||||||
|
async def _non_stream_request(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
json_body: dict,
|
||||||
|
response_type: type,
|
||||||
|
):
|
||||||
|
"""Make a non-streaming HTTP request."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(url, json=json_body, headers=headers, timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
return response_type.model_validate(response_data)
|
||||||
|
|
||||||
|
async def _stream_request(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
json_body: dict,
|
||||||
|
response_type: type,
|
||||||
|
) -> AsyncIterator:
|
||||||
|
"""Make a streaming HTTP request with Server-Sent Events parsing."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.stream("POST", url, json=json_body, headers=headers, timeout=30.0) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data:"):
|
||||||
|
# Extract JSON after "data: " prefix
|
||||||
|
data_str = line[len("data:") :].strip()
|
||||||
|
|
||||||
|
# Skip empty lines or "[DONE]" marker
|
||||||
|
if not data_str or data_str == "[DONE]":
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
yield response_type.model_validate(data)
|
||||||
|
except Exception:
|
||||||
|
# Log and skip malformed chunks
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
# Create a copy with the provider's model ID
|
||||||
|
params = params.model_copy()
|
||||||
|
params.model = model_obj.provider_resource_id
|
||||||
|
|
||||||
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
return await self._make_request(
|
||||||
|
endpoint="completions",
|
||||||
|
params=request_params,
|
||||||
|
response_type=OpenAICompletion,
|
||||||
|
stream=params.stream or False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
# Create a copy with the provider's model ID
|
||||||
|
params = params.model_copy()
|
||||||
|
params.model = model_obj.provider_resource_id
|
||||||
|
|
||||||
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
return await self._make_request(
|
||||||
|
endpoint="chat/completions",
|
||||||
|
params=request_params,
|
||||||
|
response_type=OpenAIChatCompletionChunk if params.stream else OpenAIChatCompletion,
|
||||||
|
stream=params.stream or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
client = self._get_client()
|
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
# Create a copy with the provider's model ID
|
||||||
params = params.model_copy()
|
params = params.model_copy()
|
||||||
params.model = model_obj.provider_resource_id
|
params.model = model_obj.provider_resource_id
|
||||||
|
|
||||||
request_params = params.model_dump(exclude_none=True)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
return await client.inference.openai_completion(**request_params)
|
return await self._make_request(
|
||||||
|
endpoint="embeddings",
|
||||||
async def openai_chat_completion(
|
params=request_params,
|
||||||
self,
|
response_type=OpenAIEmbeddingsResponse,
|
||||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
stream=False,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
)
|
||||||
client = self._get_client()
|
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
|
||||||
|
|
||||||
params = params.model_copy()
|
|
||||||
params.model = model_obj.provider_resource_id
|
|
||||||
|
|
||||||
request_params = params.model_dump(exclude_none=True)
|
|
||||||
|
|
||||||
return await client.inference.openai_chat_completion(**request_params)
|
|
||||||
|
|
||||||
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
json_params = {}
|
|
||||||
for key, value in request_params.items():
|
|
||||||
json_input = convert_pydantic_to_json_value(value)
|
|
||||||
if isinstance(json_input, dict):
|
|
||||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
|
||||||
elif isinstance(json_input, list):
|
|
||||||
json_input = [x for x in json_input if x is not None]
|
|
||||||
new_input = []
|
|
||||||
for x in json_input:
|
|
||||||
if isinstance(x, dict):
|
|
||||||
x = {k: v for k, v in x.items() if v is not None}
|
|
||||||
new_input.append(x)
|
|
||||||
json_input = new_input
|
|
||||||
|
|
||||||
json_params[key] = json_input
|
|
||||||
|
|
||||||
return json_params
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue