diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 6e9897a9b..9c203a8d0 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -8,16 +8,16 @@ import json from typing import AsyncGenerator from llama_models.datatypes import CoreModelId, SamplingStrategy - from llama_models.llama3.api.chat_format import ChatFormat - -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from openai import OpenAI +from llama_stack.apis.common.content_types import ( + ImageContentItem, + InterleavedContent, + TextContentItem, +) from llama_stack.apis.inference import * # noqa: F403 - from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -25,9 +25,8 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_stream_response, ) - from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_media_to_url, + convert_image_content_to_url, ) from .config import SambaNovaImplConfig @@ -86,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): async def completion( self, model_id: str, - content: InterleavedTextMedia, + content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, @@ -129,6 +128,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, request: ChatCompletionRequest ) -> ChatCompletionResponse: response = self._get_client().chat.completions.create(**request) + choice = response.choices[0] result = ChatCompletionResponse( @@ -163,7 +163,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedTextMedia], + contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() @@ -244,21 +244,19 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): async def convert_to_sambanova_content(self, message: Message) -> dict: async def _convert_content(content) -> dict: - if isinstance(content, ImageMedia): - download = False - if isinstance(content, ImageMedia) and isinstance(content.image, URL): - download = content.image.uri.startswith("https://") + if isinstance(content, ImageContentItem): + url = await convert_image_content_to_url(content, download=True) + # A fix to make sure the call sucess. + components = url.split(";base64") + url = f"{components[0].lower()};base64{components[1]}" return { "type": "image_url", - "image_url": { - "url": await convert_image_media_to_url( - content, download=download - ), - }, + "image_url": {"url": url}, } else: - assert isinstance(content, str) - return {"type": "text", "text": content} + text = content.text if isinstance(content, TextContentItem) else content + assert isinstance(text, str) + return {"type": "text", "text": text} if isinstance(message.content, list): # If it is a list, the text content should be wrapped in dict @@ -320,11 +318,14 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): if not tool_calls: return [] + for call in tool_calls: + call_function_arguments = json.loads(call.function.arguments) + compitable_tool_calls = [ ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call.function.arguments, + arguments=call_function_arguments, ) for call in tool_calls ] diff --git a/llama_stack/providers/tests/inference/test_embeddings.py b/llama_stack/providers/tests/inference/test_embeddings.py index bf09896c1..ca0276ed6 100644 --- a/llama_stack/providers/tests/inference/test_embeddings.py +++ b/llama_stack/providers/tests/inference/test_embeddings.py @@ -6,7 +6,8 @@ import pytest -from llama_stack.apis.inference import EmbeddingsResponse, ModelType +from llama_stack.apis.inference import EmbeddingsResponse +from llama_stack.apis.models import ModelType # How to run this test: # pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 3cd7b2496..96a34ec0e 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -59,7 +59,7 @@ class TestModelRegistration: }, ) - with pytest.raises(AssertionError) as exc_info: + with pytest.raises(ValueError) as exc_info: await models_impl.register_model( model_id="custom-model-2", metadata={ diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 932ae36e6..19bd30ec7 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -383,6 +383,12 @@ class TestInference: # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") + if provider.__provider_spec__.provider_type == "remote::sambanova" and ( + "-1B-" in inference_model or "-3B-" in inference_model + ): + # TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B + pytest.skip("Sambanova's tool calling for lightweight models don't work") + messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?", @@ -429,6 +435,9 @@ class TestInference: ): # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") + if provider.__provider_spec__.provider_type == "remote::sambanova": + # TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it) + pytest.skip("Sambanova's tool calling for streaming doesn't work") messages = sample_messages + [ UserMessage( diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index c4216d300..6374310f3 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -145,7 +145,7 @@ class TestVisionModelInference: assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 content = "".join( - chunk.event.delta + chunk.event.delta.text for chunk in grouped[ChatCompletionResponseEventType.progress] ) for expected_string in expected_strings: