forked from phoenix-oss/llama-stack-mirror
Significantly simpler and malleable test setup (#360)
* Significantly simpler and malleable test setup * convert memory tests * refactor fixtures and add support for composable fixtures * Fix memory to use the newer fixture organization * Get agents tests working * Safety tests work * yet another refactor to make this more general now it accepts --inference-model, --safety-model options also * get multiple providers working for meta-reference (for inference + safety) * Add README.md --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
663883cc29
commit
ffedb81c11
25 changed files with 1491 additions and 790 deletions
62
llama_stack/providers/tests/inference/conftest.py
Normal file
62
llama_stack/providers/tests/inference/conftest.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from .fixtures import INFERENCE_FIXTURES
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "llama_8b: mark test to run only with the given model"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "llama_3b: mark test to run only with the given model"
|
||||
)
|
||||
for fixture_name in INFERENCE_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
MODEL_PARAMS = [
|
||||
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
|
||||
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||
]
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--inference-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = MODEL_PARAMS
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
params,
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"inference_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in INFERENCE_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
120
llama_stack/providers/tests/inference/fixtures.py
Normal file
120
llama_stack/providers/tests/inference/fixtures.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.adapters.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.impls.meta_reference.inference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from ..conftest import ProviderFixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_model(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--inference-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id=f"meta-reference-{i}",
|
||||
provider_type="meta-reference",
|
||||
config=MetaReferenceInferenceConfig(
|
||||
model=m,
|
||||
max_seq_len=4096,
|
||||
create_distributed_process_group=False,
|
||||
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
||||
).model_dump(),
|
||||
)
|
||||
for i, m in enumerate(inference_model)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_ollama(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
if "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig(
|
||||
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_fireworks() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
config=FireworksImplConfig(
|
||||
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_together() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="together",
|
||||
provider_type="remote::together",
|
||||
config=TogetherImplConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def inference_stack(request):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
)
|
||||
|
||||
return (impls[Api.inference], impls[Api.models])
|
|
@ -1,28 +0,0 @@
|
|||
providers:
|
||||
- provider_id: test-ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.2-1B-Instruct
|
||||
- provider_id: test-tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://localhost:7001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
||||
- provider_id: test-together
|
||||
provider_type: remote::together
|
||||
config: {}
|
||||
# if a provider needs private keys from the client, they use the
|
||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||
# this is a place to provide such data.
|
||||
provider_data:
|
||||
"test-together":
|
||||
together_api_key: 0xdeadbeefputrealapikeyhere
|
|
@ -5,10 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
@ -16,24 +14,12 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_inference.py
|
||||
# -m "(fireworks or ollama) and llama_3b"
|
||||
# --env FIREWORKS_API_KEY=<your_api_key>
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
|
@ -45,45 +31,19 @@ def group_chunks(response):
|
|||
}
|
||||
|
||||
|
||||
Llama_8B = "Llama3.1-8B-Instruct"
|
||||
Llama_3B = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
||||
if "MODEL_IDS" not in os.environ:
|
||||
MODEL_IDS = [Llama_8B, Llama_3B]
|
||||
else:
|
||||
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
|
||||
|
||||
|
||||
# This is going to create multiple Stack impls without tearing down the previous one
|
||||
# Fix that!
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=[{"model": m} for m in MODEL_IDS],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
async def inference_settings(request):
|
||||
model = request.param["model"]
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.inference,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def common_params(inference_model):
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"models_impl": impls[Api.models],
|
||||
"common_params": {
|
||||
"model": model,
|
||||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in model
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
},
|
||||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in inference_model
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -109,301 +69,309 @@ def sample_tool_definition():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(inference_settings):
|
||||
params = inference_settings["common_params"]
|
||||
models_impl = inference_settings["models_impl"]
|
||||
response = await models_impl.list_models()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
assert all(isinstance(model, ModelDefWithProvider) for model in response)
|
||||
class TestInference:
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(self, inference_model, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
response = await models_impl.list_models()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
assert all(isinstance(model, ModelDefWithProvider) for model in response)
|
||||
|
||||
model_def = None
|
||||
for model in response:
|
||||
if model.identifier == params["model"]:
|
||||
model_def = model
|
||||
break
|
||||
model_def = None
|
||||
for model in response:
|
||||
if model.identifier == inference_model:
|
||||
model_def = model
|
||||
break
|
||||
|
||||
assert model_def is not None
|
||||
assert model_def.identifier == params["model"]
|
||||
assert model_def is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "1963" in response.content
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model=params["model"],
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert len(chunks) >= 1
|
||||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "1963" in response.content
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completions_structured_output(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert len(chunks) >= 1
|
||||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completions_structured_output(
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
pytest.skip(
|
||||
"Other inference providers don't support structured output in completions yet"
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
):
|
||||
pytest.skip(
|
||||
"Other inference providers don't support structured output in completions yet"
|
||||
)
|
||||
|
||||
class Output(BaseModel):
|
||||
name: str
|
||||
year_born: str
|
||||
year_retired: str
|
||||
|
||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||
response = await inference_impl.completion(
|
||||
content=user_input,
|
||||
stream=False,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
response_format=JsonSchemaResponseFormat(
|
||||
json_schema=Output.model_json_schema(),
|
||||
),
|
||||
)
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
class Output(BaseModel):
|
||||
name: str
|
||||
year_born: str
|
||||
year_retired: str
|
||||
answer = Output.model_validate_json(response.content)
|
||||
assert answer.name == "Michael Jordan"
|
||||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
|
||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||
response = await inference_impl.completion(
|
||||
content=user_input,
|
||||
stream=False,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
response_format=JsonSchemaResponseFormat(
|
||||
json_schema=Output.model_json_schema(),
|
||||
),
|
||||
)
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
answer = Output.parse_raw(response.content)
|
||||
assert answer.name == "Michael Jordan"
|
||||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::fireworks",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
pytest.skip("Other inference providers don't support structured output yet")
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
stream=False,
|
||||
response_format=JsonSchemaResponseFormat(
|
||||
json_schema=AnswerFormat.model_json_schema(),
|
||||
),
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
|
||||
answer = AnswerFormat.parse_raw(response.completion_message.content)
|
||||
assert answer.first_name == "Michael"
|
||||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AnswerFormat.parse_raw(response.completion_message.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
inference_impl, _ = inference_stack
|
||||
response = await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
stream=False,
|
||||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(
|
||||
self, inference_model, inference_stack, common_params
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::fireworks",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support structured output yet")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
stream=False,
|
||||
response_format=JsonSchemaResponseFormat(
|
||||
json_schema=AnswerFormat.model_json_schema(),
|
||||
),
|
||||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=False,
|
||||
**inference_settings["common_params"],
|
||||
)
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
assert answer.first_name == "Michael"
|
||||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
|
||||
message = response.completion_message
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||
# assert message.stop_reason == stop_reason
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) > 0
|
||||
|
||||
call = message.tool_calls[0]
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
inference_settings,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl = inference_settings["impl"]
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
response = await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
stream=False,
|
||||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
assert end.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
self,
|
||||
inference_model,
|
||||
inference_stack,
|
||||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
stream=False,
|
||||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# expected_stop_reason = get_expected_stop_reason(
|
||||
# inference_settings["common_params"]["model"]
|
||||
# )
|
||||
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
# assert end.event.stop_reason == expected_stop_reason
|
||||
message = response.completion_message
|
||||
|
||||
model = inference_settings["common_params"]["model"]
|
||||
if "Llama3.1" in model:
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||
# assert message.stop_reason == stop_reason
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) > 0
|
||||
|
||||
call = message.tool_calls[0]
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
self,
|
||||
inference_model,
|
||||
inference_stack,
|
||||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
)
|
||||
]
|
||||
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
**common_params,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# expected_stop_reason = get_expected_stop_reason(
|
||||
# inference_settings["common_params"]["model"]
|
||||
# )
|
||||
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
# assert end.event.stop_reason == expected_stop_reason
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue