mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 21:24:33 +00:00
apply change and fixed test cases
This commit is contained in:
parent
89ab2be302
commit
b197d3ce1c
5 changed files with 35 additions and 24 deletions
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||
from llama_stack.apis.inference import EmbeddingsResponse
|
||||
from llama_stack.apis.models import ModelType
|
||||
|
||||
# How to run this test:
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class TestModelRegistration:
|
|||
},
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="custom-model-2",
|
||||
metadata={
|
||||
|
|
|
|||
|
|
@ -383,6 +383,12 @@ class TestInference:
|
|||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
|
||||
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
|
||||
"-1B-" in inference_model or "-3B-" in inference_model
|
||||
):
|
||||
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
|
||||
pytest.skip("Sambanova's tool calling for lightweight models don't work")
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
|
@ -429,6 +435,9 @@ class TestInference:
|
|||
):
|
||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
if provider.__provider_spec__.provider_type == "remote::sambanova":
|
||||
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
|
||||
pytest.skip("Sambanova's tool calling for streaming doesn't work")
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class TestVisionModelInference:
|
|||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
content = "".join(
|
||||
chunk.event.delta
|
||||
chunk.event.delta.text
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
for expected_string in expected_strings:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue