From b2b00a216b957309501e26ce57a4e4837f9f5345 Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Tue, 6 May 2025 18:50:22 -0500 Subject: [PATCH] feat(providers): sambanova updated to use LiteLLM openai-compat (#1596) # What does this PR do? switch sambanova inference adaptor to LiteLLM usage to simplify integration and solve issues with current adaptor when streaming and tool calling, models and templates updated ## Test Plan pytest -s -v tests/integration/inference/test_text_inference.py --stack-config=sambanova --text-model=sambanova/Meta-Llama-3.3-70B-Instruct pytest -s -v tests/integration/inference/test_vision_inference.py --stack-config=sambanova --vision-model=sambanova/Llama-3.2-11B-Vision-Instruct --- .../self_hosted_distro/sambanova.md | 26 +- llama_stack/providers/registry/inference.py | 5 +- .../remote/inference/sambanova/__init__.py | 8 +- .../remote/inference/sambanova/config.py | 17 +- .../remote/inference/sambanova/models.py | 32 +- .../remote/inference/sambanova/sambanova.py | 466 ++++++++---------- llama_stack/templates/dependencies.json | 7 +- llama_stack/templates/dev/build.yaml | 1 + llama_stack/templates/dev/dev.py | 9 + llama_stack/templates/dev/run.yaml | 105 ++++ llama_stack/templates/sambanova/build.yaml | 5 +- llama_stack/templates/sambanova/run.yaml | 105 ++-- llama_stack/templates/sambanova/sambanova.py | 44 +- llama_stack/templates/verification/run.yaml | 88 ++-- .../inference/test_text_inference.py | 15 + 15 files changed, 529 insertions(+), 404 deletions(-) diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index 873c3075c..aaa8fd3cc 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -16,10 +16,10 @@ The `llamastack/distribution-sambanova` distribution consists of the following p | API | Provider(s) | |-----|-------------| | agents | `inline::meta-reference` | -| inference | `remote::sambanova` | +| inference | `remote::sambanova`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | @@ -28,22 +28,22 @@ The `llamastack/distribution-sambanova` distribution consists of the following p The following environment variables can be configured: - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`) -- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``) +- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``) ### Models The following models are available by default: -- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` -- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` -- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` -- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` -- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` -- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` -- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` -- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` -- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)` -- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)` +- `sambanova/Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` +- `sambanova/Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` +- `sambanova/Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` +- `sambanova/Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` +- `sambanova/Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` +- `sambanova/Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` +- `sambanova/Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` +- `sambanova/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)` +- `sambanova/Llama-4-Maverick-17B-128E-Instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)` +- `sambanova/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)` ### Prerequisite: API Keys diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index b0abc1818..7b49ef09b 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -280,11 +280,10 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="sambanova", - pip_packages=[ - "openai", - ], + pip_packages=["litellm"], module="llama_stack.providers.remote.inference.sambanova", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", ), ), remote_provider_spec( diff --git a/llama_stack/providers/remote/inference/sambanova/__init__.py b/llama_stack/providers/remote/inference/sambanova/__init__.py index 3e682e69c..a3a7b8fbd 100644 --- a/llama_stack/providers/remote/inference/sambanova/__init__.py +++ b/llama_stack/providers/remote/inference/sambanova/__init__.py @@ -4,16 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from llama_stack.apis.inference import Inference from .config import SambaNovaImplConfig -class SambaNovaProviderDataValidator(BaseModel): - sambanova_api_key: str - - -async def get_adapter_impl(config: SambaNovaImplConfig, _deps): +async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference: from .sambanova import SambaNovaInferenceAdapter assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index 8ca11de78..abbf9430f 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -6,25 +6,32 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr from llama_stack.schema_utils import json_schema_type +class SambaNovaProviderDataValidator(BaseModel): + sambanova_api_key: str | None = Field( + default=None, + description="Sambanova Cloud API key", + ) + + @json_schema_type class SambaNovaImplConfig(BaseModel): url: str = Field( default="https://api.sambanova.ai/v1", description="The URL for the SambaNova AI server", ) - api_key: str | None = Field( + api_key: SecretStr | None = Field( default=None, - description="The SambaNova.ai API Key", + description="The SambaNova cloud API Key", ) @classmethod - def sample_run_config(cls, **kwargs) -> dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: return { "url": "https://api.sambanova.ai/v1", - "api_key": "${env.SAMBANOVA_API_KEY}", + "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/sambanova/models.py b/llama_stack/providers/remote/inference/sambanova/models.py index 43041e94a..9954fa7a0 100644 --- a/llama_stack/providers/remote/inference/sambanova/models.py +++ b/llama_stack/providers/remote/inference/sambanova/models.py @@ -11,43 +11,43 @@ from llama_stack.providers.utils.inference.model_registry import ( MODEL_ENTRIES = [ build_hf_repo_model_entry( - "Meta-Llama-3.1-8B-Instruct", + "sambanova/Meta-Llama-3.1-8B-Instruct", CoreModelId.llama3_1_8b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.1-70B-Instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "Meta-Llama-3.1-405B-Instruct", + "sambanova/Meta-Llama-3.1-405B-Instruct", CoreModelId.llama3_1_405b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.2-1B-Instruct", + "sambanova/Meta-Llama-3.2-1B-Instruct", CoreModelId.llama3_2_1b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.2-3B-Instruct", + "sambanova/Meta-Llama-3.2-3B-Instruct", CoreModelId.llama3_2_3b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.3-70B-Instruct", + "sambanova/Meta-Llama-3.3-70B-Instruct", CoreModelId.llama3_3_70b_instruct.value, ), build_hf_repo_model_entry( - "Llama-3.2-11B-Vision-Instruct", + "sambanova/Llama-3.2-11B-Vision-Instruct", CoreModelId.llama3_2_11b_vision_instruct.value, ), build_hf_repo_model_entry( - "Llama-3.2-90B-Vision-Instruct", + "sambanova/Llama-3.2-90B-Vision-Instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-Guard-3-8B", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "Llama-4-Scout-17B-16E-Instruct", + "sambanova/Llama-4-Scout-17B-16E-Instruct", CoreModelId.llama4_scout_17b_16e_instruct.value, ), + build_hf_repo_model_entry( + "sambanova/Llama-4-Maverick-17B-128E-Instruct", + CoreModelId.llama4_maverick_17b_128e_instruct.value, + ), + build_hf_repo_model_entry( + "sambanova/Meta-Llama-Guard-3-8B", + CoreModelId.llama_guard_3_8b.value, + ), ] diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 3db95dcb4..d182aa1dc 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -5,305 +5,249 @@ # the root directory of this source tree. import json -from collections.abc import AsyncGenerator +from collections.abc import Iterable -from openai import OpenAI +from openai.types.chat import ( + ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( + ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( + ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, +) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) +from openai.types.chat import ( + ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, +) +from openai.types.chat import ( + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call_param import ( + Function as OpenAIFunction, +) from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, - InterleavedContentItem, TextContentItem, ) from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, CompletionMessage, - EmbeddingsResponse, - EmbeddingTaskType, - GreedySamplingStrategy, - Inference, - LogProbConfig, + JsonSchemaResponseFormat, Message, - ResponseFormat, - SamplingParams, - StopReason, SystemMessage, - TextTruncation, - ToolCall, ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ToolResponseMessage, - TopKSamplingStrategy, - TopPSamplingStrategy, UserMessage, ) -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import BuiltinTool +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, - OpenAICompletionToLlamaStackMixin, - process_chat_completion_stream_response, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, + convert_tooldef_to_openai_tool, + get_sampling_options, ) +from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference") -class SambaNovaInferenceAdapter( - ModelRegistryHelper, - Inference, - OpenAIChatCompletionToLlamaStackMixin, - OpenAICompletionToLlamaStackMixin, -): - def __init__(self, config: SambaNovaImplConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - self.config = config - async def initialize(self) -> None: - return +async def convert_message_to_openai_dict_with_b64_images( + message: Message | dict, +) -> OpenAIChatCompletionMessage: + """ + Convert a Message to an OpenAI API-compatible dictionary. + """ + # users can supply a dict instead of a Message object, we'll + # convert it to a Message object and proceed with some type safety. + if isinstance(message, dict): + if "role" not in message: + raise ValueError("role is required in message") + if message["role"] == "user": + message = UserMessage(**message) + elif message["role"] == "assistant": + message = CompletionMessage(**message) + elif message["role"] == "tool": + message = ToolResponseMessage(**message) + elif message["role"] == "system": + message = SystemMessage(**message) + else: + raise ValueError(f"Unsupported message role: {message['role']}") - async def shutdown(self) -> None: - pass - - def _get_client(self) -> OpenAI: - return OpenAI(base_url=self.config.url, api_key=self.config.api_key) - - async def completion( - self, - model_id: str, + # Map Llama Stack spec to OpenAI spec - + # str -> str + # {"type": "text", "text": ...} -> {"type": "text", "text": ...} + # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} + # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} + # List[...] -> List[...] + async def _convert_message_content( content: InterleavedContent, - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - ) -> AsyncGenerator: - raise NotImplementedError() - - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = ToolPromptFormat.json, - stream: bool | None = False, - tool_config: ToolConfig | None = None, - logprobs: LogProbConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - request_sambanova = await self.convert_chat_completion_request(request) - - if stream: - return self._stream_chat_completion(request_sambanova) - else: - return await self._nonstream_chat_completion(request_sambanova) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - response = self._get_client().chat.completions.create(**request) - - choice = response.choices[0] - - result = ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", - stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason), - tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls), - ), - logprobs=None, - ) - - return result - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - async def _to_async_generator(): - streaming = self._get_client().chat.completions.create(**request) - for chunk in streaming: - yield chunk - - stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - - async def embeddings( - self, - model_id: str, - contents: list[str] | list[InterleavedContentItem], - text_truncation: TextTruncation | None = TextTruncation.none, - output_dimension: int | None = None, - task_type: EmbeddingTaskType | None = None, - ) -> EmbeddingsResponse: - raise NotImplementedError() - - async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict: - compatible_request = self.convert_sampling_params(request.sampling_params) - compatible_request["model"] = request.model - compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages) - compatible_request["stream"] = request.stream - compatible_request["logprobs"] = False - compatible_request["extra_headers"] = { - b"User-Agent": b"llama-stack: sambanova-inference-adapter", - } - compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools) - return compatible_request - - def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict: - params = {} - - if sampling_params: - params["frequency_penalty"] = sampling_params.repetition_penalty - - if sampling_params.max_tokens: - if legacy: - params["max_tokens"] = sampling_params.max_tokens - else: - params["max_completion_tokens"] = sampling_params.max_tokens - - if isinstance(sampling_params.strategy, TopPSamplingStrategy): - params["top_p"] = sampling_params.strategy.top_p - if isinstance(sampling_params.strategy, TopKSamplingStrategy): - params["extra_body"]["top_k"] = sampling_params.strategy.top_k - if isinstance(sampling_params.strategy, GreedySamplingStrategy): - params["temperature"] = 0.0 - - return params - - async def convert_to_sambanova_messages(self, messages: list[Message]) -> list[dict]: - conversation = [] - for message in messages: - content = {} - - content["content"] = await self.convert_to_sambanova_content(message) - - if isinstance(message, UserMessage): - content["role"] = "user" - elif isinstance(message, CompletionMessage): - content["role"] = "assistant" - tools = [] - for tool_call in message.tool_calls: - tools.append( - { - "id": tool_call.call_id, - "function": { - "name": tool_call.name, - "arguments": json.dumps(tool_call.arguments), - }, - "type": "function", - } - ) - content["tool_calls"] = tools - elif isinstance(message, ToolResponseMessage): - content["role"] = "tool" - content["tool_call_id"] = message.call_id - elif isinstance(message, SystemMessage): - content["role"] = "system" - - conversation.append(content) - - return conversation - - async def convert_to_sambanova_content(self, message: Message) -> dict: - async def _convert_content(content) -> dict: - if isinstance(content, ImageContentItem): - url = await convert_image_content_to_url(content, download=True) - # A fix to make sure the call sucess. - components = url.split(";base64") - url = f"{components[0].lower()};base64{components[1]}" - return { - "type": "image_url", - "image_url": {"url": url}, - } + ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: + async def impl( + content_: InterleavedContent, + ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: + # Llama Stack and OpenAI spec match for str and text input + if isinstance(content_, str): + return content_ + elif isinstance(content_, TextContentItem): + return OpenAIChatCompletionContentPartTextParam( + type="text", + text=content_.text, + ) + elif isinstance(content_, ImageContentItem): + return OpenAIChatCompletionContentPartImageParam( + type="image_url", + image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_, download=True)), + ) + elif isinstance(content_, list): + return [await impl(item) for item in content_] else: - text = content.text if isinstance(content, TextContentItem) else content - assert isinstance(text, str) - return {"type": "text", "text": text} + raise ValueError(f"Unsupported content type: {type(content_)}") - if isinstance(message.content, list): - # If it is a list, the text content should be wrapped in dict - content = [await _convert_content(c) for c in message.content] + ret = await impl(content) + + # OpenAI*Message expects a str or list + if isinstance(ret, str) or isinstance(ret, list): + return ret else: - content = message.content + return [ret] - return content + out: OpenAIChatCompletionMessage = None + if isinstance(message, UserMessage): + out = OpenAIChatCompletionUserMessage( + role="user", + content=await _convert_message_content(message.content), + ) + elif isinstance(message, CompletionMessage): + out = OpenAIChatCompletionAssistantMessage( + role="assistant", + content=await _convert_message_content(message.content), + tool_calls=[ + OpenAIChatCompletionMessageToolCall( + id=tool.call_id, + function=OpenAIFunction( + name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, + arguments=json.dumps(tool.arguments), + ), + type="function", + ) + for tool in message.tool_calls + ] + or None, + ) + elif isinstance(message, ToolResponseMessage): + out = OpenAIChatCompletionToolMessage( + role="tool", + tool_call_id=message.call_id, + content=await _convert_message_content(message.content), + ) + elif isinstance(message, SystemMessage): + out = OpenAIChatCompletionSystemMessage( + role="system", + content=await _convert_message_content(message.content), + ) + else: + raise ValueError(f"Unsupported message type: {type(message)}") - def convert_to_sambanova_tool(self, tools: list[ToolDefinition]) -> list[dict]: - if tools is None: - return tools + return out - compatiable_tools = [] - for tool in tools: - properties = {} - compatiable_required = [] - if tool.parameters: - for tool_key, tool_param in tool.parameters.items(): - properties[tool_key] = {"type": tool_param.param_type} - if tool_param.description: - properties[tool_key]["description"] = tool_param.description - if tool_param.default: - properties[tool_key]["default"] = tool_param.default - if tool_param.required: - compatiable_required.append(tool_key) +class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): + _config: SambaNovaImplConfig - compatiable_tool = { - "type": "function", - "function": { - "name": tool.tool_name, - "description": tool.description, - "parameters": { - "type": "object", - "properties": properties, - "required": compatiable_required, - }, + def __init__(self, config: SambaNovaImplConfig): + self.config = config + LiteLLMOpenAIMixin.__init__( + self, + model_entries=MODEL_ENTRIES, + api_key_from_config=self.config.api_key, + provider_data_api_key_field="sambanova_api_key", + ) + + def _get_api_key(self) -> str: + config_api_key = self.config.api_key if self.config.api_key else None + if config_api_key: + return config_api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.sambanova_api_key: + raise ValueError( + 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' + ) + return provider_data.sambanova_api_key + + async def _get_params(self, request: ChatCompletionRequest) -> dict: + input_dict = {} + + input_dict["messages"] = [await convert_message_to_openai_dict_with_b64_images(m) for m in request.messages] + if fmt := request.response_format: + if not isinstance(fmt, JsonSchemaResponseFormat): + raise ValueError( + f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." + ) + + fmt = fmt.json_schema + name = fmt["title"] + del fmt["title"] + fmt["additionalProperties"] = False + + # Apply additionalProperties: False recursively to all objects + fmt = self._add_additional_properties_recursive(fmt) + + input_dict["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": name, + "schema": fmt, + "strict": True, }, } + if request.tools: + input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] + if request.tool_config.tool_choice: + input_dict["tool_choice"] = ( + request.tool_config.tool_choice.value + if isinstance(request.tool_config.tool_choice, ToolChoice) + else request.tool_config.tool_choice + ) - compatiable_tools.append(compatiable_tool) + provider_data = self.get_request_provider_data() + key_field = self.provider_data_api_key_field + if provider_data and getattr(provider_data, key_field, None): + api_key = getattr(provider_data, key_field) + else: + api_key = self._get_api_key() - if len(compatiable_tools) > 0: - return compatiable_tools - return None - - def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason: return { - "stop": StopReason.end_of_turn, - "length": StopReason.out_of_tokens, - "tool_calls": StopReason.end_of_message, - }.get(finish_reason, StopReason.end_of_turn) + "model": request.model, + "api_key": api_key, + "api_base": self.config.url, + **input_dict, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } - def convert_to_sambanova_tool_calls( - self, - tool_calls, - ) -> list[ToolCall]: - if not tool_calls: - return [] + async def initialize(self): + await super().initialize() - compitable_tool_calls = [ - ToolCall( - call_id=call.id, - tool_name=call.function.name, - arguments=json.loads(call.function.arguments), - arguments_json=call.function.arguments, - ) - for call in tool_calls - ] - - return compitable_tool_calls + async def shutdown(self): + await super().shutdown() diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index e8136b0c3..31f2b93f1 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -619,10 +619,11 @@ "fastapi", "fire", "httpx", + "litellm", "matplotlib", + "mcp", "nltk", "numpy", - "openai", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-sdk", "pandas", @@ -637,7 +638,9 @@ "sentencepiece", "tqdm", "transformers", - "uvicorn" + "uvicorn", + "sentence-transformers --no-deps", + "torch torchvision --index-url https://download.pytorch.org/whl/cpu" ], "tgi": [ "aiohttp", diff --git a/llama_stack/templates/dev/build.yaml b/llama_stack/templates/dev/build.yaml index df45f1319..afa1614bf 100644 --- a/llama_stack/templates/dev/build.yaml +++ b/llama_stack/templates/dev/build.yaml @@ -8,6 +8,7 @@ distribution_spec: - remote::anthropic - remote::gemini - remote::groq + - remote::sambanova - inline::sentence-transformers vector_io: - inline::sqlite-vec diff --git a/llama_stack/templates/dev/dev.py b/llama_stack/templates/dev/dev.py index 4cf8e7d22..76d5a1fb3 100644 --- a/llama_stack/templates/dev/dev.py +++ b/llama_stack/templates/dev/dev.py @@ -38,6 +38,10 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.models import ( MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, ) +from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig +from llama_stack.providers.remote.inference.sambanova.models import ( + MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, +) from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, @@ -77,6 +81,11 @@ def get_inference_providers() -> tuple[list[Provider], list[ModelInput]]: GROQ_MODEL_ENTRIES, GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"), ), + ( + "sambanova", + SAMBANOVA_MODEL_ENTRIES, + SambaNovaImplConfig.sample_run_config(api_key="${env.SAMBANOVA_API_KEY:}"), + ), ] inference_providers = [] available_models = {} diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml index b77650bfe..b98498d53 100644 --- a/llama_stack/templates/dev/run.yaml +++ b/llama_stack/templates/dev/run.yaml @@ -34,6 +34,11 @@ providers: config: url: https://api.groq.com api_key: ${env.GROQ_API_KEY:} + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY:} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -413,6 +418,106 @@ models: provider_id: groq provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-3.1-8B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-3.1-405B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-3.2-1B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-3.2-3B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-3.3-70B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Llama-3.2-11B-Vision-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Llama-3.2-90B-Vision-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Llama-4-Scout-17B-16E-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-Guard-3-8B + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 6fd5b2905..81d90f420 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,9 +1,10 @@ version: '2' distribution_spec: - description: Use SambaNova.AI for running LLM inference + description: Use SambaNova for running LLM inference providers: inference: - remote::sambanova + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb @@ -18,4 +19,6 @@ distribution_spec: - remote::brave-search - remote::tavily-search - inline::rag-runtime + - remote::model-context-protocol + - remote::wolfram-alpha image_type: conda diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 2bf2bf722..620d50307 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -14,6 +14,9 @@ providers: config: url: https://api.sambanova.ai/v1 api_key: ${env.SAMBANOVA_API_KEY} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -68,110 +71,122 @@ providers: - provider_id: rag-runtime provider_type: inline::rag-runtime config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: + api_key: ${env.WOLFRAM_ALPHA_API_KEY:} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db models: - metadata: {} - model_id: Meta-Llama-3.1-8B-Instruct + model_id: sambanova/Meta-Llama-3.1-8B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-8B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-8B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.1-70B-Instruct + model_id: sambanova/Meta-Llama-3.1-405B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-70B-Instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: sambanova - provider_model_id: Meta-Llama-3.1-70B-Instruct - model_type: llm -- metadata: {} - model_id: Meta-Llama-3.1-405B-Instruct - provider_id: sambanova - provider_model_id: Meta-Llama-3.1-405B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 provider_id: sambanova - provider_model_id: Meta-Llama-3.1-405B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.2-1B-Instruct + model_id: sambanova/Meta-Llama-3.2-1B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-1B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-1B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.2-3B-Instruct + model_id: sambanova/Meta-Llama-3.2-3B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-3B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-3B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.3-70B-Instruct + model_id: sambanova/Meta-Llama-3.3-70B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.3-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.3-70B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.3-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct model_type: llm - metadata: {} - model_id: Llama-3.2-11B-Vision-Instruct + model_id: sambanova/Llama-3.2-11B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-11B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-11B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-11B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} - model_id: Llama-3.2-90B-Vision-Instruct + model_id: sambanova/Llama-3.2-90B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-90B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-90B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-90B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-Guard-3-8B + model_id: sambanova/Llama-4-Scout-17B-16E-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-Guard-3-8B - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-8B - provider_id: sambanova - provider_model_id: Meta-Llama-Guard-3-8B - model_type: llm -- metadata: {} - model_id: Llama-4-Scout-17B-16E-Instruct - provider_id: sambanova - provider_model_id: Llama-4-Scout-17B-16E-Instruct + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct provider_id: sambanova - provider_model_id: Llama-4-Scout-17B-16E-Instruct + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct model_type: llm +- metadata: {} + model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: sambanova + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-Guard-3-8B + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: sambanova + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B vector_dbs: [] @@ -183,5 +198,7 @@ tool_groups: provider_id: tavily-search - toolgroup_id: builtin::rag provider_id: rag-runtime +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index f1862221b..2f8a0b08a 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -6,7 +6,16 @@ from pathlib import Path -from llama_stack.distribution.datatypes import Provider, ShieldInput, ToolGroupInput +from llama_stack.apis.models.models import ModelType +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES @@ -23,7 +32,7 @@ from llama_stack.templates.template import ( def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::sambanova"], + "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], @@ -32,16 +41,29 @@ def get_distribution_template() -> DistributionTemplate: "remote::brave-search", "remote::tavily-search", "inline::rag-runtime", + "remote::model-context-protocol", + "remote::wolfram-alpha", ], } name = "sambanova" - inference_provider = Provider( provider_id=name, provider_type=f"remote::{name}", config=SambaNovaImplConfig.sample_run_config(), ) - + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="sentence-transformers", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) vector_io_providers = [ Provider( provider_id="faiss", @@ -79,23 +101,27 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::rag", provider_id="rag-runtime", ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), ] return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova.AI for running LLM inference", - docker_image=None, + description="Use SambaNova for running LLM inference", + container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, available_models_by_provider=available_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider], + "inference": [inference_provider, embedding_provider], "vector_io": vector_io_providers, }, - default_models=default_models, + default_models=default_models + [embedding_model], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], default_tool_groups=default_tool_groups, ), @@ -107,7 +133,7 @@ def get_distribution_template() -> DistributionTemplate: ), "SAMBANOVA_API_KEY": ( "", - "SambaNova.AI API Key", + "SambaNova API Key", ), }, ) diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index fbe491453..a9bf74330 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -502,104 +502,104 @@ models: provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.1-8B-Instruct + model_id: sambanova/Meta-Llama-3.1-8B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.1-8B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.1-8B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.1-70B-Instruct + model_id: sambanova/Meta-Llama-3.1-405B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.1-70B-Instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.1-70B-Instruct - model_type: llm -- metadata: {} - model_id: Meta-Llama-3.1-405B-Instruct - provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.1-405B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.1-405B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.2-1B-Instruct + model_id: sambanova/Meta-Llama-3.2-1B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.2-1B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.2-1B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.2-3B-Instruct + model_id: sambanova/Meta-Llama-3.2-3B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.2-3B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.2-3B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.3-70B-Instruct + model_id: sambanova/Meta-Llama-3.3-70B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.3-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.3-70B-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-3.3-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct model_type: llm - metadata: {} - model_id: Llama-3.2-11B-Vision-Instruct + model_id: sambanova/Llama-3.2-11B-Vision-Instruct provider_id: sambanova-openai-compat - provider_model_id: Llama-3.2-11B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-11B-Vision-Instruct provider_id: sambanova-openai-compat - provider_model_id: Llama-3.2-11B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} - model_id: Llama-3.2-90B-Vision-Instruct + model_id: sambanova/Llama-3.2-90B-Vision-Instruct provider_id: sambanova-openai-compat - provider_model_id: Llama-3.2-90B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-90B-Vision-Instruct provider_id: sambanova-openai-compat - provider_model_id: Llama-3.2-90B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-Guard-3-8B + model_id: sambanova/Llama-4-Scout-17B-16E-Instruct provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-Guard-3-8B - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-8B - provider_id: sambanova-openai-compat - provider_model_id: Meta-Llama-Guard-3-8B - model_type: llm -- metadata: {} - model_id: Llama-4-Scout-17B-16E-Instruct - provider_id: sambanova-openai-compat - provider_model_id: Llama-4-Scout-17B-16E-Instruct + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct provider_id: sambanova-openai-compat - provider_model_id: Llama-4-Scout-17B-16E-Instruct + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + provider_id: sambanova-openai-compat + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: sambanova-openai-compat + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: sambanova/Meta-Llama-Guard-3-8B + provider_id: sambanova-openai-compat + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: sambanova-openai-compat + provider_model_id: sambanova/Meta-Llama-Guard-3-8B model_type: llm - metadata: {} model_id: llama3.1-8b diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index a3cfce4fd..a137d67cc 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -30,12 +30,25 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): "remote::anthropic", "remote::gemini", "remote::groq", + "remote::sambanova", ) or "openai-compat" in provider.provider_type ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") +def skip_if_model_doesnt_support_json_schema_structured_output(client_with_models, model_id): + models = {m.identifier: m for m in client_with_models.models.list()} + models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) + provider_id = models[model_id].provider_id + providers = {p.provider_id: p for p in client_with_models.providers.list()} + provider = providers[provider_id] + if provider.provider_type in ("remote::sambanova",): + pytest.skip( + f"Model {model_id} hosted by {provider.provider_type} doesn't support json_schema structured output" + ) + + def get_llama_model(client_with_models, model_id): models = {} for m in client_with_models.models.list(): @@ -384,6 +397,8 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod ], ) def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_json_schema_structured_output(client_with_models, text_model_id) + class NBAStats(BaseModel): year_for_draft: int num_seasons_in_nba: int