diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a53ddf639..2007818e5 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig @@ -142,6 +143,19 @@ def inference_bedrock() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_nvidia() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAConfig().model_dump(), + ) + ], + ) + + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -175,6 +189,7 @@ INFERENCE_FIXTURES = [ "vllm_remote", "remote", "bedrock", + "nvidia", ] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 6e263432a..b35d947ec 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -361,7 +361,10 @@ class TestInference: for chunk in grouped[ChatCompletionResponseEventType.progress] ) first = grouped[ChatCompletionResponseEventType.progress][0] - assert first.event.delta.parse_status == ToolCallParseStatus.started + if not isinstance( + first.event.delta.content, ToolCall + ): # first chunk may contain entire call + assert first.event.delta.parse_status == ToolCallParseStatus.started last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason