apply change and fixed test cases

This commit is contained in:
Edward Ma 2025-01-14 13:24:08 -08:00
parent 89ab2be302
commit b197d3ce1c
5 changed files with 35 additions and 24 deletions

View file

@ -6,7 +6,8 @@
import pytest
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
from llama_stack.apis.inference import EmbeddingsResponse
from llama_stack.apis.models import ModelType
# How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py

View file

@ -59,7 +59,7 @@ class TestModelRegistration:
},
)
with pytest.raises(AssertionError) as exc_info:
with pytest.raises(ValueError) as exc_info:
await models_impl.register_model(
model_id="custom-model-2",
metadata={

View file

@ -383,6 +383,12 @@ class TestInference:
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
"-1B-" in inference_model or "-3B-" in inference_model
):
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
pytest.skip("Sambanova's tool calling for lightweight models don't work")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -429,6 +435,9 @@ class TestInference:
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova":
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
pytest.skip("Sambanova's tool calling for streaming doesn't work")
messages = sample_messages + [
UserMessage(

View file

@ -145,7 +145,7 @@ class TestVisionModelInference:
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
content = "".join(
chunk.event.delta
chunk.event.delta.text
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
for expected_string in expected_strings: