llama-stack-mirror/llama_stack/providers/tests/inference/test_inference.py
Ashwin Bharambe dba7caf1d0 Fix fireworks and update the test
Don't look for eom_id / eot_id sadly since providers don't return the
last token
2024-10-08 17:23:02 -07:00

294 lines
9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import itertools
import json
import os
from datetime import datetime
import pytest
import pytest_asyncio
import yaml
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
def group_chunks(response):
return {
event_type: list(group)
for event_type, group in itertools.groupby(
response, key=lambda chunk: chunk.event.event_type
)
}
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
async def stack_impls(model):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[
Api.inference,
Api.models,
],
providers=dict(
inference=[
Provider(**provider),
]
),
models=[
ModelDef(
identifier=model,
llama_model=model,
provider_id=provider["provider_id"],
)
],
shields=[],
memory_banks=[],
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
# may need something cleaner here
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]
impls = await stack_impls(model)
return {
"impl": impls[Api.inference],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in model
else ToolPromptFormat.python_list
),
},
}
@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
]
assert len(response) == 1
assert isinstance(response[0], ChatCompletionResponse)
assert response[0].completion_message.role == "assistant"
assert isinstance(response[0].completion_message.content, str)
assert len(response[0].completion_message.content) > 0
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
]
assert len(response) == 1
assert isinstance(response[0], ChatCompletionResponse)
message = response[0].completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,
**inference_settings["common_params"],
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"]
if "Llama3.1" in model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]