mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 00:22:36 +00:00
enable streaming support, use openai-python instead of httpx
This commit is contained in:
parent
2dd8c4bcb6
commit
dbe665ed19
7 changed files with 1037 additions and 341 deletions
|
|
@ -5,9 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
|
|
@ -17,6 +16,7 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.sku_list import CoreModelId
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -32,7 +32,12 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
|
||||
from ._config import NVIDIAConfig
|
||||
from ._utils import check_health, convert_chat_completion_request, parse_completion
|
||||
from ._openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from ._utils import check_health
|
||||
|
||||
SUPPORTED_MODELS: Dict[CoreModelId, str] = {
|
||||
CoreModelId.llama3_8b_instruct: "meta/llama3-8b-instruct",
|
||||
|
|
@ -71,17 +76,12 @@ class NVIDIAInferenceAdapter(Inference):
|
|||
# )
|
||||
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
**(
|
||||
{b"Authorization": f"Bearer {self._config.api_key}"}
|
||||
if self._config.api_key
|
||||
else {}
|
||||
),
|
||||
}
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.base_url}/v1",
|
||||
api_key=self._config.api_key or "NO KEY",
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
# TODO(mf): filter by available models
|
||||
|
|
@ -98,7 +98,7 @@ class NVIDIAInferenceAdapter(Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(
|
||||
|
|
@ -121,56 +121,37 @@ class NVIDIAInferenceAdapter(Inference):
|
|||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
|
||||
if stream:
|
||||
raise ValueError("Streamed completions are not supported")
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=SUPPORTED_MODELS[CoreModelId(model)],
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=SUPPORTED_MODELS[CoreModelId(model)],
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._config.timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self._config.base_url}/v1/chat/completions",
|
||||
headers=self._headers,
|
||||
json=convert_chat_completion_request(request, n=1),
|
||||
)
|
||||
except httpx.ReadTimeout as e:
|
||||
raise TimeoutError(
|
||||
f"Request timed out. timeout set to {self._config.timeout}. Use `llama stack configure ...` to adjust it."
|
||||
) from e
|
||||
|
||||
if response.status_code == 401:
|
||||
raise PermissionError(
|
||||
"Unauthorized. Please check your API key, reconfigure, and try again."
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
raise ValueError(
|
||||
f"Bad request. Please check the request and try again. Detail: {response.text}"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise ValueError(
|
||||
"Model not found. Please check the model name and try again."
|
||||
)
|
||||
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed to get completion: {response.text}"
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}"
|
||||
) from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return parse_completion(response.json()["choices"][0])
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue