From 6711fd4f5a0328a9c1c62b708671809ef3ebe723 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 31 Jul 2025 13:27:01 -0500 Subject: [PATCH] update SambaNovaInferenceAdapter to use _get_params from LiteLLMOpenAIMixin by adding extra params to the mixin --- .../remote/inference/sambanova/sambanova.py | 58 +------------------ .../utils/inference/litellm_openai_mixin.py | 12 +++- 2 files changed, 12 insertions(+), 58 deletions(-) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 1391a2c8f..91f6c471d 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -6,19 +6,9 @@ import requests -from llama_stack.apis.inference import ( - ChatCompletionRequest, - JsonSchemaResponseFormat, - ToolChoice, -) from llama_stack.apis.models import Model from llama_stack.log import get_logger from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, - get_sampling_options, -) from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES @@ -39,54 +29,10 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", openai_compat_api_base=self.config.url, + download_images=True, # SambaNova requires base64 image encoding + json_schema_strict=False, # SambaNova doesn't support strict=True yet ) - async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} - - input_dict["messages"] = [ - await convert_message_to_openai_dict_new(m, download_images=True) for m in request.messages - ] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt, - "strict": False, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) - - return { - "model": request.model, - "api_key": self.get_api_key(), - "api_base": self.api_base, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - async def register_model(self, model: Model) -> Model: model_id = self.get_provider_model_id(model.provider_resource_id) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index e9a41fcf3..befb4b092 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -72,6 +72,8 @@ class LiteLLMOpenAIMixin( api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, + download_images: bool = False, + json_schema_strict: bool = True, ): """ Initialize the LiteLLMOpenAIMixin. @@ -81,6 +83,8 @@ class LiteLLMOpenAIMixin( :param provider_data_api_key_field: The field in the provider data that contains the API key. :param litellm_provider_name: The name of the provider, used for model lookups. :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. + :param download_images: Whether to download images and convert to base64 for message conversion. + :param json_schema_strict: Whether to use strict mode for JSON schema validation. """ ModelRegistryHelper.__init__(self, model_entries) @@ -88,6 +92,8 @@ class LiteLLMOpenAIMixin( self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field self.api_base = openai_compat_api_base + self.download_images = download_images + self.json_schema_strict = json_schema_strict if openai_compat_api_base: self.is_openai_compat = True @@ -206,7 +212,9 @@ class LiteLLMOpenAIMixin( async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} - input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] + input_dict["messages"] = [ + await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages + ] if fmt := request.response_format: if not isinstance(fmt, JsonSchemaResponseFormat): raise ValueError( @@ -226,7 +234,7 @@ class LiteLLMOpenAIMixin( "json_schema": { "name": name, "schema": fmt, - "strict": True, + "strict": self.json_schema_strict, }, } if request.tools: