add nvidia provider for inference tests

This commit is contained in:
Matthew Farrellee 2024-11-21 16:03:55 -05:00
parent 1e18791fff
commit 5ab3773577
2 changed files with 19 additions and 1 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
@ -142,6 +143,19 @@ def inference_bedrock() -> ProviderFixture:
) )
@pytest.fixture(scope="session")
def inference_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAConfig().model_dump(),
)
],
)
def get_model_short_name(model_name: str) -> str: def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier. """Convert model name to a short test identifier.
@ -175,6 +189,7 @@ INFERENCE_FIXTURES = [
"vllm_remote", "vllm_remote",
"remote", "remote",
"bedrock", "bedrock",
"nvidia",
] ]

View file

@ -361,7 +361,10 @@ class TestInference:
for chunk in grouped[ChatCompletionResponseEventType.progress] for chunk in grouped[ChatCompletionResponseEventType.progress]
) )
first = grouped[ChatCompletionResponseEventType.progress][0] first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started if not isinstance(
first.event.delta.content, ToolCall
): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1] last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason # assert last.event.stop_reason == expected_stop_reason