mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
enable streaming support, use openai-python instead of httpx
This commit is contained in:
parent
2dd8c4bcb6
commit
dbe665ed19
7 changed files with 1037 additions and 341 deletions
|
@ -5,9 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
|
@ -17,6 +16,7 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.sku_list import CoreModelId
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -32,7 +32,12 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
|
||||
from ._config import NVIDIAConfig
|
||||
from ._utils import check_health, convert_chat_completion_request, parse_completion
|
||||
from ._openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from ._utils import check_health
|
||||
|
||||
SUPPORTED_MODELS: Dict[CoreModelId, str] = {
|
||||
CoreModelId.llama3_8b_instruct: "meta/llama3-8b-instruct",
|
||||
|
@ -71,17 +76,12 @@ class NVIDIAInferenceAdapter(Inference):
|
|||
# )
|
||||
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
**(
|
||||
{b"Authorization": f"Bearer {self._config.api_key}"}
|
||||
if self._config.api_key
|
||||
else {}
|
||||
),
|
||||
}
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.base_url}/v1",
|
||||
api_key=self._config.api_key or "NO KEY",
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
# TODO(mf): filter by available models
|
||||
|
@ -98,7 +98,7 @@ class NVIDIAInferenceAdapter(Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(
|
||||
|
@ -121,56 +121,37 @@ class NVIDIAInferenceAdapter(Inference):
|
|||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
|
||||
if stream:
|
||||
raise ValueError("Streamed completions are not supported")
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=SUPPORTED_MODELS[CoreModelId(model)],
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=SUPPORTED_MODELS[CoreModelId(model)],
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._config.timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self._config.base_url}/v1/chat/completions",
|
||||
headers=self._headers,
|
||||
json=convert_chat_completion_request(request, n=1),
|
||||
)
|
||||
except httpx.ReadTimeout as e:
|
||||
raise TimeoutError(
|
||||
f"Request timed out. timeout set to {self._config.timeout}. Use `llama stack configure ...` to adjust it."
|
||||
) from e
|
||||
|
||||
if response.status_code == 401:
|
||||
raise PermissionError(
|
||||
"Unauthorized. Please check your API key, reconfigure, and try again."
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
raise ValueError(
|
||||
f"Bad request. Please check the request and try again. Detail: {response.text}"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise ValueError(
|
||||
"Model not found. Please check the model name and try again."
|
||||
)
|
||||
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed to get completion: {response.text}"
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}"
|
||||
) from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return parse_completion(response.json()["choices"][0])
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
|
430
llama_stack/providers/adapters/inference/nvidia/_openai_utils.py
Normal file
430
llama_stack/providers/adapters/inference/nvidia/_openai_utils.py
Normal file
|
@ -0,0 +1,430 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolCall,
|
||||
)
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Message,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> Dict:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
out_dict = message.dict()
|
||||
# Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool"
|
||||
if out_dict["role"] == "ipython":
|
||||
out_dict.update(role="tool")
|
||||
|
||||
if "stop_reason" in out_dict:
|
||||
out_dict.update(stop_reason=out_dict["stop_reason"].value)
|
||||
|
||||
# TODO(mf): tool_calls
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# messages -> messages
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# tools -> tools
|
||||
# tool_choice ("auto", "required") -> tool_choice
|
||||
# tool_prompt_format -> TBD
|
||||
# stream -> stream
|
||||
# logprobs -> logprobs
|
||||
|
||||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
n=n,
|
||||
extra_body=dict(nvext=nvext),
|
||||
extra_headers={
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
},
|
||||
)
|
||||
|
||||
if request.tools:
|
||||
payload.update(tools=request.tools)
|
||||
if request.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_choice.value
|
||||
) # we cannot include tool_choice w/o tools, server will complain
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=True)
|
||||
payload.update(top_logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if (
|
||||
request.sampling_params.top_k != -1
|
||||
and request.sampling_params.top_k < 1
|
||||
):
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls", ...]
|
||||
- stop: model hit a natural stop point or a provided stop sequence
|
||||
- length: maximum number of tokens specified in the request was reached
|
||||
- tool_calls: model called a tool
|
||||
|
||||
->
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# TODO(mf): are end_of_turn and end_of_message semantics correct?
|
||||
return {
|
||||
"stop": StopReason.end_of_turn,
|
||||
"length": StopReason.out_of_tokens,
|
||||
"tool_calls": StopReason.end_of_message,
|
||||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
||||
OpenAI ChatCompletionMessageToolCall:
|
||||
id: str
|
||||
function: Function
|
||||
type: Literal["function"]
|
||||
|
||||
OpenAI Function:
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
->
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: Dict[str, ...]
|
||||
"""
|
||||
if not tool_calls:
|
||||
return [] # CompletionMessage tool_calls is not optional
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _convert_openai_logprobs(
|
||||
logprobs: OpenAIChoiceLogprobs,
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
|
||||
|
||||
OpenAI ChoiceLogprobs:
|
||||
content: Optional[List[ChatCompletionTokenLogprob]]
|
||||
|
||||
OpenAI ChatCompletionTokenLogprob:
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
OpenAI TopLogprob:
|
||||
token: str
|
||||
logprob: float
|
||||
|
||||
->
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
||||
return [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
|
||||
}
|
||||
)
|
||||
for content in logprobs.content
|
||||
]
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Choice into a ChatCompletionResponse.
|
||||
|
||||
OpenAI Choice:
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: str
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChatCompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponse:
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
|
||||
CompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: str | ImageMedia | List[str | ImageMedia]
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
assert (
|
||||
hasattr(choice, "message") and choice.message
|
||||
), "error in server response: message not found"
|
||||
assert (
|
||||
hasattr(choice, "finish_reason") and choice.finish_reason
|
||||
), "error in server response: finish_reason not found"
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content
|
||||
or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
||||
async def convert_openai_chat_completion_stream(
|
||||
stream: AsyncStream[OpenAIChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""
|
||||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
|
||||
OpenAI ChatCompletionChunk:
|
||||
choices: List[Choice]
|
||||
|
||||
OpenAI Choice: # different from the non-streamed Choice
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChoiceDelta:
|
||||
content: Optional[str]
|
||||
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCall:
|
||||
index: int
|
||||
id: Optional[str]
|
||||
function: Optional[ChoiceDeltaToolCallFunction]
|
||||
type: Optional[Literal["function"]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCallFunction:
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponseStreamChunk:
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
ChatCompletionResponseEvent:
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: Union[str, ToolCallDelta]
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
stop_reason: Optional[StopReason]
|
||||
|
||||
ChatCompletionResponseEventType:
|
||||
start = "start"
|
||||
progress = "progress"
|
||||
complete = "complete"
|
||||
|
||||
ToolCallDelta:
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: str
|
||||
|
||||
ToolCallParseStatus:
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failure = "failure"
|
||||
success = "success"
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
StopReason:
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
def _event_type_generator() -> (
|
||||
Generator[ChatCompletionResponseEventType, None, None]
|
||||
):
|
||||
yield ChatCompletionResponseEventType.start
|
||||
while True:
|
||||
yield ChatCompletionResponseEventType.progress
|
||||
|
||||
event_type = _event_type_generator()
|
||||
|
||||
# we implement NIM specific semantics, the main difference from OpenAI
|
||||
# is that tool_calls are always produced as a complete call. there is no
|
||||
# intermediate / partial tool call streamed. because of this, we can
|
||||
# simplify the logic and not concern outselves with parse_status of
|
||||
# started/in_progress/failed. we can always assume success.
|
||||
#
|
||||
# a stream of ChatCompletionResponseStreamChunk consists of
|
||||
# 0. a start event
|
||||
# 1. zero or more progress events
|
||||
# - each progress event has a delta
|
||||
# - each progress event may have a stop_reason
|
||||
# - each progress event may have logprobs
|
||||
# - each progress event may have tool_calls
|
||||
# if a progress event has tool_calls,
|
||||
# it is fully formed and
|
||||
# can be emitted with a parse_status of success
|
||||
# 2. a complete event
|
||||
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
# if tool call and content, emit both separately
|
||||
|
||||
if choice.delta.tool_calls:
|
||||
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||
# does not support both, so we emit the content first
|
||||
if choice.delta.content:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content,
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
# it is possible to have parallel tool calls in stream, but
|
||||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn(
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest"
|
||||
)
|
||||
|
||||
# NIM only produces fully formed tool calls, so we can assume success
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content or "", # content is not optional
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
|
@ -4,43 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Message,
|
||||
)
|
||||
|
||||
from ._config import NVIDIAConfig
|
||||
|
||||
|
||||
def convert_message(message: Message) -> dict:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
out_dict = message.dict()
|
||||
# Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool"
|
||||
if out_dict["role"] == "ipython":
|
||||
out_dict.update(role="tool")
|
||||
|
||||
if "stop_reason" in out_dict:
|
||||
out_dict.update(stop_reason=out_dict["stop_reason"].value)
|
||||
|
||||
# TODO(mf): tool_calls
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
async def _get_health(url: str) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
||||
|
@ -78,251 +48,3 @@ async def check_health(config: NVIDIAConfig) -> None:
|
|||
# TODO(mf): should we wait for the server to be ready?
|
||||
except httpx.ConnectError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# messages -> messages
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# tools -> tools
|
||||
# tool_choice ("auto", "required") -> tool_choice
|
||||
# tool_prompt_format -> TBD
|
||||
# stream -> stream
|
||||
# logprobs -> logprobs
|
||||
|
||||
print(f"sampling_params: {request.sampling_params}")
|
||||
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[convert_message(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
nvext={},
|
||||
n=n,
|
||||
)
|
||||
nvext = payload["nvext"]
|
||||
|
||||
if request.tools:
|
||||
payload.update(tools=request.tools)
|
||||
if request.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_choice.value
|
||||
) # we cannot include tool_choice w/o tools, server will complain
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=True)
|
||||
payload.update(top_logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if (
|
||||
request.sampling_params.top_k != -1
|
||||
and request.sampling_params.top_k < 1
|
||||
):
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_content(completion: dict) -> str:
|
||||
"""
|
||||
Get the content from an OpenAI completion response.
|
||||
|
||||
OpenAI completion response format -
|
||||
{
|
||||
...
|
||||
"message": {"role": "assistant", "content": ..., ...},
|
||||
...
|
||||
}
|
||||
"""
|
||||
# content is nullable in the OpenAI response, common for tool calls
|
||||
return completion["message"]["content"] or ""
|
||||
|
||||
|
||||
def _parse_stop_reason(completion: dict) -> StopReason:
|
||||
"""
|
||||
Get the StopReason from an OpenAI completion response.
|
||||
|
||||
OpenAI completion response format -
|
||||
{
|
||||
...
|
||||
"finish_reason": "length" or "stop" or "tool_calls",
|
||||
...
|
||||
}
|
||||
"""
|
||||
|
||||
# StopReason options are end_of_turn, end_of_message, out_of_tokens
|
||||
# TODO(mf): is end_of_turn and end_of_message usage correct?
|
||||
stop_reason = StopReason.end_of_turn
|
||||
if completion["finish_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
elif completion["finish_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_message
|
||||
elif completion["finish_reason"] == "tool_calls":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
return stop_reason
|
||||
|
||||
|
||||
def _parse_tool_calls(completion: dict) -> List[ToolCall]:
|
||||
"""
|
||||
Get the tool calls from an OpenAI completion response.
|
||||
|
||||
OpenAI completion response format -
|
||||
{
|
||||
...,
|
||||
"message": {
|
||||
...,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": X,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": Y,
|
||||
"arguments": Z,
|
||||
},
|
||||
}*
|
||||
],
|
||||
},
|
||||
}
|
||||
->
|
||||
[
|
||||
ToolCall(call_id=X, tool_name=Y, arguments=Z),
|
||||
...
|
||||
]
|
||||
"""
|
||||
tool_calls = []
|
||||
if "tool_calls" in completion["message"]:
|
||||
assert isinstance(
|
||||
completion["message"]["tool_calls"], list
|
||||
), "error in server response: tool_calls not a list"
|
||||
for call in completion["message"]["tool_calls"]:
|
||||
assert "id" in call, "error in server response: tool call id not found"
|
||||
assert (
|
||||
"function" in call
|
||||
), "error in server response: tool call function not found"
|
||||
assert (
|
||||
"name" in call["function"]
|
||||
), "error in server response: tool call function name not found"
|
||||
assert (
|
||||
"arguments" in call["function"]
|
||||
), "error in server response: tool call function arguments not found"
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call["id"],
|
||||
tool_name=call["function"]["name"],
|
||||
arguments=call["function"]["arguments"],
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _parse_logprobs(completion: dict) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Extract logprobs from OpenAI as a list of TokenLogProbs.
|
||||
|
||||
OpenAI completion response format -
|
||||
{
|
||||
...
|
||||
"logprobs": {
|
||||
content: [
|
||||
{
|
||||
...,
|
||||
top_logprobs: [{token: X, logprob: Y, bytes: [...]}+]
|
||||
}+
|
||||
]
|
||||
},
|
||||
...
|
||||
}
|
||||
->
|
||||
[
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={X: Y, ...}
|
||||
),
|
||||
...
|
||||
]
|
||||
"""
|
||||
if not (logprobs := completion.get("logprobs")):
|
||||
return None
|
||||
|
||||
return [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
logprobs["token"]: logprobs["logprob"]
|
||||
for logprobs in content["top_logprobs"]
|
||||
}
|
||||
)
|
||||
for content in logprobs["content"]
|
||||
]
|
||||
|
||||
|
||||
def parse_completion(
|
||||
completion: dict,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Parse an OpenAI completion response into a CompletionMessage and logprobs.
|
||||
|
||||
OpenAI completion response format -
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ...,
|
||||
"tool_calls": [
|
||||
{
|
||||
...
|
||||
"id": ...,
|
||||
"function": {
|
||||
"name": ...,
|
||||
"arguments": ...,
|
||||
},
|
||||
}*
|
||||
]?,
|
||||
"finish_reason": ...,
|
||||
"logprobs": {
|
||||
"content": [
|
||||
{
|
||||
...,
|
||||
"top_logprobs": [{"token": ..., "logprob": ..., ...}+]
|
||||
}+
|
||||
]
|
||||
}?
|
||||
}
|
||||
"""
|
||||
assert "message" in completion, "error in server response: message not found"
|
||||
assert (
|
||||
"finish_reason" in completion
|
||||
), "error in server response: finish_reason not found"
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=_parse_content(completion),
|
||||
stop_reason=_parse_stop_reason(completion),
|
||||
tool_calls=_parse_tool_calls(completion),
|
||||
),
|
||||
logprobs=_parse_logprobs(completion),
|
||||
)
|
||||
|
|
|
@ -144,7 +144,9 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[], # TODO(mf): need to specify httpx if it's already a llama-stack dep?
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.nvidia",
|
||||
config_class="llama_stack.providers.adapters.inference.nvidia.NVIDIAConfig",
|
||||
),
|
||||
|
|
|
@ -8,11 +8,15 @@ import itertools
|
|||
from typing import Generator, List, Tuple
|
||||
|
||||
import pytest
|
||||
from llama_models.datatypes import SamplingParams
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
# LogProbConfig,
|
||||
Message,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
|
@ -96,6 +100,70 @@ async def test_chat_completion_messages(
|
|||
assert response.completion_message.tool_calls == []
|
||||
|
||||
|
||||
async def test_chat_completion_basic(
|
||||
client: Inference,
|
||||
model: str,
|
||||
):
|
||||
"""
|
||||
Test the chat completion endpoint with basic messages, with and without streaming.
|
||||
"""
|
||||
client = await client
|
||||
messages = [
|
||||
UserMessage(content="How are you?"),
|
||||
]
|
||||
|
||||
response = await client.chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
# we're not testing accuracy, so no assertions on the result.completion_message.content
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.stop_reason, StopReason)
|
||||
assert response.completion_message.tool_calls == []
|
||||
|
||||
|
||||
async def test_chat_completion_stream_basic(
|
||||
client: Inference,
|
||||
model: str,
|
||||
):
|
||||
"""
|
||||
Test the chat completion endpoint with basic messages, with and without streaming.
|
||||
"""
|
||||
client = await client
|
||||
messages = [
|
||||
UserMessage(content="How are you?"),
|
||||
]
|
||||
|
||||
response = await client.chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
sampling_params=SamplingParams(max_tokens=5),
|
||||
# logprobs=LogProbConfig(top_k=3),
|
||||
)
|
||||
|
||||
chunks = [chunk async for chunk in response]
|
||||
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert all(isinstance(chunk.event.delta, str) for chunk in chunks)
|
||||
assert chunks[0].event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
if len(chunks) > 2:
|
||||
assert all(
|
||||
chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
for chunk in chunks[1:-1]
|
||||
)
|
||||
# we're not testing accuracy, so no assertions on the result.completion_message.content
|
||||
assert all(
|
||||
chunk.event.stop_reason is None
|
||||
or isinstance(chunk.event.stop_reason, StopReason)
|
||||
for chunk in chunks
|
||||
)
|
||||
|
||||
|
||||
async def test_bad_base_url(
|
||||
model: str,
|
||||
):
|
||||
|
|
|
@ -157,7 +157,7 @@ async def test_tools(
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "magic",
|
||||
"arguments": {"input": 3},
|
||||
"arguments": '{"input": 3}',
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -165,7 +165,7 @@ async def test_tools(
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "magic!",
|
||||
"arguments": {"input": 42},
|
||||
"arguments": '{"input": 42}',
|
||||
},
|
||||
},
|
||||
],
|
||||
|
|
493
tests/nvidia/unit/test_openai_utils.py
Normal file
493
tests/nvidia/unit/test_openai_utils.py
Normal file
|
@ -0,0 +1,493 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
from llama_stack.providers.adapters.inference.nvidia._openai_utils import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionTokenLogprob,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as ChoiceChunk,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_token_logprob import TopLogprob
|
||||
|
||||
|
||||
def test_convert_openai_chat_completion_choice_basic():
|
||||
response = Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
result = convert_openai_chat_completion_choice(response)
|
||||
assert isinstance(result, ChatCompletionResponse)
|
||||
assert result.completion_message.content == "Hello, world!"
|
||||
assert result.completion_message.stop_reason == StopReason.end_of_turn
|
||||
assert result.completion_message.tool_calls == []
|
||||
assert result.logprobs is None
|
||||
|
||||
|
||||
def test_convert_openai_chat_completion_choice_basic_with_tool_calls():
|
||||
response = Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function={
|
||||
"name": "test_function",
|
||||
"arguments": '{"test_args": "test_value"}',
|
||||
},
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
|
||||
result = convert_openai_chat_completion_choice(response)
|
||||
assert isinstance(result, ChatCompletionResponse)
|
||||
assert result.completion_message.content == "Hello, world!"
|
||||
assert result.completion_message.stop_reason == StopReason.end_of_message
|
||||
assert len(result.completion_message.tool_calls) == 1
|
||||
assert result.completion_message.tool_calls[0].tool_name == "test_function"
|
||||
assert result.completion_message.tool_calls[0].arguments == {
|
||||
"test_args": "test_value"
|
||||
}
|
||||
assert result.logprobs is None
|
||||
|
||||
|
||||
def test_convert_openai_chat_completion_choice_basic_with_logprobs():
|
||||
response = Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello world",
|
||||
),
|
||||
finish_reason="stop",
|
||||
logprobs=ChoiceLogprobs(
|
||||
content=[
|
||||
ChatCompletionTokenLogprob(
|
||||
token="Hello",
|
||||
logprob=-1.0,
|
||||
bytes=[72, 101, 108, 108, 111],
|
||||
top_logprobs=[
|
||||
TopLogprob(
|
||||
token="Hello", logprob=-1.0, bytes=[72, 101, 108, 108, 111]
|
||||
),
|
||||
TopLogprob(
|
||||
token="Greetings",
|
||||
logprob=-1.5,
|
||||
bytes=[71, 114, 101, 101, 116, 105, 110, 103, 115],
|
||||
),
|
||||
],
|
||||
),
|
||||
ChatCompletionTokenLogprob(
|
||||
token="world",
|
||||
logprob=-1.5,
|
||||
bytes=[119, 111, 114, 108, 100],
|
||||
top_logprobs=[
|
||||
TopLogprob(
|
||||
token="world", logprob=-1.5, bytes=[119, 111, 114, 108, 100]
|
||||
),
|
||||
TopLogprob(
|
||||
token="planet",
|
||||
logprob=-2.0,
|
||||
bytes=[112, 108, 97, 110, 101, 116],
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = convert_openai_chat_completion_choice(response)
|
||||
assert isinstance(result, ChatCompletionResponse)
|
||||
assert result.completion_message.content == "Hello world"
|
||||
assert result.completion_message.stop_reason == StopReason.end_of_turn
|
||||
assert result.completion_message.tool_calls == []
|
||||
assert result.logprobs is not None
|
||||
assert len(result.logprobs) == 2
|
||||
assert len(result.logprobs[0].logprobs_by_token) == 2
|
||||
assert result.logprobs[0].logprobs_by_token["Hello"] == -1.0
|
||||
assert result.logprobs[0].logprobs_by_token["Greetings"] == -1.5
|
||||
assert len(result.logprobs[1].logprobs_by_token) == 2
|
||||
assert result.logprobs[1].logprobs_by_token["world"] == -1.5
|
||||
assert result.logprobs[1].logprobs_by_token["planet"] == -2.0
|
||||
|
||||
|
||||
def test_convert_openai_chat_completion_choice_missing_message():
|
||||
response = Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
response.message = None
|
||||
with pytest.raises(
|
||||
AssertionError, match="error in server response: message not found"
|
||||
):
|
||||
convert_openai_chat_completion_choice(response)
|
||||
|
||||
del response.message
|
||||
with pytest.raises(
|
||||
AssertionError, match="error in server response: message not found"
|
||||
):
|
||||
convert_openai_chat_completion_choice(response)
|
||||
|
||||
|
||||
def test_convert_openai_chat_completion_choice_missing_finish_reason():
|
||||
response = Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
response.finish_reason = None
|
||||
with pytest.raises(
|
||||
AssertionError, match="error in server response: finish_reason not found"
|
||||
):
|
||||
convert_openai_chat_completion_choice(response)
|
||||
|
||||
del response.finish_reason
|
||||
with pytest.raises(
|
||||
AssertionError, match="error in server response: finish_reason not found"
|
||||
):
|
||||
convert_openai_chat_completion_choice(response)
|
||||
|
||||
|
||||
# we want to test convert_openai_chat_completion_stream
|
||||
# we need to produce a stream of OpenAIChatCompletionChunk
|
||||
# streams to produce -
|
||||
# 0. basic stream with one chunk, should produce 3 (start, progress, complete)
|
||||
# 1. stream with 3 chunks, should produce 5 events (start, progress, progress, progress, complete)
|
||||
# 2. stream with a tool call, should produce 4 events (start, progress w/ tool_call, complete)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_openai_chat_completion_stream_basic():
|
||||
chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="1",
|
||||
created=1234567890,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
results = [
|
||||
result
|
||||
async for result in convert_openai_chat_completion_stream(
|
||||
async_generator_from_list(chunks)
|
||||
)
|
||||
]
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(
|
||||
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
||||
)
|
||||
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
||||
assert results[0].event.delta == "Hello, world!"
|
||||
assert results[1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert results[1].event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_openai_chat_completion_stream_basic_empty():
|
||||
chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="1",
|
||||
created=1234567890,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
),
|
||||
OpenAIChatCompletionChunk(
|
||||
id="1",
|
||||
created=1234567890,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
results = [
|
||||
result
|
||||
async for result in convert_openai_chat_completion_stream(
|
||||
async_generator_from_list(chunks)
|
||||
)
|
||||
]
|
||||
|
||||
print(results)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(
|
||||
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
||||
)
|
||||
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
||||
assert results[1].event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert results[1].event.delta == "Hello, world!"
|
||||
assert results[2].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert results[2].event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_openai_chat_completion_stream_multiple_chunks():
|
||||
chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="1",
|
||||
created=1234567890,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
),
|
||||
# finish_reason="continue",
|
||||
)
|
||||
],
|
||||
),
|
||||
OpenAIChatCompletionChunk(
|
||||
id="2",
|
||||
created=1234567891,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="How are you?",
|
||||
),
|
||||
# finish_reason="continue",
|
||||
)
|
||||
],
|
||||
),
|
||||
OpenAIChatCompletionChunk(
|
||||
id="3",
|
||||
created=1234567892,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="I'm good, thanks!",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
results = [
|
||||
result
|
||||
async for result in convert_openai_chat_completion_stream(
|
||||
async_generator_from_list(chunks)
|
||||
)
|
||||
]
|
||||
|
||||
assert len(results) == 4
|
||||
assert all(
|
||||
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
||||
)
|
||||
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
||||
assert results[0].event.delta == "Hello, world!"
|
||||
assert not results[0].event.stop_reason
|
||||
assert results[1].event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert results[1].event.delta == "How are you?"
|
||||
assert not results[1].event.stop_reason
|
||||
assert results[2].event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert results[2].event.delta == "I'm good, thanks!"
|
||||
assert not results[2].event.stop_reason
|
||||
assert results[3].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert results[3].event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_openai_chat_completion_stream_with_tool_call_and_content():
|
||||
chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="1",
|
||||
created=1234567890,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello, world!",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_function",
|
||||
arguments='{"test_args": "test_value"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
results = [
|
||||
result
|
||||
async for result in convert_openai_chat_completion_stream(
|
||||
async_generator_from_list(chunks)
|
||||
)
|
||||
]
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(
|
||||
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
||||
)
|
||||
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
||||
assert results[0].event.delta == "Hello, world!"
|
||||
assert not results[0].event.stop_reason
|
||||
assert results[1].event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert not isinstance(results[1].event.delta, str)
|
||||
assert results[1].event.delta.content.tool_name == "test_function"
|
||||
assert results[1].event.delta.content.arguments == {"test_args": "test_value"}
|
||||
assert not results[1].event.stop_reason
|
||||
assert results[2].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert results[2].event.stop_reason == StopReason.end_of_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_openai_chat_completion_stream_with_tool_call_and_no_content():
|
||||
chunks = [
|
||||
OpenAIChatCompletionChunk(
|
||||
id="1",
|
||||
created=1234567890,
|
||||
model="mock-model",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
ChoiceChunk(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="test_function",
|
||||
arguments='{"test_args": "test_value"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
results = [
|
||||
result
|
||||
async for result in convert_openai_chat_completion_stream(
|
||||
async_generator_from_list(chunks)
|
||||
)
|
||||
]
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(
|
||||
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
||||
)
|
||||
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
||||
assert not isinstance(results[0].event.delta, str)
|
||||
assert results[0].event.delta.content.tool_name == "test_function"
|
||||
assert results[0].event.delta.content.arguments == {"test_args": "test_value"}
|
||||
assert not results[0].event.stop_reason
|
||||
assert results[1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert results[1].event.stop_reason == StopReason.end_of_message
|
Loading…
Add table
Add a link
Reference in a new issue