apply change and fixed test cases

This commit is contained in:
Edward Ma 2025-01-14 13:24:08 -08:00
parent 89ab2be302
commit b197d3ce1c
5 changed files with 35 additions and 24 deletions

View file

@ -8,16 +8,16 @@ import json
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId, SamplingStrategy from llama_models.datatypes import CoreModelId, SamplingStrategy
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
@ -25,9 +25,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url, convert_image_content_to_url,
) )
from .config import SambaNovaImplConfig from .config import SambaNovaImplConfig
@ -86,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -129,6 +128,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request) response = self._get_client().chat.completions.create(**request)
choice = response.choices[0] choice = response.choices[0]
result = ChatCompletionResponse( result = ChatCompletionResponse(
@ -163,7 +163,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()
@ -244,21 +244,19 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def convert_to_sambanova_content(self, message: Message) -> dict: async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia): if isinstance(content, ImageContentItem):
download = False url = await convert_image_content_to_url(content, download=True)
if isinstance(content, ImageMedia) and isinstance(content.image, URL): # A fix to make sure the call sucess.
download = content.image.uri.startswith("https://") components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": url},
"url": await convert_image_media_to_url(
content, download=download
),
},
} }
else: else:
assert isinstance(content, str) text = content.text if isinstance(content, TextContentItem) else content
return {"type": "text", "text": content} assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list): if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict # If it is a list, the text content should be wrapped in dict
@ -320,11 +318,14 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls: if not tool_calls:
return [] return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [ compitable_tool_calls = [
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call.function.arguments, arguments=call_function_arguments,
) )
for call in tool_calls for call in tool_calls
] ]

View file

@ -6,7 +6,8 @@
import pytest import pytest
from llama_stack.apis.inference import EmbeddingsResponse, ModelType from llama_stack.apis.inference import EmbeddingsResponse
from llama_stack.apis.models import ModelType
# How to run this test: # How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py # pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py

View file

@ -59,7 +59,7 @@ class TestModelRegistration:
}, },
) )
with pytest.raises(AssertionError) as exc_info: with pytest.raises(ValueError) as exc_info:
await models_impl.register_model( await models_impl.register_model(
model_id="custom-model-2", model_id="custom-model-2",
metadata={ metadata={

View file

@ -383,6 +383,12 @@ class TestInference:
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
"-1B-" in inference_model or "-3B-" in inference_model
):
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
pytest.skip("Sambanova's tool calling for lightweight models don't work")
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(
content="What's the weather like in San Francisco?", content="What's the weather like in San Francisco?",
@ -429,6 +435,9 @@ class TestInference:
): ):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova":
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
pytest.skip("Sambanova's tool calling for streaming doesn't work")
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(

View file

@ -145,7 +145,7 @@ class TestVisionModelInference:
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
content = "".join( content = "".join(
chunk.event.delta chunk.event.delta.text
for chunk in grouped[ChatCompletionResponseEventType.progress] for chunk in grouped[ChatCompletionResponseEventType.progress]
) )
for expected_string in expected_strings: for expected_string in expected_strings: