Significantly simpler and malleable test setup

This commit is contained in:
Ashwin Bharambe 2024-11-01 13:07:59 -07:00 committed by Ashwin Bharambe
parent c9bf1d7d0b
commit bba6717ef5
7 changed files with 511 additions and 339 deletions

2
.gitignore vendored
View file

@ -15,5 +15,5 @@ Package.resolved
*.ipynb_checkpoints*
.idea
.venv/
.idea
.vscode
_build

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from dotenv import load_dotenv
from termcolor import colored
def pytest_configure(config):
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
def pytest_addoption(parser):
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
)
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
class MissingCredentialError(Exception):
pass
def get_env_or_fail(key: str) -> str:
"""Get environment variable or raise helpful error"""
value = os.getenv(key)
if not value:
raise MissingCredentialError(
f"\nMissing {key} in environment. Please set it using one of these methods:"
f"\n1. Export in shell: export {key}=your-key"
f"\n2. Create .env file in project root with: {key}=your-key"
f"\n3. Pass directly to pytest: pytest --env {key}=your-key"
)
return value

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Tuple
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig
from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig
from llama_stack.providers.adapters.inference.together import TogetherImplConfig
from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..env import get_env_or_fail
MODEL_PARAMS = [
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
@pytest.fixture(scope="session", params=MODEL_PARAMS)
def llama_model(request):
return request.param
@pytest.fixture(scope="session")
def meta_reference(llama_model) -> Provider:
return Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=MetaReferenceInferenceConfig(
model=llama_model,
max_seq_len=512,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
)
@pytest.fixture(scope="session")
def ollama(llama_model) -> Provider:
if llama_model == "Llama3.1-8B-Instruct":
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing")
return Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=(
OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump()
),
)
@pytest.fixture(scope="session")
def fireworks(llama_model) -> Provider:
return Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
)
@pytest.fixture(scope="session")
def together(llama_model) -> Tuple[Provider, Dict[str, Any]]:
provider = Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
)
return provider, dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
)
PROVIDER_PARAMS = [
pytest.param("meta_reference", marks=pytest.mark.meta_reference),
pytest.param("ollama", marks=pytest.mark.ollama),
pytest.param("fireworks", marks=pytest.mark.fireworks),
pytest.param("together", marks=pytest.mark.together),
]
@pytest_asyncio.fixture(
scope="session",
params=PROVIDER_PARAMS,
)
async def stack_impls(request):
provider_fixture = request.param
provider = request.getfixturevalue(provider_fixture)
if isinstance(provider, tuple):
provider, provider_data = provider
else:
provider_data = None
impls = await resolve_impls_for_test_v2(
[Api.inference],
{"inference": [provider.model_dump()]},
provider_data,
)
return (impls[Api.inference], impls[Api.models])
def pytest_configure(config):
config.addinivalue_line(
"markers", "llama_8b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers", "llama_3b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers",
"meta_reference: marks tests as metaref specific",
)
config.addinivalue_line(
"markers",
"ollama: marks tests as ollama specific",
)
config.addinivalue_line(
"markers",
"fireworks: marks tests as fireworks specific",
)
config.addinivalue_line(
"markers",
"together: marks tests as fireworks specific",
)

View file

@ -1,28 +0,0 @@
providers:
- provider_id: test-ollama
provider_type: remote::ollama
config:
host: localhost
port: 11434
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://localhost:7001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-together
provider_type: remote::together
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-together":
together_api_key: 0xdeadbeefputrealapikeyhere

View file

