mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 20:42:38 +00:00
493 lines
17 KiB
Python
493 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import AsyncGenerator, List
|
|
|
|
import pytest
|
|
from llama_models.llama3.api.datatypes import StopReason
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
)
|
|
from llama_stack.providers.adapters.inference.nvidia._openai_utils import (
|
|
convert_openai_chat_completion_choice,
|
|
convert_openai_chat_completion_stream,
|
|
)
|
|
from openai.types.chat import (
|
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
|
ChatCompletionMessage,
|
|
ChatCompletionMessageToolCall,
|
|
ChatCompletionTokenLogprob,
|
|
)
|
|
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
Choice as ChoiceChunk,
|
|
ChoiceDelta,
|
|
ChoiceDeltaToolCall,
|
|
ChoiceDeltaToolCallFunction,
|
|
)
|
|
from openai.types.chat.chat_completion_token_logprob import TopLogprob
|
|
|
|
|
|
def test_convert_openai_chat_completion_choice_basic():
|
|
response = Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
result = convert_openai_chat_completion_choice(response)
|
|
assert isinstance(result, ChatCompletionResponse)
|
|
assert result.completion_message.content == "Hello, world!"
|
|
assert result.completion_message.stop_reason == StopReason.end_of_turn
|
|
assert result.completion_message.tool_calls == []
|
|
assert result.logprobs is None
|
|
|
|
|
|
def test_convert_openai_chat_completion_choice_basic_with_tool_calls():
|
|
response = Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="tool_call_id",
|
|
type="function",
|
|
function={
|
|
"name": "test_function",
|
|
"arguments": '{"test_args": "test_value"}',
|
|
},
|
|
)
|
|
],
|
|
),
|
|
finish_reason="tool_calls",
|
|
)
|
|
|
|
result = convert_openai_chat_completion_choice(response)
|
|
assert isinstance(result, ChatCompletionResponse)
|
|
assert result.completion_message.content == "Hello, world!"
|
|
assert result.completion_message.stop_reason == StopReason.end_of_message
|
|
assert len(result.completion_message.tool_calls) == 1
|
|
assert result.completion_message.tool_calls[0].tool_name == "test_function"
|
|
assert result.completion_message.tool_calls[0].arguments == {
|
|
"test_args": "test_value"
|
|
}
|
|
assert result.logprobs is None
|
|
|
|
|
|
def test_convert_openai_chat_completion_choice_basic_with_logprobs():
|
|
response = Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content="Hello world",
|
|
),
|
|
finish_reason="stop",
|
|
logprobs=ChoiceLogprobs(
|
|
content=[
|
|
ChatCompletionTokenLogprob(
|
|
token="Hello",
|
|
logprob=-1.0,
|
|
bytes=[72, 101, 108, 108, 111],
|
|
top_logprobs=[
|
|
TopLogprob(
|
|
token="Hello", logprob=-1.0, bytes=[72, 101, 108, 108, 111]
|
|
),
|
|
TopLogprob(
|
|
token="Greetings",
|
|
logprob=-1.5,
|
|
bytes=[71, 114, 101, 101, 116, 105, 110, 103, 115],
|
|
),
|
|
],
|
|
),
|
|
ChatCompletionTokenLogprob(
|
|
token="world",
|
|
logprob=-1.5,
|
|
bytes=[119, 111, 114, 108, 100],
|
|
top_logprobs=[
|
|
TopLogprob(
|
|
token="world", logprob=-1.5, bytes=[119, 111, 114, 108, 100]
|
|
),
|
|
TopLogprob(
|
|
token="planet",
|
|
logprob=-2.0,
|
|
bytes=[112, 108, 97, 110, 101, 116],
|
|
),
|
|
],
|
|
),
|
|
]
|
|
),
|
|
)
|
|
|
|
result = convert_openai_chat_completion_choice(response)
|
|
assert isinstance(result, ChatCompletionResponse)
|
|
assert result.completion_message.content == "Hello world"
|
|
assert result.completion_message.stop_reason == StopReason.end_of_turn
|
|
assert result.completion_message.tool_calls == []
|
|
assert result.logprobs is not None
|
|
assert len(result.logprobs) == 2
|
|
assert len(result.logprobs[0].logprobs_by_token) == 2
|
|
assert result.logprobs[0].logprobs_by_token["Hello"] == -1.0
|
|
assert result.logprobs[0].logprobs_by_token["Greetings"] == -1.5
|
|
assert len(result.logprobs[1].logprobs_by_token) == 2
|
|
assert result.logprobs[1].logprobs_by_token["world"] == -1.5
|
|
assert result.logprobs[1].logprobs_by_token["planet"] == -2.0
|
|
|
|
|
|
def test_convert_openai_chat_completion_choice_missing_message():
|
|
response = Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
|
|
response.message = None
|
|
with pytest.raises(
|
|
AssertionError, match="error in server response: message not found"
|
|
):
|
|
convert_openai_chat_completion_choice(response)
|
|
|
|
del response.message
|
|
with pytest.raises(
|
|
AssertionError, match="error in server response: message not found"
|
|
):
|
|
convert_openai_chat_completion_choice(response)
|
|
|
|
|
|
def test_convert_openai_chat_completion_choice_missing_finish_reason():
|
|
response = Choice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
|
|
response.finish_reason = None
|
|
with pytest.raises(
|
|
AssertionError, match="error in server response: finish_reason not found"
|
|
):
|
|
convert_openai_chat_completion_choice(response)
|
|
|
|
del response.finish_reason
|
|
with pytest.raises(
|
|
AssertionError, match="error in server response: finish_reason not found"
|
|
):
|
|
convert_openai_chat_completion_choice(response)
|
|
|
|
|
|
# we want to test convert_openai_chat_completion_stream
|
|
# we need to produce a stream of OpenAIChatCompletionChunk
|
|
# streams to produce -
|
|
# 0. basic stream with one chunk, should produce 3 (start, progress, complete)
|
|
# 1. stream with 3 chunks, should produce 5 events (start, progress, progress, progress, complete)
|
|
# 2. stream with a tool call, should produce 4 events (start, progress w/ tool_call, complete)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_openai_chat_completion_stream_basic():
|
|
chunks = [
|
|
OpenAIChatCompletionChunk(
|
|
id="1",
|
|
created=1234567890,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
)
|
|
]
|
|
|
|
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
|
for item in items:
|
|
yield item
|
|
|
|
results = [
|
|
result
|
|
async for result in convert_openai_chat_completion_stream(
|
|
async_generator_from_list(chunks)
|
|
)
|
|
]
|
|
|
|
assert len(results) == 2
|
|
assert all(
|
|
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
|
)
|
|
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
|
assert results[0].event.delta == "Hello, world!"
|
|
assert results[1].event.event_type == ChatCompletionResponseEventType.complete
|
|
assert results[1].event.stop_reason == StopReason.end_of_turn
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_openai_chat_completion_stream_basic_empty():
|
|
chunks = [
|
|
OpenAIChatCompletionChunk(
|
|
id="1",
|
|
created=1234567890,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
),
|
|
OpenAIChatCompletionChunk(
|
|
id="1",
|
|
created=1234567890,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
),
|
|
]
|
|
|
|
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
|
for item in items:
|
|
yield item
|
|
|
|
results = [
|
|
result
|
|
async for result in convert_openai_chat_completion_stream(
|
|
async_generator_from_list(chunks)
|
|
)
|
|
]
|
|
|
|
print(results)
|
|
|
|
assert len(results) == 3
|
|
assert all(
|
|
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
|
)
|
|
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
|
assert results[1].event.event_type == ChatCompletionResponseEventType.progress
|
|
assert results[1].event.delta == "Hello, world!"
|
|
assert results[2].event.event_type == ChatCompletionResponseEventType.complete
|
|
assert results[2].event.stop_reason == StopReason.end_of_turn
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_openai_chat_completion_stream_multiple_chunks():
|
|
chunks = [
|
|
OpenAIChatCompletionChunk(
|
|
id="1",
|
|
created=1234567890,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
),
|
|
# finish_reason="continue",
|
|
)
|
|
],
|
|
),
|
|
OpenAIChatCompletionChunk(
|
|
id="2",
|
|
created=1234567891,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="How are you?",
|
|
),
|
|
# finish_reason="continue",
|
|
)
|
|
],
|
|
),
|
|
OpenAIChatCompletionChunk(
|
|
id="3",
|
|
created=1234567892,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="I'm good, thanks!",
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
),
|
|
]
|
|
|
|
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
|
for item in items:
|
|
yield item
|
|
|
|
results = [
|
|
result
|
|
async for result in convert_openai_chat_completion_stream(
|
|
async_generator_from_list(chunks)
|
|
)
|
|
]
|
|
|
|
assert len(results) == 4
|
|
assert all(
|
|
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
|
)
|
|
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
|
assert results[0].event.delta == "Hello, world!"
|
|
assert not results[0].event.stop_reason
|
|
assert results[1].event.event_type == ChatCompletionResponseEventType.progress
|
|
assert results[1].event.delta == "How are you?"
|
|
assert not results[1].event.stop_reason
|
|
assert results[2].event.event_type == ChatCompletionResponseEventType.progress
|
|
assert results[2].event.delta == "I'm good, thanks!"
|
|
assert not results[2].event.stop_reason
|
|
assert results[3].event.event_type == ChatCompletionResponseEventType.complete
|
|
assert results[3].event.stop_reason == StopReason.end_of_turn
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_openai_chat_completion_stream_with_tool_call_and_content():
|
|
chunks = [
|
|
OpenAIChatCompletionChunk(
|
|
id="1",
|
|
created=1234567890,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
content="Hello, world!",
|
|
tool_calls=[
|
|
ChoiceDeltaToolCall(
|
|
index=0,
|
|
id="tool_call_id",
|
|
type="function",
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name="test_function",
|
|
arguments='{"test_args": "test_value"}',
|
|
),
|
|
)
|
|
],
|
|
),
|
|
finish_reason="tool_calls",
|
|
)
|
|
],
|
|
)
|
|
]
|
|
|
|
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
|
for item in items:
|
|
yield item
|
|
|
|
results = [
|
|
result
|
|
async for result in convert_openai_chat_completion_stream(
|
|
async_generator_from_list(chunks)
|
|
)
|
|
]
|
|
|
|
assert len(results) == 3
|
|
assert all(
|
|
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
|
)
|
|
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
|
assert results[0].event.delta == "Hello, world!"
|
|
assert not results[0].event.stop_reason
|
|
assert results[1].event.event_type == ChatCompletionResponseEventType.progress
|
|
assert not isinstance(results[1].event.delta, str)
|
|
assert results[1].event.delta.content.tool_name == "test_function"
|
|
assert results[1].event.delta.content.arguments == {"test_args": "test_value"}
|
|
assert not results[1].event.stop_reason
|
|
assert results[2].event.event_type == ChatCompletionResponseEventType.complete
|
|
assert results[2].event.stop_reason == StopReason.end_of_message
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convert_openai_chat_completion_stream_with_tool_call_and_no_content():
|
|
chunks = [
|
|
OpenAIChatCompletionChunk(
|
|
id="1",
|
|
created=1234567890,
|
|
model="mock-model",
|
|
object="chat.completion.chunk",
|
|
choices=[
|
|
ChoiceChunk(
|
|
index=0,
|
|
delta=ChoiceDelta(
|
|
role="assistant",
|
|
tool_calls=[
|
|
ChoiceDeltaToolCall(
|
|
index=0,
|
|
id="tool_call_id",
|
|
type="function",
|
|
function=ChoiceDeltaToolCallFunction(
|
|
name="test_function",
|
|
arguments='{"test_args": "test_value"}',
|
|
),
|
|
)
|
|
],
|
|
),
|
|
finish_reason="tool_calls",
|
|
)
|
|
],
|
|
)
|
|
]
|
|
|
|
async def async_generator_from_list(items: List) -> AsyncGenerator:
|
|
for item in items:
|
|
yield item
|
|
|
|
results = [
|
|
result
|
|
async for result in convert_openai_chat_completion_stream(
|
|
async_generator_from_list(chunks)
|
|
)
|
|
]
|
|
|
|
assert len(results) == 2
|
|
assert all(
|
|
isinstance(result, ChatCompletionResponseStreamChunk) for result in results
|
|
)
|
|
assert results[0].event.event_type == ChatCompletionResponseEventType.start
|
|
assert not isinstance(results[0].event.delta, str)
|
|
assert results[0].event.delta.content.tool_name == "test_function"
|
|
assert results[0].event.delta.content.arguments == {"test_args": "test_value"}
|
|
assert not results[0].event.stop_reason
|
|
assert results[1].event.event_type == ChatCompletionResponseEventType.complete
|
|
assert results[1].event.stop_reason == StopReason.end_of_message
|