diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 8ba705f59..96469acac 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,178 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from collections.abc import Iterable - -import requests -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call_param import ( - Function as OpenAIFunction, -) - -from llama_stack.apis.common.content_types import ( - ImageContentItem, - InterleavedContent, - TextContentItem, -) -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - JsonSchemaResponseFormat, - Message, - SystemMessage, - ToolChoice, - ToolResponseMessage, - UserMessage, -) -from llama_stack.apis.models import Model -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import ( - convert_tooldef_to_openai_tool, - get_sampling_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") - - -async def convert_message_to_openai_dict_with_b64_images( - message: Message | dict, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_, download=True)), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage = None - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - tool_calls=[ - OpenAIChatCompletionMessageToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, - arguments=json.dumps(tool.arguments), - ), - type="function", - ) - for tool in message.tool_calls - ] - or None, - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): - _config: SambaNovaImplConfig - def __init__(self, config: SambaNovaImplConfig): self.config = config self.environment_available_models = [] @@ -185,89 +20,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): litellm_provider_name="sambanova", api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", + openai_compat_api_base=self.config.url, + download_images=True, # SambaNova requires base64 image encoding + json_schema_strict=False, # SambaNova doesn't support strict=True yet ) - - def _get_api_key(self) -> str: - config_api_key = self.config.api_key if self.config.api_key else None - if config_api_key: - return config_api_key.get_secret_value() - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.sambanova_api_key: - raise ValueError( - 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' - ) - return provider_data.sambanova_api_key - - async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} - - input_dict["messages"] = [await convert_message_to_openai_dict_with_b64_images(m) for m in request.messages] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt, - "strict": False, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) - - provider_data = self.get_request_provider_data() - key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self._get_api_key() - - return { - "model": request.model, - "api_key": api_key, - "api_base": self.config.url, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - - async def register_model(self, model: Model) -> Model: - model_id = self.get_provider_model_id(model.provider_resource_id) - - list_models_url = self.config.url + "/models" - if len(self.environment_available_models) == 0: - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] - - if model_id.split("sambanova/")[-1] not in self.environment_available_models: - logger.warning(f"Model {model_id} not available in {list_models_url}") - return model - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index e9a41fcf3..befb4b092 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -72,6 +72,8 @@ class LiteLLMOpenAIMixin( api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, + download_images: bool = False, + json_schema_strict: bool = True, ): """ Initialize the LiteLLMOpenAIMixin. @@ -81,6 +83,8 @@ class LiteLLMOpenAIMixin( :param provider_data_api_key_field: The field in the provider data that contains the API key. :param litellm_provider_name: The name of the provider, used for model lookups. :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. + :param download_images: Whether to download images and convert to base64 for message conversion. + :param json_schema_strict: Whether to use strict mode for JSON schema validation. """ ModelRegistryHelper.__init__(self, model_entries) @@ -88,6 +92,8 @@ class LiteLLMOpenAIMixin( self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field self.api_base = openai_compat_api_base + self.download_images = download_images + self.json_schema_strict = json_schema_strict if openai_compat_api_base: self.is_openai_compat = True @@ -206,7 +212,9 @@ class LiteLLMOpenAIMixin( async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} - input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] + input_dict["messages"] = [ + await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages + ] if fmt := request.response_format: if not isinstance(fmt, JsonSchemaResponseFormat): raise ValueError( @@ -226,7 +234,7 @@ class LiteLLMOpenAIMixin( "json_schema": { "name": name, "schema": fmt, - "strict": True, + "strict": self.json_schema_strict, }, } if request.tools: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 47144ee0e..e6e5ccc8a 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -564,6 +564,7 @@ class UnparseableToolCall(BaseModel): async def convert_message_to_openai_dict_new( message: Message | dict, + download_images: bool = False, ) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. @@ -607,7 +608,9 @@ async def convert_message_to_openai_dict_new( elif isinstance(content_, ImageContentItem): return OpenAIChatCompletionContentPartImageParam( type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)), + image_url=OpenAIImageURL( + url=await convert_image_content_to_url(content_, download=download_images) + ), ) elif isinstance(content_, list): return [await impl(item) for item in content_]