@ -5,10 +5,8 @@
# the root directory of this source tree.
import itertools
import os
import pytest
import pytest_asyncio
from pydantic import BaseModel, ValidationError
@ -16,24 +14,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
from .conftest import MODEL_PARAMS, PROVIDER_PARAMS
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
# --tb=short --disable-warnings
# ```
# pytest llama_stack/providers/tests/inference/test_inference.py
# -m "(fireworks or ollama) and llama_3b"
# -v -s --tb=short --disable-warnings
# --env FIREWORKS_API_KEY=<your_api_key>
def group_chunks(response):
@ -45,45 +33,19 @@ def group_chunks(response):
}
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
if "MODEL_IDS" not in os.environ:
MODEL_IDS = [Llama_8B, Llama_3B]
else:
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[{"model": m} for m in MODEL_IDS],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
)
@pytest.fixture
def common_params(llama_model):
return {
"impl": impls[Api.inference],
"models_impl": impls[Api.models],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in model
else ToolPromptFormat.python_list
),
},
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in llama_model
else ToolPromptFormat.python_list
),
}
@ -109,301 +71,313 @@ def sample_tool_definition():
)
@pytest.mark.asyncio
async def test_model_list(inference_settings):
params = inference_settings["common_params"]
models_impl = inference_settings["models_impl"]
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, ModelDefWithProvider) for model in response)
@pytest.mark.parametrize("llama_model", MODEL_PARAMS, indirect=True)
@pytest.mark.parametrize(
"stack_impls",
PROVIDER_PARAMS,
indirect=True,
)
class TestInference:
@pytest.mark.asyncio
async def test_model_list(self, llama_model, stack_impls):
_, models_impl = stack_impls
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, ModelDefWithProvider) for model in response)
model_def = None
for model in response:
if model.identifier == params["model"]:
model_def = model
break
model_def = None
for model in response:
if model.identifier == llama_model:
model_def = model
break
assert model_def is not None
assert model_def.identifier == params["model"]
assert model_def is not None
@pytest.mark.asyncio
async def test_completion(self, llama_model, stack_impls, common_params):
inference_impl, _ = stack_impls
@pytest.mark.asyncio
async def test_completion(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(llama_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
):
pytest.skip("Other inference providers don't support completion() yet")
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=params["model"],
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model=llama_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) >= 1
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=llama_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) >= 1
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
"remote::together",
"remote::fireworks",
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, llama_model, stack_impls, common_params
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
inference_impl, _ = stack_impls
provider = inference_impl.routing_table.get_provider_impl(llama_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
"remote::together",
"remote::fireworks",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=user_input,
stream=False,
model=llama_model,
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=user_input,
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.parse_raw(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = await inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(
self, llama_model, stack_impls, common_params, sample_messages
):
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.parse_raw(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.parse_raw(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in await inference_impl.chat_completion(
inference_impl, _ = stack_impls
response = await inference_impl.chat_completion(
model=llama_model,
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
stream=False,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_structured_output(self, llama_model, stack_impls, common_params):
inference_impl, _ = stack_impls
provider = inference_impl.routing_table.get_provider_impl(llama_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
):
pytest.skip("Other inference providers don't support structured output yet")
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model=llama_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**common_params,
)
]
response = await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert isinstance(response, ChatCompletionResponse)
answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
response = await inference_impl.chat_completion(
model=llama_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**common_params,
)
]
response = [
r
async for r in await inference_impl.chat_completion(
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(
self, llama_model, stack_impls, common_params, sample_messages
):
inference_impl, _ = stack_impls
response = [
r
async for r in await inference_impl.chat_completion(
model=llama_model,
messages=sample_messages,
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
self,
llama_model,
stack_impls,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = stack_impls
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
model=llama_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**inference_settings["common_params"],
stream=False,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
assert isinstance(response, ChatCompletionResponse)
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
message = response.completion_message
model = inference_settings["common_params"]["model"]
if "Llama3.1" in model:
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
self,
llama_model,
stack_impls,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = stack_impls
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in await inference_impl.chat_completion(
model=llama_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
# This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
if "Llama3.1" in llama_model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

View file

@ -7,7 +7,7 @@
import json
import os
from datetime import datetime
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import yaml
@ -18,6 +18,28 @@ from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, Provider],
provider_data: Optional[Dict[str, Any]] = None,
):
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=apis,
providers=providers,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(