mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 17:00:09 +00:00
add completion api support
This commit is contained in:
parent
d3956a1d22
commit
6d41a93188
3 changed files with 208 additions and 7 deletions
|
|
@ -9,6 +9,7 @@ from typing import AsyncIterator, List, Optional, Union
|
|||
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
ImageMedia,
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolChoice,
|
||||
|
|
@ -22,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
|
|
@ -37,8 +39,11 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from . import NVIDIAConfig
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_openai_completion_choice,
|
||||
convert_openai_completion_stream,
|
||||
)
|
||||
from .utils import _is_nvidia_hosted, check_health
|
||||
|
||||
|
|
@ -115,7 +120,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -124,7 +129,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
raise NotImplementedError()
|
||||
if isinstance(content, ImageMedia) or (
|
||||
isinstance(content, list)
|
||||
and any(isinstance(c, ImageMedia) for c in content)
|
||||
):
|
||||
raise NotImplementedError("ImageMedia is not supported")
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
|
||||
) from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_completion_stream(response)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_completion_choice(response.choices[0])
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue