diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 59b0c9e62..3f0423b74 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -496,10 +496,11 @@ "fastapi", "fire", "httpx", + "litellm", "matplotlib", + "mcp", "nltk", "numpy", - "openai", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-sdk", "pandas", @@ -514,7 +515,9 @@ "sentencepiece", "tqdm", "transformers", - "uvicorn" + "uvicorn", + "sentence-transformers --no-deps", + "torch torchvision --index-url https://download.pytorch.org/whl/cpu" ], "tgi": [ "aiohttp", diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index a7f738261..04beea8c1 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -16,10 +16,10 @@ The `llamastack/distribution-sambanova` distribution consists of the following p | API | Provider(s) | |-----|-------------| | agents | `inline::meta-reference` | -| inference | `remote::sambanova` | +| inference | `remote::sambanova`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | @@ -28,21 +28,21 @@ The `llamastack/distribution-sambanova` distribution consists of the following p The following environment variables can be configured: - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) -- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``) +- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``) ### Models The following models are available by default: -- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` -- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` -- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` -- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` -- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` -- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` -- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` -- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` -- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)` +- `sambanova/Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` +- `sambanova/Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` +- `sambanova/Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` +- `sambanova/Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` +- `sambanova/Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` +- `sambanova/Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` +- `sambanova/Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` +- `sambanova/Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` +- `sambanova/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)` ### Prerequisite: API Keys diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index d5f095740..8bf8a1961 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -241,11 +241,10 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="sambanova", - pip_packages=[ - "openai", - ], + pip_packages=["litellm"], module="llama_stack.providers.remote.inference.sambanova", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", ), ), remote_provider_spec( diff --git a/llama_stack/providers/remote/inference/sambanova/__init__.py b/llama_stack/providers/remote/inference/sambanova/__init__.py index 3e682e69c..a3a7b8fbd 100644 --- a/llama_stack/providers/remote/inference/sambanova/__init__.py +++ b/llama_stack/providers/remote/inference/sambanova/__init__.py @@ -4,16 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from llama_stack.apis.inference import Inference from .config import SambaNovaImplConfig -class SambaNovaProviderDataValidator(BaseModel): - sambanova_api_key: str - - -async def get_adapter_impl(config: SambaNovaImplConfig, _deps): +async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference: from .sambanova import SambaNovaInferenceAdapter assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index a30c29b74..f76262914 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -11,6 +11,13 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class SambaNovaProviderDataValidator(BaseModel): + sambanova_api_key: Optional[str] = Field( + default=None, + description="Sambanova Cloud API key", + ) + + @json_schema_type class SambaNovaImplConfig(BaseModel): url: str = Field( @@ -19,7 +26,7 @@ class SambaNovaImplConfig(BaseModel): ) api_key: Optional[str] = Field( default=None, - description="The SambaNova.ai API Key", + description="The SambaNova cloud API Key", ) @classmethod diff --git a/llama_stack/providers/remote/inference/sambanova/models.py b/llama_stack/providers/remote/inference/sambanova/models.py index 2231be22d..04212e331 100644 --- a/llama_stack/providers/remote/inference/sambanova/models.py +++ b/llama_stack/providers/remote/inference/sambanova/models.py @@ -11,39 +11,39 @@ from llama_stack.providers.utils.inference.model_registry import ( MODEL_ENTRIES = [ build_hf_repo_model_entry( - "Meta-Llama-3.1-8B-Instruct", + "sambanova/Meta-Llama-3.1-8B-Instruct", CoreModelId.llama3_1_8b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.1-70B-Instruct", + "sambanova/Meta-Llama-3.1-70B-Instruct", CoreModelId.llama3_1_70b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.1-405B-Instruct", + "sambanova/Meta-Llama-3.1-405B-Instruct", CoreModelId.llama3_1_405b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.2-1B-Instruct", + "sambanova/Meta-Llama-3.2-1B-Instruct", CoreModelId.llama3_2_1b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.2-3B-Instruct", + "sambanova/Meta-Llama-3.2-3B-Instruct", CoreModelId.llama3_2_3b_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-3.3-70B-Instruct", + "sambanova/Meta-Llama-3.3-70B-Instruct", CoreModelId.llama3_3_70b_instruct.value, ), build_hf_repo_model_entry( - "Llama-3.2-11B-Vision-Instruct", + "sambanova/Llama-3.2-11B-Vision-Instruct", CoreModelId.llama3_2_11b_vision_instruct.value, ), build_hf_repo_model_entry( - "Llama-3.2-90B-Vision-Instruct", + "sambanova/Llama-3.2-90B-Vision-Instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), build_hf_repo_model_entry( - "Meta-Llama-Guard-3-8B", + "sambanova/Meta-Llama-Guard-3-8B", CoreModelId.llama_guard_3_8b.value, ), ] diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index a5e17c2a3..731342da0 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,305 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from typing import AsyncGenerator, List, Optional +from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from openai import OpenAI - -from llama_stack.apis.common.content_types import ( - ImageContentItem, - InterleavedContent, - InterleavedContentItem, - TextContentItem, -) -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - CompletionMessage, - EmbeddingsResponse, - EmbeddingTaskType, - Inference, - LogProbConfig, - Message, - ResponseFormat, - SamplingParams, - StopReason, - SystemMessage, - TextTruncation, - ToolCall, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, - ToolResponseMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) -from llama_stack.providers.utils.inference.openai_compat import ( - process_chat_completion_stream_response, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, -) - -from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): - def __init__(self, config: SambaNovaImplConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) +class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): + _config: SambaNovaImplConfig + + def __init__(self, config: SambaNovaImplConfig): + LiteLLMOpenAIMixin.__init__( + self, + model_entries=MODEL_ENTRIES, + api_key_from_config=config.api_key, + provider_data_api_key_field="sambanova_api_key", + ) self.config = config - async def initialize(self) -> None: - return + async def initialize(self): + await super().initialize() - async def shutdown(self) -> None: - pass - - def _get_client(self) -> OpenAI: - return OpenAI(base_url=self.config.url, api_key=self.config.api_key) - - async def completion( - self, - model_id: str, - content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: - raise NotImplementedError() - - async def chat_completion( - self, - model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, - stream: Optional[bool] = False, - tool_config: Optional[ToolConfig] = None, - logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - request_sambanova = await self.convert_chat_completion_request(request) - - if stream: - return self._stream_chat_completion(request_sambanova) - else: - return await self._nonstream_chat_completion(request_sambanova) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - response = self._get_client().chat.completions.create(**request) - - choice = response.choices[0] - - result = ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", - stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason), - tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls), - ), - logprobs=None, - ) - - return result - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - async def _to_async_generator(): - streaming = self._get_client().chat.completions.create(**request) - for chunk in streaming: - yield chunk - - stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - - async def embeddings( - self, - model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, - ) -> EmbeddingsResponse: - raise NotImplementedError() - - async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict: - compatible_request = self.convert_sampling_params(request.sampling_params) - compatible_request["model"] = request.model - compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages) - compatible_request["stream"] = request.stream - compatible_request["logprobs"] = False - compatible_request["extra_headers"] = { - b"User-Agent": b"llama-stack: sambanova-inference-adapter", - } - compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools) - return compatible_request - - def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict: - params = {} - - if sampling_params: - params["frequency_penalty"] = sampling_params.repetition_penalty - - if sampling_params.max_tokens: - if legacy: - params["max_tokens"] = sampling_params.max_tokens - else: - params["max_completion_tokens"] = sampling_params.max_tokens - - if isinstance(sampling_params.strategy, TopPSamplingStrategy): - params["top_p"] = sampling_params.strategy.top_p - if isinstance(sampling_params.strategy, TopKSamplingStrategy): - params["extra_body"]["top_k"] = sampling_params.strategy.top_k - if isinstance(sampling_params.strategy, GreedySamplingStrategy): - params["temperature"] = 0.0 - - return params - - async def convert_to_sambanova_messages(self, messages: List[Message]) -> List[dict]: - conversation = [] - for message in messages: - content = {} - - content["content"] = await self.convert_to_sambanova_content(message) - - if isinstance(message, UserMessage): - content["role"] = "user" - elif isinstance(message, CompletionMessage): - content["role"] = "assistant" - tools = [] - for tool_call in message.tool_calls: - tools.append( - { - "id": tool_call.call_id, - "function": { - "name": tool_call.name, - "arguments": json.dumps(tool_call.arguments), - }, - "type": "function", - } - ) - content["tool_calls"] = tools - elif isinstance(message, ToolResponseMessage): - content["role"] = "tool" - content["tool_call_id"] = message.call_id - elif isinstance(message, SystemMessage): - content["role"] = "system" - - conversation.append(content) - - return conversation - - async def convert_to_sambanova_content(self, message: Message) -> dict: - async def _convert_content(content) -> dict: - if isinstance(content, ImageContentItem): - url = await convert_image_content_to_url(content, download=True) - # A fix to make sure the call sucess. - components = url.split(";base64") - url = f"{components[0].lower()};base64{components[1]}" - return { - "type": "image_url", - "image_url": {"url": url}, - } - else: - text = content.text if isinstance(content, TextContentItem) else content - assert isinstance(text, str) - return {"type": "text", "text": text} - - if isinstance(message.content, list): - # If it is a list, the text content should be wrapped in dict - content = [await _convert_content(c) for c in message.content] - else: - content = message.content - - return content - - def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]: - if tools is None: - return tools - - compatiable_tools = [] - - for tool in tools: - properties = {} - compatiable_required = [] - if tool.parameters: - for tool_key, tool_param in tool.parameters.items(): - properties[tool_key] = {"type": tool_param.param_type} - if tool_param.description: - properties[tool_key]["description"] = tool_param.description - if tool_param.default: - properties[tool_key]["default"] = tool_param.default - if tool_param.required: - compatiable_required.append(tool_key) - - compatiable_tool = { - "type": "function", - "function": { - "name": tool.tool_name, - "description": tool.description, - "parameters": { - "type": "object", - "properties": properties, - "required": compatiable_required, - }, - }, - } - - compatiable_tools.append(compatiable_tool) - - if len(compatiable_tools) > 0: - return compatiable_tools - return None - - def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason: - return { - "stop": StopReason.end_of_turn, - "length": StopReason.out_of_tokens, - "tool_calls": StopReason.end_of_message, - }.get(finish_reason, StopReason.end_of_turn) - - def convert_to_sambanova_tool_calls( - self, - tool_calls, - ) -> List[ToolCall]: - if not tool_calls: - return [] - - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - - compitable_tool_calls = [ - ToolCall( - call_id=call.id, - tool_name=call.function.name, - arguments=call_function_arguments, - ) - for call in tool_calls - ] - - return compitable_tool_calls + async def shutdown(self): + await super().shutdown() diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index ca5ffe618..9b09b9e90 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,9 +1,10 @@ version: '2' distribution_spec: - description: Use SambaNova.AI for running LLM inference + description: Use SambaNova for running LLM inference providers: inference: - remote::sambanova + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb @@ -19,4 +20,6 @@ distribution_spec: - remote::tavily-search - inline::code-interpreter - inline::rag-runtime + - remote::model-context-protocol + - remote::wolfram-alpha image_type: conda diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index cfa0cc194..828839cc9 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -14,6 +14,9 @@ providers: config: url: https://api.sambanova.ai/v1 api_key: ${env.SAMBANOVA_API_KEY} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -70,100 +73,111 @@ providers: - provider_id: rag-runtime provider_type: inline::rag-runtime config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db models: - metadata: {} - model_id: Meta-Llama-3.1-8B-Instruct + model_id: sambanova/Meta-Llama-3.1-8B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-8B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-8B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.1-70B-Instruct + model_id: sambanova/Meta-Llama-3.1-70B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-70B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-70B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-70B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.1-405B-Instruct + model_id: sambanova/Meta-Llama-3.1-405B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.1-405B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 provider_id: sambanova - provider_model_id: Meta-Llama-3.1-405B-Instruct + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.2-1B-Instruct + model_id: sambanova/Meta-Llama-3.2-1B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-1B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-1B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.2-3B-Instruct + model_id: sambanova/Meta-Llama-3.2-3B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-3B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.2-3B-Instruct + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-3.3-70B-Instruct + model_id: sambanova/Meta-Llama-3.3-70B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.3-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.3-70B-Instruct provider_id: sambanova - provider_model_id: Meta-Llama-3.3-70B-Instruct + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct model_type: llm - metadata: {} - model_id: Llama-3.2-11B-Vision-Instruct + model_id: sambanova/Llama-3.2-11B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-11B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-11B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-11B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} - model_id: Llama-3.2-90B-Vision-Instruct + model_id: sambanova/Llama-3.2-90B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-90B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-90B-Vision-Instruct provider_id: sambanova - provider_model_id: Llama-3.2-90B-Vision-Instruct + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm - metadata: {} - model_id: Meta-Llama-Guard-3-8B + model_id: sambanova/Meta-Llama-Guard-3-8B provider_id: sambanova - provider_model_id: Meta-Llama-Guard-3-8B + provider_model_id: sambanova/Meta-Llama-Guard-3-8B model_type: llm - metadata: {} model_id: meta-llama/Llama-Guard-3-8B provider_id: sambanova - provider_model_id: Meta-Llama-Guard-3-8B + provider_model_id: sambanova/Meta-Llama-Guard-3-8B model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B vector_dbs: [] @@ -177,5 +191,7 @@ tool_groups: provider_id: rag-runtime - toolgroup_id: builtin::code_interpreter provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha server: port: 8321 diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 0b7e82751..31c01e1eb 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -6,11 +6,16 @@ from pathlib import Path +from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( + ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES @@ -21,7 +26,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::sambanova"], + "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], @@ -31,16 +36,29 @@ def get_distribution_template() -> DistributionTemplate: "remote::tavily-search", "inline::code-interpreter", "inline::rag-runtime", + "remote::model-context-protocol", + "remote::wolfram-alpha", ], } name = "sambanova" - inference_provider = Provider( provider_id=name, provider_type=f"remote::{name}", config=SambaNovaImplConfig.sample_run_config(), ) - + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="sentence-transformers", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) vector_io_providers = [ Provider( provider_id="faiss", @@ -82,23 +100,27 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::code_interpreter", provider_id="code-interpreter", ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), ] return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova.AI for running LLM inference", - docker_image=None, + description="Use SambaNova for running LLM inference", + container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, available_models_by_provider=available_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider], + "inference": [inference_provider, embedding_provider], "vector_io": vector_io_providers, }, - default_models=default_models, + default_models=default_models + [embedding_model], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], default_tool_groups=default_tool_groups, ), @@ -110,7 +132,7 @@ def get_distribution_template() -> DistributionTemplate: ), "SAMBANOVA_API_KEY": ( "", - "SambaNova.AI API Key", + "SambaNova API Key", ), }, ) diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 7e3e14dbc..52aed2976 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -21,7 +21,13 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"): + if provider.provider_type in ( + "remote::openai", + "remote::anthropic", + "remote::gemini", + "remote::groq", + "remote::sambanova", + ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")