mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 23:03:49 +00:00
refactor fixtures and add support for composable fixtures
This commit is contained in:
parent
a42fbea1b8
commit
dd049d5727
10 changed files with 485 additions and 270 deletions
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .conftest import MODEL_PARAMS, PROVIDER_PARAMS
|
||||
from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -38,12 +38,12 @@ def get_expected_stop_reason(model: str):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def common_params(llama_model):
|
||||
def common_params(inference_model):
|
||||
return {
|
||||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in llama_model
|
||||
if "Llama3.1" in inference_model
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
}
|
||||
|
@ -71,16 +71,19 @@ def sample_tool_definition():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model", MODEL_PARAMS, indirect=True)
|
||||
@pytest.mark.parametrize("inference_model", MODEL_PARAMS, indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"stack_impls",
|
||||
PROVIDER_PARAMS,
|
||||
"inference_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in INFERENCE_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestInference:
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(self, llama_model, stack_impls):
|
||||
_, models_impl = stack_impls
|
||||
async def test_model_list(self, inference_model, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
response = await models_impl.list_models()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) >= 1
|
||||
|
@ -88,17 +91,17 @@ class TestInference:
|
|||
|
||||
model_def = None
|
||||
for model in response:
|
||||
if model.identifier == llama_model:
|
||||
if model.identifier == inference_model:
|
||||
model_def = model
|
||||
break
|
||||
|
||||
assert model_def is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(self, llama_model, stack_impls, common_params):
|
||||
inference_impl, _ = stack_impls
|
||||
async def test_completion(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(llama_model)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
|
@ -111,7 +114,7 @@ class TestInference:
|
|||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
@ -125,7 +128,7 @@ class TestInference:
|
|||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
@ -140,11 +143,11 @@ class TestInference:
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completions_structured_output(
|
||||
self, llama_model, stack_impls, common_params
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
inference_impl, _ = stack_impls
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(llama_model)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::tgi",
|
||||
|
@ -164,7 +167,7 @@ class TestInference:
|
|||
response = await inference_impl.completion(
|
||||
content=user_input,
|
||||
stream=False,
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
@ -182,11 +185,11 @@ class TestInference:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(
|
||||
self, llama_model, stack_impls, common_params, sample_messages
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
inference_impl, _ = stack_impls
|
||||
inference_impl, _ = inference_stack
|
||||
response = await inference_impl.chat_completion(
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**common_params,
|
||||
|
@ -198,10 +201,12 @@ class TestInference:
|
|||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(self, llama_model, stack_impls, common_params):
|
||||
inference_impl, _ = stack_impls
|
||||
async def test_structured_output(
|
||||
self, inference_model, inference_stack, common_params
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(llama_model)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::fireworks",
|
||||
|
@ -217,7 +222,7 @@ class TestInference:
|
|||
num_seasons_in_nba: int
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
|
@ -240,7 +245,7 @@ class TestInference:
|
|||
assert answer.num_seasons_in_nba == 15
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
|
@ -257,13 +262,13 @@ class TestInference:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(
|
||||
self, llama_model, stack_impls, common_params, sample_messages
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
inference_impl, _ = stack_impls
|
||||
inference_impl, _ = inference_stack
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**common_params,
|
||||
|
@ -285,13 +290,13 @@ class TestInference:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling(
|
||||
self,
|
||||
llama_model,
|
||||
stack_impls,
|
||||
inference_model,
|
||||
inference_stack,
|
||||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = stack_impls
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
@ -299,7 +304,7 @@ class TestInference:
|
|||
]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=False,
|
||||
|
@ -324,13 +329,13 @@ class TestInference:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_calling_streaming(
|
||||
self,
|
||||
llama_model,
|
||||
stack_impls,
|
||||
inference_model,
|
||||
inference_stack,
|
||||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
):
|
||||
inference_impl, _ = stack_impls
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
@ -340,7 +345,7 @@ class TestInference:
|
|||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=llama_model,
|
||||
model=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
|
@ -364,7 +369,7 @@ class TestInference:
|
|||
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
# assert end.event.stop_reason == expected_stop_reason
|
||||
|
||||
if "Llama3.1" in llama_model:
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue