mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
apply change and fixed test cases
This commit is contained in:
parent
89ab2be302
commit
b197d3ce1c
5 changed files with 35 additions and 24 deletions
|
@ -8,16 +8,16 @@ import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId, SamplingStrategy
|
from llama_models.datatypes import CoreModelId, SamplingStrategy
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
ImageContentItem,
|
||||||
|
InterleavedContent,
|
||||||
|
TextContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -25,9 +25,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_content_to_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import SambaNovaImplConfig
|
from .config import SambaNovaImplConfig
|
||||||
|
@ -86,7 +85,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -129,6 +128,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
response = self._get_client().chat.completions.create(**request)
|
response = self._get_client().chat.completions.create(**request)
|
||||||
|
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|
||||||
result = ChatCompletionResponse(
|
result = ChatCompletionResponse(
|
||||||
|
@ -163,7 +163,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -244,21 +244,19 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
async def convert_to_sambanova_content(self, message: Message) -> dict:
|
async def convert_to_sambanova_content(self, message: Message) -> dict:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
if isinstance(content, ImageMedia):
|
if isinstance(content, ImageContentItem):
|
||||||
download = False
|
url = await convert_image_content_to_url(content, download=True)
|
||||||
if isinstance(content, ImageMedia) and isinstance(content.image, URL):
|
# A fix to make sure the call sucess.
|
||||||
download = content.image.uri.startswith("https://")
|
components = url.split(";base64")
|
||||||
|
url = f"{components[0].lower()};base64{components[1]}"
|
||||||
return {
|
return {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": url},
|
||||||
"url": await convert_image_media_to_url(
|
|
||||||
content, download=download
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
assert isinstance(content, str)
|
text = content.text if isinstance(content, TextContentItem) else content
|
||||||
return {"type": "text", "text": content}
|
assert isinstance(text, str)
|
||||||
|
return {"type": "text", "text": text}
|
||||||
|
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
# If it is a list, the text content should be wrapped in dict
|
# If it is a list, the text content should be wrapped in dict
|
||||||
|
@ -320,11 +318,14 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
for call in tool_calls:
|
||||||
|
call_function_arguments = json.loads(call.function.arguments)
|
||||||
|
|
||||||
compitable_tool_calls = [
|
compitable_tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=call.id,
|
call_id=call.id,
|
||||||
tool_name=call.function.name,
|
tool_name=call.function.name,
|
||||||
arguments=call.function.arguments,
|
arguments=call_function_arguments,
|
||||||
)
|
)
|
||||||
for call in tool_calls
|
for call in tool_calls
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
from llama_stack.apis.inference import EmbeddingsResponse
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||||
|
|
|
@ -59,7 +59,7 @@ class TestModelRegistration:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
await models_impl.register_model(
|
await models_impl.register_model(
|
||||||
model_id="custom-model-2",
|
model_id="custom-model-2",
|
||||||
metadata={
|
metadata={
|
||||||
|
|
|
@ -383,6 +383,12 @@ class TestInference:
|
||||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||||
|
|
||||||
|
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
|
||||||
|
"-1B-" in inference_model or "-3B-" in inference_model
|
||||||
|
):
|
||||||
|
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
|
||||||
|
pytest.skip("Sambanova's tool calling for lightweight models don't work")
|
||||||
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What's the weather like in San Francisco?",
|
content="What's the weather like in San Francisco?",
|
||||||
|
@ -429,6 +435,9 @@ class TestInference:
|
||||||
):
|
):
|
||||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||||
|
if provider.__provider_spec__.provider_type == "remote::sambanova":
|
||||||
|
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
|
||||||
|
pytest.skip("Sambanova's tool calling for streaming doesn't work")
|
||||||
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
|
|
@ -145,7 +145,7 @@ class TestVisionModelInference:
|
||||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
content = "".join(
|
content = "".join(
|
||||||
chunk.event.delta
|
chunk.event.delta.text
|
||||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
)
|
)
|
||||||
for expected_string in expected_strings:
|
for expected_string in expected_strings:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue