llama-stack-mirror/llama_stack/providers/remote/inference/passthrough/passthrough.py
Eric Huang a93130e323 test
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
Completes the refactoring started in previous commit by:

1. **Fix library client** (critical): Add logic to detect Pydantic model parameters
   and construct them properly from request bodies. The key fix is to NOT exclude
   any params when converting the body for Pydantic models - we need all fields
   to pass to the Pydantic constructor.

   Before: _convert_body excluded all params, leaving body empty for Pydantic construction
   After: Check for Pydantic params first, skip exclusion, construct model with full body

2. **Update remaining providers** to use new Pydantic-based signatures:
   - litellm_openai_mixin: Extract extra fields via __pydantic_extra__
   - databricks: Use TYPE_CHECKING import for params type
   - llama_openai_compat: Use TYPE_CHECKING import for params type
   - sentence_transformers: Update method signatures to use params

3. **Update unit tests** to use new Pydantic signature:
   - test_openai_mixin.py: Use OpenAIChatCompletionRequestParams

This fixes test failures where the library client was trying to construct
Pydantic models with empty dictionaries.
The previous fix had a bug: it called _convert_body() which only keeps fields
that match function parameter names. For Pydantic methods with signature:
  openai_chat_completion(params: OpenAIChatCompletionRequestParams)

The signature only has 'params', but the body has 'model', 'messages', etc.
So _convert_body() returned an empty dict.

Fix: Skip _convert_body() entirely for Pydantic params. Use the raw body
directly to construct the Pydantic model (after stripping NOT_GIVENs).

This properly fixes the ValidationError where required fields were missing.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
2025-10-09 13:53:33 -07:00

127 lines
4.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestParams,
OpenAICompletion,
OpenAICompletionRequestParams,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference):
def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
if self.config.url is not None:
passthrough_url = self.config.url
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_url:
raise ValueError(
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
)
passthrough_url = provider_data.passthrough_url
if self.config.api_key is not None:
passthrough_api_key = self.config.api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.passthrough_api_key:
raise ValueError(
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
)
passthrough_api_key = provider_data.passthrough_api_key
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
params: OpenAICompletionRequestParams,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(params.model)
# Update model with provider resource ID
params.model = model_obj.provider_resource_id
# Convert Pydantic model to dict, including extra fields
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestParams,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(params.model)
# Update model with provider resource ID
params.model = model_obj.provider_resource_id
# Convert Pydantic model to dict, including extra fields
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params