From 5ab3773577081701017f1e3ce7fb7931c0614764 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 21 Nov 2024 16:03:55 -0500 Subject: [PATCH] add nvidia provider for inference tests --- llama_stack/providers/tests/inference/fixtures.py | 15 +++++++++++++++ .../tests/inference/test_text_inference.py | 5 ++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a53ddf639..2007818e5 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig @@ -142,6 +143,19 @@ def inference_bedrock() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_nvidia() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAConfig().model_dump(), + ) + ], + ) + + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -175,6 +189,7 @@ INFERENCE_FIXTURES = [ "vllm_remote", "remote", "bedrock", + "nvidia", ] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 6e263432a..b35d947ec 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -361,7 +361,10 @@ class TestInference: for chunk in grouped[ChatCompletionResponseEventType.progress] ) first = grouped[ChatCompletionResponseEventType.progress][0] - assert first.event.delta.parse_status == ToolCallParseStatus.started + if not isinstance( + first.event.delta.content, ToolCall + ): # first chunk may contain entire call + assert first.event.delta.parse_status == ToolCallParseStatus.started last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason