From b52df5fe5b618d74afd2e49ec13cf623d59f5c8a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 11 Dec 2024 13:08:38 -0500 Subject: [PATCH] add completion api support to nvidia inference provider (#533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? add the completion api to the nvidia inference provider ## Test Plan while running the meta/llama-3.1-8b-instruct NIM from https://build.nvidia.com/meta/llama-3_1-8b-instruct?snippet_tab=Docker ``` ➜ pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_BASE_URL=http://localhost:8000 -k test_completion --inference-model Llama3.1-8B-Instruct =============================================== test session starts =============================================== platform linux -- Python 3.10.15, pytest-8.3.3, pluggy-1.5.0 -- /home/matt/.conda/envs/stack/bin/python cachedir: .pytest_cache rootdir: /home/matt/Documents/Repositories/meta-llama/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0, httpx-0.34.0 asyncio: mode=strict, default_loop_scope=None collected 20 items / 18 deselected / 2 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-nvidia] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-nvidia] SKIPPED ============================= 1 passed, 1 skipped, 18 deselected, 6 warnings in 5.40s ============================= ``` the structured output functionality works but the accuracy fails ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --- .../remote/inference/nvidia/nvidia.py | 40 ++++- .../remote/inference/nvidia/openai_utils.py | 169 +++++++++++++++++- .../tests/inference/test_text_inference.py | 6 +- 3 files changed, 208 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index f38aa7112..a97882497 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -9,6 +9,7 @@ from typing import AsyncIterator, List, Optional, Union from llama_models.datatypes import SamplingParams from llama_models.llama3.api.datatypes import ( + ImageMedia, InterleavedTextMedia, Message, ToolChoice, @@ -22,6 +23,7 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, + CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, @@ -37,8 +39,11 @@ from llama_stack.providers.utils.inference.model_registry import ( from . import NVIDIAConfig from .openai_utils import ( convert_chat_completion_request, + convert_completion_request, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, + convert_openai_completion_choice, + convert_openai_completion_stream, ) from .utils import _is_nvidia_hosted, check_health @@ -115,7 +120,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): timeout=self._config.timeout, ) - def completion( + async def completion( self, model_id: str, content: InterleavedTextMedia, @@ -124,7 +129,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: - raise NotImplementedError() + if isinstance(content, ImageMedia) or ( + isinstance(content, list) + and any(isinstance(c, ImageMedia) for c in content) + ): + raise NotImplementedError("ImageMedia is not supported") + + await check_health(self._config) # this raises errors + + request = convert_completion_request( + request=CompletionRequest( + model=self.get_provider_model_id(model_id), + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ), + n=1, + ) + + try: + response = await self._client.completions.create(**request) + except APIConnectionError as e: + raise ConnectionError( + f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}" + ) from e + + if stream: + return convert_openai_completion_stream(response) + else: + # we pass n=1 to get only one completion + return convert_openai_completion_choice(response.choices[0]) async def embeddings( self, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index b74aa05da..ba8ff0fa4 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -17,7 +17,6 @@ from llama_models.llama3.api.datatypes import ( ToolDefinition, ) from openai import AsyncStream - from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionChunk as OpenAIChatCompletionChunk, @@ -31,10 +30,11 @@ from openai.types.chat.chat_completion import ( Choice as OpenAIChoice, ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ) - from openai.types.chat.chat_completion_message_tool_call_param import ( Function as OpenAIFunction, ) +from openai.types.completion import Completion as OpenAICompletion +from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -42,6 +42,9 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEvent, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, JsonSchemaResponseFormat, Message, SystemMessage, @@ -579,3 +582,165 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) + + +def convert_completion_request( + request: CompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # prompt -> prompt + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # response_format -> nvext.guided_json + # stream -> stream + # logprobs.top_k -> logprobs + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + prompt=request.content, + stream=request.stream, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + n=n, + ) + + if request.response_format: + # this is not openai compliant, it is a nim extension + nvext.update(guided_json=request.response_format.json_schema) + + if request.logprobs: + payload.update(logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if ( + request.sampling_params.top_k != -1 + and request.sampling_params.top_k < 1 + ): + warnings.warn("top_k must be -1 or >= 1") + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _convert_openai_completion_logprobs( + logprobs: Optional[OpenAICompletionLogprobs], +) -> Optional[List[TokenLogProbs]]: + """ + Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. + + OpenAI CompletionLogprobs: + text_offset: Optional[List[int]] + token_logprobs: Optional[List[float]] + tokens: Optional[List[str]] + top_logprobs: Optional[List[Dict[str, float]]] + + -> + + TokenLogProbs: + logprobs_by_token: Dict[str, float] + - token, logprob + """ + if not logprobs: + return None + + return [ + TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs + ] + + +def convert_openai_completion_choice( + choice: OpenAIChoice, +) -> CompletionResponse: + """ + Convert an OpenAI Completion Choice into a CompletionResponse. + + OpenAI Completion Choice: + text: str + finish_reason: str + logprobs: Optional[ChoiceLogprobs] + + -> + + CompletionResponse: + completion_message: CompletionMessage + logprobs: Optional[List[TokenLogProbs]] + + CompletionMessage: + role: Literal["assistant"] + content: str | ImageMedia | List[str | ImageMedia] + stop_reason: StopReason + tool_calls: List[ToolCall] + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + return CompletionResponse( + content=choice.text, + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + logprobs=_convert_openai_completion_logprobs(choice.logprobs), + ) + + +async def convert_openai_completion_stream( + stream: AsyncStream[OpenAICompletion], +) -> AsyncGenerator[CompletionResponse, None]: + """ + Convert a stream of OpenAI Completions into a stream + of ChatCompletionResponseStreamChunks. + + OpenAI Completion: + id: str + choices: List[OpenAICompletionChoice] + created: int + model: str + system_fingerprint: Optional[str] + usage: Optional[OpenAICompletionUsage] + + OpenAI CompletionChoice: + finish_reason: str + index: int + logprobs: Optional[OpenAILogprobs] + text: str + + -> + + CompletionResponseStreamChunk: + delta: str + stop_reason: Optional[StopReason] + logprobs: Optional[List[TokenLogProbs]] + """ + async for chunk in stream: + choice = chunk.choices[0] + yield CompletionResponseStreamChunk( + delta=choice.text, + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + logprobs=_convert_openai_completion_logprobs(choice.logprobs), + ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index b84761219..741b61c5c 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -94,6 +94,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::nvidia", "remote::cerebras", ): pytest.skip("Other inference providers don't support completion() yet") @@ -129,9 +130,7 @@ class TestInference: @pytest.mark.asyncio @pytest.mark.skip("This test is not quite robust") - async def test_completions_structured_output( - self, inference_model, inference_stack - ): + async def test_completion_structured_output(self, inference_model, inference_stack): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) @@ -140,6 +139,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::nvidia", "remote::vllm", "remote::cerebras", ):