From bba6717ef531a48b5a639fd8b3a865696b4b94c2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 1 Nov 2024 13:07:59 -0700 Subject: [PATCH] Significantly simpler and malleable test setup --- .gitignore | 2 +- llama_stack/providers/tests/conftest.py | 41 ++ llama_stack/providers/tests/env.py | 24 + .../providers/tests/inference/conftest.py | 139 ++++ .../inference/provider_config_example.yaml | 28 - .../tests/inference/test_inference.py | 592 +++++++++--------- llama_stack/providers/tests/resolver.py | 24 +- 7 files changed, 511 insertions(+), 339 deletions(-) create mode 100644 llama_stack/providers/tests/conftest.py create mode 100644 llama_stack/providers/tests/env.py create mode 100644 llama_stack/providers/tests/inference/conftest.py delete mode 100644 llama_stack/providers/tests/inference/provider_config_example.yaml diff --git a/.gitignore b/.gitignore index 897494f21..90470f8b3 100644 --- a/.gitignore +++ b/.gitignore @@ -15,5 +15,5 @@ Package.resolved *.ipynb_checkpoints* .idea .venv/ -.idea +.vscode _build diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py new file mode 100644 index 000000000..40c826fb8 --- /dev/null +++ b/llama_stack/providers/tests/conftest.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from pathlib import Path + +from dotenv import load_dotenv +from termcolor import colored + + +def pytest_configure(config): + """Load environment variables at start of test run""" + # Load from .env file if it exists + env_file = Path(__file__).parent / ".env" + if env_file.exists(): + load_dotenv(env_file) + + # Load any environment variables passed via --env + env_vars = config.getoption("--env") or [] + for env_var in env_vars: + key, value = env_var.split("=", 1) + os.environ[key] = value + + +def pytest_addoption(parser): + """Add custom command line options""" + parser.addoption( + "--env", action="append", help="Set environment variables, e.g. --env KEY=value" + ) + + +def pytest_itemcollected(item): + # Get all markers as a list + filtered = ("asyncio", "parametrize") + marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered] + if marks: + marks = colored(",".join(marks), "yellow") + item.name = f"{item.name}[{marks}]" diff --git a/llama_stack/providers/tests/env.py b/llama_stack/providers/tests/env.py new file mode 100644 index 000000000..1dac43333 --- /dev/null +++ b/llama_stack/providers/tests/env.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + + +class MissingCredentialError(Exception): + pass + + +def get_env_or_fail(key: str) -> str: + """Get environment variable or raise helpful error""" + value = os.getenv(key) + if not value: + raise MissingCredentialError( + f"\nMissing {key} in environment. Please set it using one of these methods:" + f"\n1. Export in shell: export {key}=your-key" + f"\n2. Create .env file in project root with: {key}=your-key" + f"\n3. Pass directly to pytest: pytest --env {key}=your-key" + ) + return value diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py new file mode 100644 index 000000000..ae679a1b7 --- /dev/null +++ b/llama_stack/providers/tests/inference/conftest.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Tuple + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig +from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig +from llama_stack.providers.adapters.inference.together import TogetherImplConfig +from llama_stack.providers.impls.meta_reference.inference import ( + MetaReferenceInferenceConfig, +) +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..env import get_env_or_fail + + +MODEL_PARAMS = [ + pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"), + pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), +] + + +@pytest.fixture(scope="session", params=MODEL_PARAMS) +def llama_model(request): + return request.param + + +@pytest.fixture(scope="session") +def meta_reference(llama_model) -> Provider: + return Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=MetaReferenceInferenceConfig( + model=llama_model, + max_seq_len=512, + create_distributed_process_group=False, + checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None), + ).model_dump(), + ) + + +@pytest.fixture(scope="session") +def ollama(llama_model) -> Provider: + if llama_model == "Llama3.1-8B-Instruct": + pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing") + + return Provider( + provider_id="ollama", + provider_type="remote::ollama", + config=( + OllamaImplConfig( + host="localhost", port=os.getenv("OLLAMA_PORT", 11434) + ).model_dump() + ), + ) + + +@pytest.fixture(scope="session") +def fireworks(llama_model) -> Provider: + return Provider( + provider_id="fireworks", + provider_type="remote::fireworks", + config=FireworksImplConfig( + api_key=get_env_or_fail("FIREWORKS_API_KEY"), + ).model_dump(), + ) + + +@pytest.fixture(scope="session") +def together(llama_model) -> Tuple[Provider, Dict[str, Any]]: + provider = Provider( + provider_id="together", + provider_type="remote::together", + config=TogetherImplConfig().model_dump(), + ) + return provider, dict( + together_api_key=get_env_or_fail("TOGETHER_API_KEY"), + ) + + +PROVIDER_PARAMS = [ + pytest.param("meta_reference", marks=pytest.mark.meta_reference), + pytest.param("ollama", marks=pytest.mark.ollama), + pytest.param("fireworks", marks=pytest.mark.fireworks), + pytest.param("together", marks=pytest.mark.together), +] + + +@pytest_asyncio.fixture( + scope="session", + params=PROVIDER_PARAMS, +) +async def stack_impls(request): + provider_fixture = request.param + provider = request.getfixturevalue(provider_fixture) + if isinstance(provider, tuple): + provider, provider_data = provider + else: + provider_data = None + + impls = await resolve_impls_for_test_v2( + [Api.inference], + {"inference": [provider.model_dump()]}, + provider_data, + ) + + return (impls[Api.inference], impls[Api.models]) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "llama_8b: mark test to run only with the given model" + ) + config.addinivalue_line( + "markers", "llama_3b: mark test to run only with the given model" + ) + config.addinivalue_line( + "markers", + "meta_reference: marks tests as metaref specific", + ) + config.addinivalue_line( + "markers", + "ollama: marks tests as ollama specific", + ) + config.addinivalue_line( + "markers", + "fireworks: marks tests as fireworks specific", + ) + config.addinivalue_line( + "markers", + "together: marks tests as fireworks specific", + ) diff --git a/llama_stack/providers/tests/inference/provider_config_example.yaml b/llama_stack/providers/tests/inference/provider_config_example.yaml deleted file mode 100644 index 675ece1ea..000000000 --- a/llama_stack/providers/tests/inference/provider_config_example.yaml +++ /dev/null @@ -1,28 +0,0 @@ -providers: - - provider_id: test-ollama - provider_type: remote::ollama - config: - host: localhost - port: 11434 - - provider_id: meta-reference - provider_type: meta-reference - config: - model: Llama3.2-1B-Instruct - - provider_id: test-tgi - provider_type: remote::tgi - config: - url: http://localhost:7001 - - provider_id: test-remote - provider_type: remote - config: - host: localhost - port: 7002 - - provider_id: test-together - provider_type: remote::together - config: {} -# if a provider needs private keys from the client, they use the -# "get_request_provider_data" function (see distribution/request_headers.py) -# this is a place to provide such data. -provider_data: - "test-together": - together_api_key: 0xdeadbeefputrealapikeyhere diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 3063eb431..d96bae649 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -5,10 +5,8 @@ # the root directory of this source tree. import itertools -import os import pytest -import pytest_asyncio from pydantic import BaseModel, ValidationError @@ -16,24 +14,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.providers.tests.resolver import resolve_impls_for_test +from .conftest import MODEL_PARAMS, PROVIDER_PARAMS # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/inference/test_inference.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/inference/test_inference.py +# -m "(fireworks or ollama) and llama_3b" +# -v -s --tb=short --disable-warnings +# --env FIREWORKS_API_KEY= def group_chunks(response): @@ -45,45 +33,19 @@ def group_chunks(response): } -Llama_8B = "Llama3.1-8B-Instruct" -Llama_3B = "Llama3.2-3B-Instruct" - - def get_expected_stop_reason(model: str): return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn -if "MODEL_IDS" not in os.environ: - MODEL_IDS = [Llama_8B, Llama_3B] -else: - MODEL_IDS = os.environ["MODEL_IDS"].split(",") - - -# This is going to create multiple Stack impls without tearing down the previous one -# Fix that! -@pytest_asyncio.fixture( - scope="session", - params=[{"model": m} for m in MODEL_IDS], - ids=lambda d: d["model"], -) -async def inference_settings(request): - model = request.param["model"] - impls = await resolve_impls_for_test( - Api.inference, - ) - +@pytest.fixture +def common_params(llama_model): return { - "impl": impls[Api.inference], - "models_impl": impls[Api.models], - "common_params": { - "model": model, - "tool_choice": ToolChoice.auto, - "tool_prompt_format": ( - ToolPromptFormat.json - if "Llama3.1" in model - else ToolPromptFormat.python_list - ), - }, + "tool_choice": ToolChoice.auto, + "tool_prompt_format": ( + ToolPromptFormat.json + if "Llama3.1" in llama_model + else ToolPromptFormat.python_list + ), } @@ -109,301 +71,313 @@ def sample_tool_definition(): ) -@pytest.mark.asyncio -async def test_model_list(inference_settings): - params = inference_settings["common_params"] - models_impl = inference_settings["models_impl"] - response = await models_impl.list_models() - assert isinstance(response, list) - assert len(response) >= 1 - assert all(isinstance(model, ModelDefWithProvider) for model in response) +@pytest.mark.parametrize("llama_model", MODEL_PARAMS, indirect=True) +@pytest.mark.parametrize( + "stack_impls", + PROVIDER_PARAMS, + indirect=True, +) +class TestInference: + @pytest.mark.asyncio + async def test_model_list(self, llama_model, stack_impls): + _, models_impl = stack_impls + response = await models_impl.list_models() + assert isinstance(response, list) + assert len(response) >= 1 + assert all(isinstance(model, ModelDefWithProvider) for model in response) - model_def = None - for model in response: - if model.identifier == params["model"]: - model_def = model - break + model_def = None + for model in response: + if model.identifier == llama_model: + model_def = model + break - assert model_def is not None - assert model_def.identifier == params["model"] + assert model_def is not None + @pytest.mark.asyncio + async def test_completion(self, llama_model, stack_impls, common_params): + inference_impl, _ = stack_impls -@pytest.mark.asyncio -async def test_completion(inference_settings): - inference_impl = inference_settings["impl"] - params = inference_settings["common_params"] + provider = inference_impl.routing_table.get_provider_impl(llama_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::ollama", + "remote::tgi", + "remote::together", + "remote::fireworks", + ): + pytest.skip("Other inference providers don't support completion() yet") - provider = inference_impl.routing_table.get_provider_impl(params["model"]) - if provider.__provider_spec__.provider_type not in ( - "meta-reference", - "remote::ollama", - "remote::tgi", - "remote::together", - "remote::fireworks", - ): - pytest.skip("Other inference providers don't support completion() yet") - - response = await inference_impl.completion( - content="Micheael Jordan is born in ", - stream=False, - model=params["model"], - sampling_params=SamplingParams( - max_tokens=50, - ), - ) - - assert isinstance(response, CompletionResponse) - assert "1963" in response.content - - chunks = [ - r - async for r in await inference_impl.completion( - content="Roses are red,", - stream=True, - model=params["model"], + response = await inference_impl.completion( + content="Micheael Jordan is born in ", + stream=False, + model=llama_model, sampling_params=SamplingParams( max_tokens=50, ), ) - ] - assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) - assert len(chunks) >= 1 - last = chunks[-1] - assert last.stop_reason == StopReason.out_of_tokens + assert isinstance(response, CompletionResponse) + assert "1963" in response.content + chunks = [ + r + async for r in await inference_impl.completion( + content="Roses are red,", + stream=True, + model=llama_model, + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + ] -@pytest.mark.asyncio -@pytest.mark.skip("This test is not quite robust") -async def test_completions_structured_output(inference_settings): - inference_impl = inference_settings["impl"] - params = inference_settings["common_params"] + assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) + assert len(chunks) >= 1 + last = chunks[-1] + assert last.stop_reason == StopReason.out_of_tokens - provider = inference_impl.routing_table.get_provider_impl(params["model"]) - if provider.__provider_spec__.provider_type not in ( - "meta-reference", - "remote::tgi", - "remote::together", - "remote::fireworks", + @pytest.mark.asyncio + @pytest.mark.skip("This test is not quite robust") + async def test_completions_structured_output( + self, llama_model, stack_impls, common_params ): - pytest.skip( - "Other inference providers don't support structured output in completions yet" + inference_impl, _ = stack_impls + + provider = inference_impl.routing_table.get_provider_impl(llama_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::tgi", + "remote::together", + "remote::fireworks", + ): + pytest.skip( + "Other inference providers don't support structured output in completions yet" + ) + + class Output(BaseModel): + name: str + year_born: str + year_retired: str + + user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." + response = await inference_impl.completion( + content=user_input, + stream=False, + model=llama_model, + sampling_params=SamplingParams( + max_tokens=50, + ), + response_format=JsonSchemaResponseFormat( + json_schema=Output.model_json_schema(), + ), ) + assert isinstance(response, CompletionResponse) + assert isinstance(response.content, str) - class Output(BaseModel): - name: str - year_born: str - year_retired: str + answer = Output.model_validate_json(response.content) + assert answer.name == "Michael Jordan" + assert answer.year_born == "1963" + assert answer.year_retired == "2003" - user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." - response = await inference_impl.completion( - content=user_input, - stream=False, - model=params["model"], - sampling_params=SamplingParams( - max_tokens=50, - ), - response_format=JsonSchemaResponseFormat( - json_schema=Output.model_json_schema(), - ), - ) - assert isinstance(response, CompletionResponse) - assert isinstance(response.content, str) - - answer = Output.parse_raw(response.content) - assert answer.name == "Michael Jordan" - assert answer.year_born == "1963" - assert answer.year_retired == "2003" - - -@pytest.mark.asyncio -async def test_chat_completion_non_streaming(inference_settings, sample_messages): - inference_impl = inference_settings["impl"] - response = await inference_impl.chat_completion( - messages=sample_messages, - stream=False, - **inference_settings["common_params"], - ) - - assert isinstance(response, ChatCompletionResponse) - assert response.completion_message.role == "assistant" - assert isinstance(response.completion_message.content, str) - assert len(response.completion_message.content) > 0 - - -@pytest.mark.asyncio -async def test_structured_output(inference_settings): - inference_impl = inference_settings["impl"] - params = inference_settings["common_params"] - - provider = inference_impl.routing_table.get_provider_impl(params["model"]) - if provider.__provider_spec__.provider_type not in ( - "meta-reference", - "remote::fireworks", - "remote::tgi", - "remote::together", + @pytest.mark.asyncio + async def test_chat_completion_non_streaming( + self, llama_model, stack_impls, common_params, sample_messages ): - pytest.skip("Other inference providers don't support structured output yet") - - class AnswerFormat(BaseModel): - first_name: str - last_name: str - year_of_birth: int - num_seasons_in_nba: int - - response = await inference_impl.chat_completion( - messages=[ - SystemMessage(content="You are a helpful assistant."), - UserMessage(content="Please give me information about Michael Jordan."), - ], - stream=False, - response_format=JsonSchemaResponseFormat( - json_schema=AnswerFormat.model_json_schema(), - ), - **inference_settings["common_params"], - ) - - assert isinstance(response, ChatCompletionResponse) - assert response.completion_message.role == "assistant" - assert isinstance(response.completion_message.content, str) - - answer = AnswerFormat.parse_raw(response.completion_message.content) - assert answer.first_name == "Michael" - assert answer.last_name == "Jordan" - assert answer.year_of_birth == 1963 - assert answer.num_seasons_in_nba == 15 - - response = await inference_impl.chat_completion( - messages=[ - SystemMessage(content="You are a helpful assistant."), - UserMessage(content="Please give me information about Michael Jordan."), - ], - stream=False, - **inference_settings["common_params"], - ) - - assert isinstance(response, ChatCompletionResponse) - assert isinstance(response.completion_message.content, str) - - with pytest.raises(ValidationError): - AnswerFormat.parse_raw(response.completion_message.content) - - -@pytest.mark.asyncio -async def test_chat_completion_streaming(inference_settings, sample_messages): - inference_impl = inference_settings["impl"] - response = [ - r - async for r in await inference_impl.chat_completion( + inference_impl, _ = stack_impls + response = await inference_impl.chat_completion( + model=llama_model, messages=sample_messages, - stream=True, - **inference_settings["common_params"], + stream=False, + **common_params, ) - ] - assert len(response) > 0 - assert all( - isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response - ) - grouped = group_chunks(response) - assert len(grouped[ChatCompletionResponseEventType.start]) == 1 - assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 - assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + assert len(response.completion_message.content) > 0 - end = grouped[ChatCompletionResponseEventType.complete][0] - assert end.event.stop_reason == StopReason.end_of_turn + @pytest.mark.asyncio + async def test_structured_output(self, llama_model, stack_impls, common_params): + inference_impl, _ = stack_impls + provider = inference_impl.routing_table.get_provider_impl(llama_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::fireworks", + "remote::tgi", + "remote::together", + ): + pytest.skip("Other inference providers don't support structured output yet") -@pytest.mark.asyncio -async def test_chat_completion_with_tool_calling( - inference_settings, - sample_messages, - sample_tool_definition, -): - inference_impl = inference_settings["impl"] - messages = sample_messages + [ - UserMessage( - content="What's the weather like in San Francisco?", + class AnswerFormat(BaseModel): + first_name: str + last_name: str + year_of_birth: int + num_seasons_in_nba: int + + response = await inference_impl.chat_completion( + model=llama_model, + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + response_format=JsonSchemaResponseFormat( + json_schema=AnswerFormat.model_json_schema(), + ), + **common_params, ) - ] - response = await inference_impl.chat_completion( - messages=messages, - tools=[sample_tool_definition], - stream=False, - **inference_settings["common_params"], - ) + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) - assert isinstance(response, ChatCompletionResponse) + answer = AnswerFormat.model_validate_json(response.completion_message.content) + assert answer.first_name == "Michael" + assert answer.last_name == "Jordan" + assert answer.year_of_birth == 1963 + assert answer.num_seasons_in_nba == 15 - message = response.completion_message - - # This is not supported in most providers :/ they don't return eom_id / eot_id - # stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"]) - # assert message.stop_reason == stop_reason - assert message.tool_calls is not None - assert len(message.tool_calls) > 0 - - call = message.tool_calls[0] - assert call.tool_name == "get_weather" - assert "location" in call.arguments - assert "San Francisco" in call.arguments["location"] - - -@pytest.mark.asyncio -async def test_chat_completion_with_tool_calling_streaming( - inference_settings, - sample_messages, - sample_tool_definition, -): - inference_impl = inference_settings["impl"] - messages = sample_messages + [ - UserMessage( - content="What's the weather like in San Francisco?", + response = await inference_impl.chat_completion( + model=llama_model, + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please give me information about Michael Jordan."), + ], + stream=False, + **common_params, ) - ] - response = [ - r - async for r in await inference_impl.chat_completion( + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + + with pytest.raises(ValidationError): + AnswerFormat.model_validate_json(response.completion_message.content) + + @pytest.mark.asyncio + async def test_chat_completion_streaming( + self, llama_model, stack_impls, common_params, sample_messages + ): + inference_impl, _ = stack_impls + response = [ + r + async for r in await inference_impl.chat_completion( + model=llama_model, + messages=sample_messages, + stream=True, + **common_params, + ) + ] + + assert len(response) > 0 + assert all( + isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response + ) + grouped = group_chunks(response) + assert len(grouped[ChatCompletionResponseEventType.start]) == 1 + assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 + assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + + end = grouped[ChatCompletionResponseEventType.complete][0] + assert end.event.stop_reason == StopReason.end_of_turn + + @pytest.mark.asyncio + async def test_chat_completion_with_tool_calling( + self, + llama_model, + stack_impls, + common_params, + sample_messages, + sample_tool_definition, + ): + inference_impl, _ = stack_impls + messages = sample_messages + [ + UserMessage( + content="What's the weather like in San Francisco?", + ) + ] + + response = await inference_impl.chat_completion( + model=llama_model, messages=messages, tools=[sample_tool_definition], - stream=True, - **inference_settings["common_params"], + stream=False, + **common_params, ) - ] - assert len(response) > 0 - assert all( - isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response - ) - grouped = group_chunks(response) - assert len(grouped[ChatCompletionResponseEventType.start]) == 1 - assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 - assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + assert isinstance(response, ChatCompletionResponse) - # This is not supported in most providers :/ they don't return eom_id / eot_id - # expected_stop_reason = get_expected_stop_reason( - # inference_settings["common_params"]["model"] - # ) - # end = grouped[ChatCompletionResponseEventType.complete][0] - # assert end.event.stop_reason == expected_stop_reason + message = response.completion_message - model = inference_settings["common_params"]["model"] - if "Llama3.1" in model: + # This is not supported in most providers :/ they don't return eom_id / eot_id + # stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"]) + # assert message.stop_reason == stop_reason + assert message.tool_calls is not None + assert len(message.tool_calls) > 0 + + call = message.tool_calls[0] + assert call.tool_name == "get_weather" + assert "location" in call.arguments + assert "San Francisco" in call.arguments["location"] + + @pytest.mark.asyncio + async def test_chat_completion_with_tool_calling_streaming( + self, + llama_model, + stack_impls, + common_params, + sample_messages, + sample_tool_definition, + ): + inference_impl, _ = stack_impls + messages = sample_messages + [ + UserMessage( + content="What's the weather like in San Francisco?", + ) + ] + + response = [ + r + async for r in await inference_impl.chat_completion( + model=llama_model, + messages=messages, + tools=[sample_tool_definition], + stream=True, + **common_params, + ) + ] + + assert len(response) > 0 assert all( - isinstance(chunk.event.delta, ToolCallDelta) - for chunk in grouped[ChatCompletionResponseEventType.progress] + isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response ) - first = grouped[ChatCompletionResponseEventType.progress][0] - assert first.event.delta.parse_status == ToolCallParseStatus.started + grouped = group_chunks(response) + assert len(grouped[ChatCompletionResponseEventType.start]) == 1 + assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 + assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 - last = grouped[ChatCompletionResponseEventType.progress][-1] - # assert last.event.stop_reason == expected_stop_reason - assert last.event.delta.parse_status == ToolCallParseStatus.success - assert isinstance(last.event.delta.content, ToolCall) + # This is not supported in most providers :/ they don't return eom_id / eot_id + # expected_stop_reason = get_expected_stop_reason( + # inference_settings["common_params"]["model"] + # ) + # end = grouped[ChatCompletionResponseEventType.complete][0] + # assert end.event.stop_reason == expected_stop_reason - call = last.event.delta.content - assert call.tool_name == "get_weather" - assert "location" in call.arguments - assert "San Francisco" in call.arguments["location"] + if "Llama3.1" in llama_model: + assert all( + isinstance(chunk.event.delta, ToolCallDelta) + for chunk in grouped[ChatCompletionResponseEventType.progress] + ) + first = grouped[ChatCompletionResponseEventType.progress][0] + assert first.event.delta.parse_status == ToolCallParseStatus.started + + last = grouped[ChatCompletionResponseEventType.progress][-1] + # assert last.event.stop_reason == expected_stop_reason + assert last.event.delta.parse_status == ToolCallParseStatus.success + assert isinstance(last.event.delta.content, ToolCall) + + call = last.event.delta.content + assert call.tool_name == "get_weather" + assert "location" in call.arguments + assert "San Francisco" in call.arguments["location"] diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index f211cc7d3..a03b25aba 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -7,7 +7,7 @@ import json import os from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import yaml @@ -18,6 +18,28 @@ from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls +async def resolve_impls_for_test_v2( + apis: List[Api], + providers: Dict[str, Provider], + provider_data: Optional[Dict[str, Any]] = None, +): + run_config = dict( + built_at=datetime.now(), + image_name="test-fixture", + apis=apis, + providers=providers, + ) + run_config = parse_and_maybe_upgrade_config(run_config) + impls = await resolve_impls(run_config, get_provider_registry()) + + if provider_data: + set_request_provider_data( + {"X-LlamaStack-ProviderData": json.dumps(provider_data)} + ) + + return impls + + async def resolve_impls_for_test(api: Api, deps: List[Api] = None): if "PROVIDER_CONFIG" not in os.environ: raise ValueError(