forked from phoenix-oss/llama-stack-mirror
add NVIDIA NIM inference adapter (#355)
# What does this PR do? this PR adds a basic inference adapter to NVIDIA NIMs what it does - - chat completion api - tool calls - streaming - structured output - logprobs - support hosted NIM on integrate.api.nvidia.com - support downloaded NIM containers what it does not do - - completion api - embedding api - vision models - builtin tools - have certainty that sampling strategies are correct ## Feature/Issue validation/testing/test plan `pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...` all tests should pass. there are pydantic v1 warnings. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Was this discussed/approved via a Github issue? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? - [x] Did you write any new necessary tests? Thanks for contributing 🎉!
This commit is contained in:
parent
2cfc41e13b
commit
4e6c984c26
10 changed files with 934 additions and 10 deletions
|
@ -6,6 +6,8 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from .fixtures import INFERENCE_FIXTURES
|
||||
|
||||
|
||||
|
@ -67,11 +69,12 @@ def pytest_generate_tests(metafunc):
|
|||
indirect=True,
|
||||
)
|
||||
if "inference_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"inference_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in INFERENCE_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
fixtures = INFERENCE_FIXTURES
|
||||
if filtered_stacks := get_provider_fixture_overrides(
|
||||
metafunc.config,
|
||||
{
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
},
|
||||
):
|
||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
|
|||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
|
@ -142,6 +143,19 @@ def inference_bedrock() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_nvidia() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIAConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_model_short_name(model_name: str) -> str:
|
||||
"""Convert model name to a short test identifier.
|
||||
|
||||
|
@ -175,6 +189,7 @@ INFERENCE_FIXTURES = [
|
|||
"vllm_remote",
|
||||
"remote",
|
||||
"bedrock",
|
||||
"nvidia",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -198,6 +198,7 @@ class TestInference:
|
|||
"remote::fireworks",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::nvidia",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support structured output yet")
|
||||
|
||||
|
@ -361,7 +362,10 @@ class TestInference:
|
|||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
if not isinstance(
|
||||
first.event.delta.content, ToolCall
|
||||
): # first chunk may contain entire call
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue