apply change and fixed test cases

This commit is contained in:
Edward Ma 2025-01-14 13:24:08 -08:00
parent 89ab2be302
commit b197d3ce1c
5 changed files with 35 additions and 24 deletions

View file

@ -8,16 +8,16 @@ import json
from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId, SamplingStrategy
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@ -25,9 +25,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
convert_image_content_to_url,
)
from .config import SambaNovaImplConfig
@ -86,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -129,6 +128,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
@ -163,7 +163,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
@ -244,21 +244,19 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
download = False
if isinstance(content, ImageMedia) and isinstance(content.image, URL):
download = content.image.uri.startswith("https://")
if isinstance(content, ImageContentItem):
url = await convert_image_content_to_url(content, download=True)
# A fix to make sure the call sucess.
components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return {
"type": "image_url",
"image_url": {
"url": await convert_image_media_to_url(
content, download=download
),
},
"image_url": {"url": url},
}
else:
assert isinstance(content, str)
return {"type": "text", "text": content}
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
@ -320,11 +318,14 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls:
return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call.function.arguments,
arguments=call_function_arguments,
)
for call in tool_calls
]