add completion api support to nvidia inference provider (#533)

# What does this PR do?

add the completion api to the nvidia inference provider


## Test Plan

while running the meta/llama-3.1-8b-instruct NIM from
https://build.nvidia.com/meta/llama-3_1-8b-instruct?snippet_tab=Docker

```
➜ pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_BASE_URL=http://localhost:8000 -k test_completion --inference-model Llama3.1-8B-Instruct
=============================================== test session starts ===============================================
platform linux -- Python 3.10.15, pytest-8.3.3, pluggy-1.5.0 -- /home/matt/.conda/envs/stack/bin/python
cachedir: .pytest_cache
rootdir: /home/matt/Documents/Repositories/meta-llama/llama-stack
configfile: pyproject.toml
plugins: anyio-4.6.2.post1, asyncio-0.24.0, httpx-0.34.0
asyncio: mode=strict, default_loop_scope=None
collected 20 items / 18 deselected / 2 selected                                                                             

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-nvidia] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-nvidia] SKIPPED

============================= 1 passed, 1 skipped, 18 deselected, 6 warnings in 5.40s =============================
```

the structured output functionality works but the accuracy fails

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.
This commit is contained in:
Matthew Farrellee 2024-12-11 13:08:38 -05:00 committed by GitHub
parent 07c72c4256
commit b52df5fe5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 208 additions and 7 deletions

View file

@ -9,6 +9,7 @@ from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
ImageMedia,
InterleavedTextMedia,
Message,
ToolChoice,
@ -22,6 +23,7 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
@ -37,8 +39,11 @@ from llama_stack.providers.utils.inference.model_registry import (
from . import NVIDIAConfig
from .openai_utils import (
convert_chat_completion_request,
convert_completion_request,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_openai_completion_choice,
convert_openai_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
@ -115,7 +120,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
timeout=self._config.timeout,
)
def completion(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
@ -124,7 +129,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
raise NotImplementedError()
if isinstance(content, ImageMedia) or (
isinstance(content, list)
and any(isinstance(c, ImageMedia) for c in content)
):
raise NotImplementedError("ImageMedia is not supported")
await check_health(self._config) # this raises errors
request = convert_completion_request(
request=CompletionRequest(
model=self.get_provider_model_id(model_id),
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
),
n=1,
)
try:
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
if stream:
return convert_openai_completion_stream(response)
else:
# we pass n=1 to get only one completion
return convert_openai_completion_choice(response.choices[0])
async def embeddings(
self,

View file

@ -17,7 +17,6 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -31,10 +30,11 @@ from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -42,6 +42,9 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
@ -579,3 +582,165 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
def convert_completion_request(
request: CompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# prompt -> prompt
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> nvext.guided_json
# stream -> stream
# logprobs.top_k -> logprobs
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
prompt=request.content,
stream=request.stream,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
n=n,
)
if request.response_format:
# this is not openai compliant, it is a nim extension
nvext.update(guided_json=request.response_format.json_schema)
if request.logprobs:
payload.update(logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_completion_logprobs(
logprobs: Optional[OpenAICompletionLogprobs],
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
OpenAI CompletionLogprobs:
text_offset: Optional[List[int]]
token_logprobs: Optional[List[float]]
tokens: Optional[List[str]]
top_logprobs: Optional[List[Dict[str, float]]]
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs
]
def convert_openai_completion_choice(
choice: OpenAIChoice,
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
OpenAI Completion Choice:
text: str
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
->
CompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
return CompletionResponse(
content=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)
async def convert_openai_completion_stream(
stream: AsyncStream[OpenAICompletion],
) -> AsyncGenerator[CompletionResponse, None]:
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
OpenAI Completion:
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
system_fingerprint: Optional[str]
usage: Optional[OpenAICompletionUsage]
OpenAI CompletionChoice:
finish_reason: str
index: int
logprobs: Optional[OpenAILogprobs]
text: str
->
CompletionResponseStreamChunk:
delta: str
stop_reason: Optional[StopReason]
logprobs: Optional[List[TokenLogProbs]]
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
delta=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)

View file

@ -94,6 +94,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -129,9 +130,7 @@ class TestInference:
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, inference_model, inference_stack
):
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
@ -140,6 +139,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):