Significantly simpler and malleable test setup

This commit is contained in:
Ashwin Bharambe 2024-11-01 13:07:59 -07:00 committed by Ashwin Bharambe
parent c9bf1d7d0b
commit bba6717ef5
7 changed files with 511 additions and 339 deletions

2
.gitignore vendored
View file

@ -15,5 +15,5 @@ Package.resolved
*.ipynb_checkpoints* *.ipynb_checkpoints*
.idea .idea
.venv/ .venv/
.idea .vscode
_build _build

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from dotenv import load_dotenv
from termcolor import colored
def pytest_configure(config):
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
def pytest_addoption(parser):
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
)
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
class MissingCredentialError(Exception):
pass
def get_env_or_fail(key: str) -> str:
"""Get environment variable or raise helpful error"""
value = os.getenv(key)
if not value:
raise MissingCredentialError(
f"\nMissing {key} in environment. Please set it using one of these methods:"
f"\n1. Export in shell: export {key}=your-key"
f"\n2. Create .env file in project root with: {key}=your-key"
f"\n3. Pass directly to pytest: pytest --env {key}=your-key"
)
return value

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Tuple
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig
from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig
from llama_stack.providers.adapters.inference.together import TogetherImplConfig
from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..env import get_env_or_fail
MODEL_PARAMS = [
pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
@pytest.fixture(scope="session", params=MODEL_PARAMS)
def llama_model(request):
return request.param
@pytest.fixture(scope="session")
def meta_reference(llama_model) -> Provider:
return Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=MetaReferenceInferenceConfig(
model=llama_model,
max_seq_len=512,
create_distributed_process_group=False,
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
).model_dump(),
)
@pytest.fixture(scope="session")
def ollama(llama_model) -> Provider:
if llama_model == "Llama3.1-8B-Instruct":
pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing")
return Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=(
OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump()
),
)
@pytest.fixture(scope="session")
def fireworks(llama_model) -> Provider:
return Provider(
provider_id="fireworks",
provider_type="remote::fireworks",
config=FireworksImplConfig(
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
).model_dump(),
)
@pytest.fixture(scope="session")
def together(llama_model) -> Tuple[Provider, Dict[str, Any]]:
provider = Provider(
provider_id="together",
provider_type="remote::together",
config=TogetherImplConfig().model_dump(),
)
return provider, dict(
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
)
PROVIDER_PARAMS = [
pytest.param("meta_reference", marks=pytest.mark.meta_reference),
pytest.param("ollama", marks=pytest.mark.ollama),
pytest.param("fireworks", marks=pytest.mark.fireworks),
pytest.param("together", marks=pytest.mark.together),
]
@pytest_asyncio.fixture(
scope="session",
params=PROVIDER_PARAMS,
)
async def stack_impls(request):
provider_fixture = request.param
provider = request.getfixturevalue(provider_fixture)
if isinstance(provider, tuple):
provider, provider_data = provider
else:
provider_data = None
impls = await resolve_impls_for_test_v2(
[Api.inference],
{"inference": [provider.model_dump()]},
provider_data,
)
return (impls[Api.inference], impls[Api.models])
def pytest_configure(config):
config.addinivalue_line(
"markers", "llama_8b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers", "llama_3b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers",
"meta_reference: marks tests as metaref specific",
)
config.addinivalue_line(
"markers",
"ollama: marks tests as ollama specific",
)
config.addinivalue_line(
"markers",
"fireworks: marks tests as fireworks specific",
)
config.addinivalue_line(
"markers",
"together: marks tests as fireworks specific",
)

View file

@ -1,28 +0,0 @@
providers:
- provider_id: test-ollama
provider_type: remote::ollama
config:
host: localhost
port: 11434
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://localhost:7001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-together
provider_type: remote::together
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-together":
together_api_key: 0xdeadbeefputrealapikeyhere

View file

@ -5,10 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import itertools import itertools
import os
import pytest import pytest
import pytest_asyncio
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -16,24 +14,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test from .conftest import MODEL_PARAMS, PROVIDER_PARAMS
# How to run this test: # How to run this test:
# #
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky # pytest llama_stack/providers/tests/inference/test_inference.py
# since it depends on the provider you are testing. On top of that you need # -m "(fireworks or ollama) and llama_3b"
# `pytest` and `pytest-asyncio` installed. # -v -s --tb=short --disable-warnings
# # --env FIREWORKS_API_KEY=<your_api_key>
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
# --tb=short --disable-warnings
# ```
def group_chunks(response): def group_chunks(response):
@ -45,45 +33,19 @@ def group_chunks(response):
} }
Llama_8B = "Llama3.1-8B-Instruct"
Llama_3B = "Llama3.2-3B-Instruct"
def get_expected_stop_reason(model: str): def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
if "MODEL_IDS" not in os.environ: @pytest.fixture
MODEL_IDS = [Llama_8B, Llama_3B] def common_params(llama_model):
else:
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[{"model": m} for m in MODEL_IDS],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
)
return { return {
"impl": impls[Api.inference],
"models_impl": impls[Api.models],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto, "tool_choice": ToolChoice.auto,
"tool_prompt_format": ( "tool_prompt_format": (
ToolPromptFormat.json ToolPromptFormat.json
if "Llama3.1" in model if "Llama3.1" in llama_model
else ToolPromptFormat.python_list else ToolPromptFormat.python_list
), ),
},
} }
@ -109,10 +71,16 @@ def sample_tool_definition():
) )
@pytest.mark.asyncio @pytest.mark.parametrize("llama_model", MODEL_PARAMS, indirect=True)
async def test_model_list(inference_settings): @pytest.mark.parametrize(
params = inference_settings["common_params"] "stack_impls",
models_impl = inference_settings["models_impl"] PROVIDER_PARAMS,
indirect=True,
)
class TestInference:
@pytest.mark.asyncio
async def test_model_list(self, llama_model, stack_impls):
_, models_impl = stack_impls
response = await models_impl.list_models() response = await models_impl.list_models()
assert isinstance(response, list) assert isinstance(response, list)
assert len(response) >= 1 assert len(response) >= 1
@ -120,20 +88,17 @@ async def test_model_list(inference_settings):
model_def = None model_def = None
for model in response: for model in response:
if model.identifier == params["model"]: if model.identifier == llama_model:
model_def = model model_def = model
break break
assert model_def is not None assert model_def is not None
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_completion(self, llama_model, stack_impls, common_params):
inference_impl, _ = stack_impls
@pytest.mark.asyncio provider = inference_impl.routing_table.get_provider_impl(llama_model)
async def test_completion(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::ollama", "remote::ollama",
@ -146,7 +111,7 @@ async def test_completion(inference_settings):
response = await inference_impl.completion( response = await inference_impl.completion(
content="Micheael Jordan is born in ", content="Micheael Jordan is born in ",
stream=False, stream=False,
model=params["model"], model=llama_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -160,7 +125,7 @@ async def test_completion(inference_settings):
async for r in await inference_impl.completion( async for r in await inference_impl.completion(
content="Roses are red,", content="Roses are red,",
stream=True, stream=True,
model=params["model"], model=llama_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -172,14 +137,14 @@ async def test_completion(inference_settings):
last = chunks[-1] last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, llama_model, stack_impls, common_params
):
inference_impl, _ = stack_impls
@pytest.mark.asyncio provider = inference_impl.routing_table.get_provider_impl(llama_model)
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::tgi", "remote::tgi",
@ -199,7 +164,7 @@ async def test_completions_structured_output(inference_settings):
response = await inference_impl.completion( response = await inference_impl.completion(
content=user_input, content=user_input,
stream=False, stream=False,
model=params["model"], model=llama_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -210,19 +175,21 @@ async def test_completions_structured_output(inference_settings):
assert isinstance(response, CompletionResponse) assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str) assert isinstance(response.content, str)
answer = Output.parse_raw(response.content) answer = Output.model_validate_json(response.content)
assert answer.name == "Michael Jordan" assert answer.name == "Michael Jordan"
assert answer.year_born == "1963" assert answer.year_born == "1963"
assert answer.year_retired == "2003" assert answer.year_retired == "2003"
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_chat_completion_non_streaming(
async def test_chat_completion_non_streaming(inference_settings, sample_messages): self, llama_model, stack_impls, common_params, sample_messages
inference_impl = inference_settings["impl"] ):
inference_impl, _ = stack_impls
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=llama_model,
messages=sample_messages, messages=sample_messages,
stream=False, stream=False,
**inference_settings["common_params"], **common_params,
) )
assert isinstance(response, ChatCompletionResponse) assert isinstance(response, ChatCompletionResponse)
@ -230,13 +197,11 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages
assert isinstance(response.completion_message.content, str) assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0 assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(self, llama_model, stack_impls, common_params):
inference_impl, _ = stack_impls
@pytest.mark.asyncio provider = inference_impl.routing_table.get_provider_impl(llama_model)
async def test_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::fireworks", "remote::fireworks",
@ -252,6 +217,7 @@ async def test_structured_output(inference_settings):
num_seasons_in_nba: int num_seasons_in_nba: int
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=llama_model,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."), UserMessage(content="Please give me information about Michael Jordan."),
@ -260,44 +226,47 @@ async def test_structured_output(inference_settings):
response_format=JsonSchemaResponseFormat( response_format=JsonSchemaResponseFormat(
json_schema=AnswerFormat.model_json_schema(), json_schema=AnswerFormat.model_json_schema(),
), ),
**inference_settings["common_params"], **common_params,
) )
assert isinstance(response, ChatCompletionResponse) assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant" assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str) assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.parse_raw(response.completion_message.content) answer = AnswerFormat.model_validate_json(response.completion_message.content)
assert answer.first_name == "Michael" assert answer.first_name == "Michael"
assert answer.last_name == "Jordan" assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963 assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15 assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=llama_model,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."), UserMessage(content="Please give me information about Michael Jordan."),
], ],
stream=False, stream=False,
**inference_settings["common_params"], **common_params,
) )
assert isinstance(response, ChatCompletionResponse) assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str) assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
AnswerFormat.parse_raw(response.completion_message.content) AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_chat_completion_streaming(
async def test_chat_completion_streaming(inference_settings, sample_messages): self, llama_model, stack_impls, common_params, sample_messages
inference_impl = inference_settings["impl"] ):
inference_impl, _ = stack_impls
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model=llama_model,
messages=sample_messages, messages=sample_messages,
stream=True, stream=True,
**inference_settings["common_params"], **common_params,
) )
] ]
@ -313,14 +282,16 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
end = grouped[ChatCompletionResponseEventType.complete][0] end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn assert end.event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_chat_completion_with_tool_calling(
async def test_chat_completion_with_tool_calling( self,
inference_settings, llama_model,
stack_impls,
common_params,
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
): ):
inference_impl = inference_settings["impl"] inference_impl, _ = stack_impls
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(
content="What's the weather like in San Francisco?", content="What's the weather like in San Francisco?",
@ -328,10 +299,11 @@ async def test_chat_completion_with_tool_calling(
] ]
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=llama_model,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=False, stream=False,
**inference_settings["common_params"], **common_params,
) )
assert isinstance(response, ChatCompletionResponse) assert isinstance(response, ChatCompletionResponse)
@ -349,14 +321,16 @@ async def test_chat_completion_with_tool_calling(
assert "location" in call.arguments assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"] assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
@pytest.mark.asyncio async def test_chat_completion_with_tool_calling_streaming(
async def test_chat_completion_with_tool_calling_streaming( self,
inference_settings, llama_model,
stack_impls,
common_params,
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
): ):
inference_impl = inference_settings["impl"] inference_impl, _ = stack_impls
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(
content="What's the weather like in San Francisco?", content="What's the weather like in San Francisco?",
@ -366,10 +340,11 @@ async def test_chat_completion_with_tool_calling_streaming(
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model=llama_model,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=True, stream=True,
**inference_settings["common_params"], **common_params,
) )
] ]
@ -389,8 +364,7 @@ async def test_chat_completion_with_tool_calling_streaming(
# end = grouped[ChatCompletionResponseEventType.complete][0] # end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason # assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"] if "Llama3.1" in llama_model:
if "Llama3.1" in model:
assert all( assert all(
isinstance(chunk.event.delta, ToolCallDelta) isinstance(chunk.event.delta, ToolCallDelta)
for chunk in grouped[ChatCompletionResponseEventType.progress] for chunk in grouped[ChatCompletionResponseEventType.progress]

View file

@ -7,7 +7,7 @@
import json import json
import os import os
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import yaml import yaml
@ -18,6 +18,28 @@ from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, Provider],
provider_data: Optional[Dict[str, Any]] = None,
):
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=apis,
providers=providers,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
async def resolve_impls_for_test(api: Api, deps: List[Api] = None): async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ: if "PROVIDER_CONFIG" not in os.environ:
raise ValueError( raise ValueError(