Significantly simpler and malleable test setup

This commit is contained in:
Ashwin Bharambe 2024-11-01 13:07:59 -07:00 committed by Ashwin Bharambe
parent c9bf1d7d0b
commit bba6717ef5
7 changed files with 511 additions and 339 deletions

2
.gitignore vendored
View file

@ -15,5 +15,5 @@ Package.resolved
*.ipynb_checkpoints* *.ipynb_checkpoints*
.idea .idea
.venv/ .venv/
.idea .vscode
_build _build

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from dotenv import load_dotenv
from termcolor import colored
def pytest_configure(config):
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
def pytest_addoption(parser):
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
)
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
class MissingCredentialError(Exception):
pass
def get_env_or_fail(key: str) -> str:
"""Get environment variable or raise helpful error"""
value = os.getenv(key)
if not value:
raise MissingCredentialError(
f"\nMissing {key} in environment. Please set it using one of these methods:"
f"\n1. Export in shell: export {key}=your-key"
f"\n2. Create .env file in project root with: {key}=your-key"
f"\n3. Pass directly to pytest: pytest --env {key}=your-key"
)
return value

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Tuple
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig
from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig
from llama_stack.providers.adapters.inference.together import TogetherImplConfig
from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..env import get_env_or_fail
MODEL_PARAMS = [
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
@pytest.fixture(scope="session", params=MODEL_PARAMS)
def llama_model(request):
return request.param
@pytest.fixture(scope="session")
def meta_reference(llama_model) -> Provider:
return Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=MetaReferenceInferenceConfig(
model=llama_model,
max_seq_len=512,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
)
@pytest.fixture(scope="session")
def ollama(llama_model) -> Provider:
if llama_model == "Llama3.1-8B-Instruct":
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing")
return Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=(
OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump()
),
)
@pytest.fixture(scope="session")
def fireworks(llama_model) -> Provider:
return Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
)
@pytest.fixture(scope="session")
def together(llama_model) -> Tuple[Provider, Dict[str, Any]]:
provider = Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
)
return provider, dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
)
PROVIDER_PARAMS = [
pytest.param("meta_reference", marks=pytest.mark.meta_reference),
pytest.param("ollama", marks=pytest.mark.ollama),
pytest.param("fireworks", marks=pytest.mark.fireworks),
pytest.param("together", marks=pytest.mark.together),
]
@pytest_asyncio.fixture(
scope="session",
params=PROVIDER_PARAMS,
)
async def stack_impls(request):
provider_fixture = request.param
provider = request.getfixturevalue(provider_fixture)
if isinstance(provider, tuple):
provider, provider_data = provider
else:
provider_data = None
impls = await resolve_impls_for_test_v2(
[Api.inference],
{"inference": [provider.model_dump()]},
provider_data,
)
return (impls[Api.inference], impls[Api.models])
def pytest_configure(config):
config.addinivalue_line(
"markers", "llama_8b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers", "llama_3b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers",
"meta_reference: marks tests as metaref specific",
)
config.addinivalue_line(
"markers",
"ollama: marks tests as ollama specific",
)
config.addinivalue_line(
"markers",
"fireworks: marks tests as fireworks specific",
)
config.addinivalue_line(
"markers",
"together: marks tests as fireworks specific",
)

View file

@ -1,28 +0,0 @@
providers:
- provider_id: test-ollama
provider_type: remote::ollama
config:
host: localhost
port: 11434
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://localhost:7001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-together
provider_type: remote::together
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-together":
together_api_key: 0xdeadbeefputrealapikeyhere

View file

@ -5,10 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import itertools import itertools
import os
import pytest import pytest
import pytest_asyncio
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -16,24 +14,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test from .conftest import MODEL_PARAMS, PROVIDER_PARAMS
# How to run this test: # How to run this test:
# #
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky # pytest llama_stack/providers/tests/inference/test_inference.py
# since it depends on the provider you are testing. On top of that you need # -m "(fireworks or ollama) and llama_3b"
# `pytest` and `pytest-asyncio` installed. # -v -s --tb=short --disable-warnings
# # --env FIREWORKS_API_KEY=<your_api_key>
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
# --tb=short --disable-warnings
# ```
def group_chunks(response): def group_chunks(response):
@ -45,45 +33,19 @@ def group_chunks(response):
} }
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str): def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
if "MODEL_IDS" not in os.environ: @pytest.fixture
MODEL_IDS = [Llama_8B, Llama_3B] def common_params(llama_model):
else:
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[{"model": m} for m in MODEL_IDS],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
)
return { return {
"impl": impls[Api.inference], "tool_choice": ToolChoice.auto,
"models_impl": impls[Api.models], "tool_prompt_format": (
"common_params": { ToolPromptFormat.json
"model": model, if "Llama3.1" in llama_model
"tool_choice": ToolChoice.auto, else ToolPromptFormat.python_list
"tool_prompt_format": ( ),
ToolPromptFormat.json
if "Llama3.1" in model
else ToolPromptFormat.python_list
),
},
} }
@ -109,301 +71,313 @@ def sample_tool_definition():
) )
@pytest.mark.asyncio @pytest.mark.parametrize("llama_model", MODEL_PARAMS, indirect=True)
async def test_model_list(inference_settings): @pytest.mark.parametrize(
params = inference_settings["common_params"] "stack_impls",
models_impl = inference_settings["models_impl"] PROVIDER_PARAMS,
response = await models_impl.list_models() indirect=True,
assert isinstance(response, list) )
assert len(response) >= 1 class TestInference:
assert all(isinstance(model, ModelDefWithProvider) for model in response) @pytest.mark.asyncio
async def test_model_list(self, llama_model, stack_impls):
_, models_impl = stack_impls
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, ModelDefWithProvider) for model in response)
model_def = None model_def = None
for model in response: for model in response:
if model.identifier == params["model"]: if model.identifier == llama_model:
model_def = model model_def = model
break break
assert model_def is not None assert model_def is not None
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_completion(self, llama_model, stack_impls, common_params):
inference_impl, _ = stack_impls
@pytest.mark.asyncio provider = inference_impl.routing_table.get_provider_impl(llama_model)
async def test_completion(inference_settings): if provider.__provider_spec__.provider_type not in (
inference_impl = inference_settings["impl"] "meta-reference",
params = inference_settings["common_params"] "remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
):
pytest.skip("Other inference providers don't support completion() yet")
provider = inference_impl.routing_table.get_provider_impl(params["model"]) response = await inference_impl.completion(
if provider.__provider_spec__.provider_type not in ( content="Micheael Jordan is born in ",
"meta-reference", stream=False,
"remote::ollama", model=llama_model,
"remote::tgi",
"remote::together",
"remote::fireworks",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "1963" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=params["model"],
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
) )
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) assert isinstance(response, CompletionResponse)
assert len(chunks) >= 1 assert "1963" in response.content
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=llama_model,
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
@pytest.mark.asyncio assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
@pytest.mark.skip("This test is not quite robust") assert len(chunks) >= 1
async def test_completions_structured_output(inference_settings): last = chunks[-1]
inference_impl = inference_settings["impl"] assert last.stop_reason == StopReason.out_of_tokens
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"]) @pytest.mark.asyncio
if provider.__provider_spec__.provider_type not in ( @pytest.mark.skip("This test is not quite robust")
"meta-reference", async def test_completions_structured_output(
"remote::tgi", self, llama_model, stack_impls, common_params
"remote::together",
"remote::fireworks",
): ):
pytest.skip( inference_impl, _ = stack_impls
"Other inference providers don't support structured output in completions yet"
provider = inference_impl.routing_table.get_provider_impl(llama_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
"remote::together",
"remote::fireworks",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=user_input,
stream=False,
model=llama_model,
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
) )
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
class Output(BaseModel): answer = Output.model_validate_json(response.content)
name: str assert answer.name == "Michael Jordan"
year_born: str assert answer.year_born == "1963"
year_retired: str assert answer.year_retired == "2003"
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." @pytest.mark.asyncio
response = await inference_impl.completion( async def test_chat_completion_non_streaming(
content=user_input, self, llama_model, stack_impls, common_params, sample_messages
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonSchemaResponseFormat(
json_schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.parse_raw(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = await inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
): ):
pytest.skip("Other inference providers don't support structured output yet") inference_impl, _ = stack_impls
response = await inference_impl.chat_completion(
class AnswerFormat(BaseModel): model=llama_model,
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.parse_raw(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.parse_raw(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in await inference_impl.chat_completion(
messages=sample_messages, messages=sample_messages,
stream=True, stream=False,
**inference_settings["common_params"], **common_params,
) )
]
assert len(response) > 0 assert isinstance(response, ChatCompletionResponse)
assert all( assert response.completion_message.role == "assistant"
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response assert isinstance(response.completion_message.content, str)
) assert len(response.completion_message.content) > 0
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0] @pytest.mark.asyncio
assert end.event.stop_reason == StopReason.end_of_turn async def test_structured_output(self, llama_model, stack_impls, common_params):
inference_impl, _ = stack_impls
provider = inference_impl.routing_table.get_provider_impl(llama_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
):
pytest.skip("Other inference providers don't support structured output yet")
@pytest.mark.asyncio class AnswerFormat(BaseModel):
async def test_chat_completion_with_tool_calling( first_name: str
inference_settings, last_name: str
sample_messages, year_of_birth: int
sample_tool_definition, num_seasons_in_nba: int
):
inference_impl = inference_settings["impl"] response = await inference_impl.chat_completion(
messages = sample_messages + [ model=llama_model,
UserMessage( messages=[
content="What's the weather like in San Francisco?", SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(),
),
**common_params,
) )
]
response = await inference_impl.chat_completion( assert isinstance(response, ChatCompletionResponse)
messages=messages, assert response.completion_message.role == "assistant"
tools=[sample_tool_definition], assert isinstance(response.completion_message.content, str)
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse) answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
message = response.completion_message response = await inference_impl.chat_completion(
model=llama_model,
# This is not supported in most providers :/ they don't return eom_id / eot_id messages=[
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"]) SystemMessage(content="You are a helpful assistant."),
# assert message.stop_reason == stop_reason UserMessage(content="Please give me information about Michael Jordan."),
assert message.tool_calls is not None ],
assert len(message.tool_calls) > 0 stream=False,
**common_params,
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
inference_settings,
sample_messages,
sample_tool_definition,
):
inference_impl = inference_settings["impl"]
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
) )
]
response = [ assert isinstance(response, ChatCompletionResponse)
r assert isinstance(response.completion_message.content, str)
async for r in await inference_impl.chat_completion(
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(
self, llama_model, stack_impls, common_params, sample_messages
):
inference_impl, _ = stack_impls
response = [
r
async for r in await inference_impl.chat_completion(
model=llama_model,
messages=sample_messages,
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling(
self,
llama_model,
stack_impls,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = stack_impls
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = await inference_impl.chat_completion(
model=llama_model,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=True, stream=False,
**inference_settings["common_params"], **common_params,
) )
]
assert len(response) > 0 assert isinstance(response, ChatCompletionResponse)
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
# This is not supported in most providers :/ they don't return eom_id / eot_id message = response.completion_message
# expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"]
# )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"] # This is not supported in most providers :/ they don't return eom_id / eot_id
if "Llama3.1" in model: # stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
# assert message.stop_reason == stop_reason
assert message.tool_calls is not None
assert len(message.tool_calls) > 0
call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_chat_completion_with_tool_calling_streaming(
self,
llama_model,
stack_impls,
common_params,
sample_messages,
sample_tool_definition,
):
inference_impl, _ = stack_impls
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
response = [
r
async for r in await inference_impl.chat_completion(
model=llama_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
**common_params,
)
]
assert len(response) > 0
assert all( assert all(
isinstance(chunk.event.delta, ToolCallDelta) isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
for chunk in grouped[ChatCompletionResponseEventType.progress]
) )
first = grouped[ChatCompletionResponseEventType.progress][0] grouped = group_chunks(response)
assert first.event.delta.parse_status == ToolCallParseStatus.started assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
last = grouped[ChatCompletionResponseEventType.progress][-1] # This is not supported in most providers :/ they don't return eom_id / eot_id
# assert last.event.stop_reason == expected_stop_reason # expected_stop_reason = get_expected_stop_reason(
assert last.event.delta.parse_status == ToolCallParseStatus.success # inference_settings["common_params"]["model"]
assert isinstance(last.event.delta.content, ToolCall) # )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason
call = last.event.delta.content if "Llama3.1" in llama_model:
assert call.tool_name == "get_weather" assert all(
assert "location" in call.arguments isinstance(chunk.event.delta, ToolCallDelta)
assert "San Francisco" in call.arguments["location"] for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

View file

@ -7,7 +7,7 @@
import json import json
import os import os
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import yaml import yaml
@ -18,6 +18,28 @@ from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, Provider],
provider_data: Optional[Dict[str, Any]] = None,
):
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=apis,
providers=providers,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
async def resolve_impls_for_test(api: Api, deps: List[Api] = None): async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ: if "PROVIDER_CONFIG" not in os.environ:
raise ValueError( raise ValueError(