diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 8d38b4d4c..6e9897a9b 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -23,11 +23,12 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - process_chat_completion_response, process_chat_completion_stream_response, ) -from llama_stack.providers.utils.inference.prompt_adapter import convert_message_to_dict +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_image_media_to_url, +) from .config import SambaNovaImplConfig @@ -69,6 +70,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, model_aliases=MODEL_ALIASES, ) + self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -118,24 +120,38 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): ) request_sambanova = await self.convert_chat_completion_request(request) - client = OpenAI(base_url=self.config.url, api_key=self.config.api_key) if stream: - return self._stream_chat_completion(request_sambanova, client) + return self._stream_chat_completion(request_sambanova) else: - return await self._nonstream_chat_completion(request_sambanova, client) + return await self._nonstream_chat_completion(request_sambanova) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: OpenAI + self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - r = client.chat.completions.create(**request) - return process_chat_completion_response(r, self.formatter) + response = self._get_client().chat.completions.create(**request) + choice = response.choices[0] + + result = ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content or "", + stop_reason=self.convert_to_sambanova_finish_reason( + choice.finish_reason + ), + tool_calls=self.convert_to_sambanova_tool_calls( + choice.message.tool_calls + ), + ), + logprobs=None, + ) + + return result async def _stream_chat_completion( - self, request: ChatCompletionRequest, client: OpenAI + self, request: ChatCompletionRequest ) -> AsyncGenerator: async def _to_async_generator(): - s = client.chat.completions.create(**request) - for chunk in s: + streaming = self._get_client().chat.completions.create(**request) + for chunk in streaming: yield chunk stream = _to_async_generator() @@ -156,7 +172,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): ) -> dict: compatible_request = self.convert_sampling_params(request.sampling_params) compatible_request["model"] = request.model - compatible_request["messages"] = await self.convert_to_sambanova_message( + compatible_request["messages"] = await self.convert_to_sambanova_messages( request.messages ) compatible_request["stream"] = request.stream @@ -164,6 +180,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): compatible_request["extra_headers"] = { b"User-Agent": b"llama-stack: sambanova-inference-adapter", } + compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools) return compatible_request def convert_sampling_params( @@ -189,12 +206,15 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): return params - async def convert_to_sambanova_message(self, messages: List[Message]) -> List[dict]: + async def convert_to_sambanova_messages( + self, messages: List[Message] + ) -> List[dict]: conversation = [] for message in messages: - content = await convert_message_to_dict(message) + content = {} + + content["content"] = await self.convert_to_sambanova_content(message) - # Need to override role if isinstance(message, UserMessage): content["role"] = "user" elif isinstance(message, CompletionMessage): @@ -221,3 +241,92 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): conversation.append(content) return conversation + + async def convert_to_sambanova_content(self, message: Message) -> dict: + async def _convert_content(content) -> dict: + if isinstance(content, ImageMedia): + download = False + if isinstance(content, ImageMedia) and isinstance(content.image, URL): + download = content.image.uri.startswith("https://") + return { + "type": "image_url", + "image_url": { + "url": await convert_image_media_to_url( + content, download=download + ), + }, + } + else: + assert isinstance(content, str) + return {"type": "text", "text": content} + + if isinstance(message.content, list): + # If it is a list, the text content should be wrapped in dict + content = [await _convert_content(c) for c in message.content] + else: + content = message.content + + return content + + def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]: + if tools is None: + return tools + + compatiable_tools = [] + + for tool in tools: + properties = {} + compatiable_required = [] + if tool.parameters: + for tool_key, tool_param in tool.parameters.items(): + properties[tool_key] = {"type": tool_param.param_type} + if tool_param.description: + properties[tool_key]["description"] = tool_param.description + if tool_param.default: + properties[tool_key]["default"] = tool_param.default + if tool_param.required: + compatiable_required.append(tool_key) + + compatiable_tool = { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": compatiable_required, + }, + }, + } + + compatiable_tools.append(compatiable_tool) + + if len(compatiable_tools) > 0: + return compatiable_tools + return None + + def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason: + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + def convert_to_sambanova_tool_calls( + self, + tool_calls, + ) -> List[ToolCall]: + if not tool_calls: + return [] + + compitable_tool_calls = [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=call.function.arguments, + ) + for call in tool_calls + ] + + return compitable_tool_calls diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a427eef12..65008de66 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -20,6 +20,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig @@ -173,6 +174,24 @@ def inference_tgi() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_sambanova() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="sambanova", + provider_type="remote::sambanova", + config=SambaNovaImplConfig( + api_key=get_env_or_fail("SAMBANOVA_API_KEY"), + ).model_dump(), + ) + ], + provider_data=dict( + sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"), + ), + ) + + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -208,6 +227,7 @@ INFERENCE_FIXTURES = [ "bedrock", "nvidia", "tgi", + "sambanova", ] diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 2c222ffa1..0e043c765 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -24,7 +24,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): UserMessage(content=content), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -41,7 +41,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.brave_search), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -69,7 +69,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ], tool_prompt_format=ToolPromptFormat.json, ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -99,7 +99,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -121,7 +121,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.code_interpreter), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2, messages) self.assertTrue(messages[0].content.endswith(system_prompt)) diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 56fa4c075..44b5ae7ab 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -49,6 +49,7 @@ class TestVisionModelInference: "remote::fireworks", "remote::ollama", "remote::vllm", + "remote::sambanova", ): pytest.skip( "Other inference providers don't support vision chat completion() yet" @@ -83,6 +84,7 @@ class TestVisionModelInference: "remote::fireworks", "remote::ollama", "remote::vllm", + "remote::sambanova", ): pytest.skip( "Other inference providers don't support vision chat completion() yet"