From 560467e6fe767b787ccc00baf21a707aa51c6cb5 Mon Sep 17 00:00:00 2001 From: Vladimir Ivic Date: Thu, 21 Nov 2024 15:35:55 -0800 Subject: [PATCH] Add Ollama inference mocks Summary: This commit adds mock support for Ollama inference testing. Use `--mock-overrides` during your test run: ``` pytest llama_stack/providers/tests/inference/test_text_inference.py -m "ollama" --mock-overrides inference=ollama --inference-model Llama3.2-1B-Instruct ``` The test will run using Ollama provider using mock Adapter. Test Plan: Run tests ``` pytest llama_stack/providers/tests/inference/test_text_inference.py -m "ollama" --mock-overrides inference=ollama --inference-model Llama3.2-1B-Instruct -v -s --tb=short --disable-warnings ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 -- /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/bin/python cachedir: .pytest_cache rootdir: /Users/vivic/Code/llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 56 items / 48 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-ollama] Overriding inference=ollama with mocks from inference_ollama_mocks Resolved 4 providers inner-inference => ollama models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: Llama3.2-1B-Instruct served by ollama PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[-ollama] SKIPPED (This test is not quite robust) llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-ollama] SKIPPED (Other inference providers don't support structured output yet) llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-ollama] PASSED ==================================================================================== 6 passed, 2 skipped, 48 deselected, 6 warnings in 0.11s ==================================================================================== ``` --- .../providers/tests/inference/fixtures.py | 17 +- .../providers/tests/inference/mocks.py | 208 ++++++++++++++++++ 2 files changed, 223 insertions(+), 2 deletions(-) create mode 100644 llama_stack/providers/tests/inference/mocks.py diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a53ddf639..a3a94c360 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -23,8 +23,9 @@ from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.tests.resolver import construct_stack_for_test -from ..conftest import ProviderFixture, remote_stack_fixture +from ..conftest import ProviderFixture, remote_stack_fixture, should_use_mock_overrides from ..env import get_env_or_fail +from .mocks import * # noqa @pytest.fixture(scope="session") @@ -182,6 +183,18 @@ INFERENCE_FIXTURES = [ async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") + + # Setup mocks if they are specified via the command line and they are defined + if should_use_mock_overrides( + request, f"inference={fixture_name}", f"inference_{fixture_name}_mocks" + ): + try: + request.getfixturevalue(f"inference_{fixture_name}_mocks") + except pytest.FixtureLookupError: + print( + f"Fixture inference_{fixture_name}_mocks not implemented, skipping mocks." + ) + test_stack = await construct_stack_for_test( [Api.inference], {"inference": inference_fixture.providers}, @@ -189,4 +202,4 @@ async def inference_stack(request, inference_model): models=[ModelInput(model_id=inference_model)], ) - return test_stack.impls[Api.inference], test_stack.impls[Api.models] + yield test_stack.impls[Api.inference], test_stack.impls[Api.models] diff --git a/llama_stack/providers/tests/inference/mocks.py b/llama_stack/providers/tests/inference/mocks.py new file mode 100644 index 000000000..68e706fb4 --- /dev/null +++ b/llama_stack/providers/tests/inference/mocks.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Type + +from unittest.mock import create_autospec, MagicMock, Mock, patch + +import pytest +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + Inference, + ToolCallDelta, + ToolCallParseStatus, +) +from llama_stack.apis.models import Model +from llama_stack.distribution.routers import ModelsRoutingTable +from llama_stack.distribution.routers.routers import InferenceRouter +from llama_stack.providers.remote.inference.ollama.ollama import OllamaInferenceAdapter + + +class Stubs: + completion_stub_matchers = { + "stream=False": { + "content=Micheael Jordan is born in ": CompletionResponse( + content="1963", + stop_reason="end_of_message", + logprobs=None, + ) + }, + "stream=True": { + "content=Roses are red,": CompletionResponseStreamChunk( + delta="", stop_reason="out_of_tokens", logprobs=None + ) + }, + } + + @staticmethod + async def process_completion(*args, **kwargs): + if kwargs["stream"]: + stream_mock = MagicMock() + stream_mock.__aiter__.return_value = [ + Stubs.completion_stub_matchers["stream=True"][ + f"content={kwargs['content']}" + ] + ] + return stream_mock + return Stubs.completion_stub_matchers["stream=False"][ + f"content={kwargs['content']}" + ] + + chat_completion_stub_matchers = { + "stream=False": { + "content=You are a helpful assistant.|What's the weather like today?": ChatCompletionResponse( + completion_message=CompletionMessage( + role="assistant", + content="Hello world", + stop_reason="end_of_message", + ) + ), + "content=You are a helpful assistant.|What's the weather like today?|What's the weather like in San Francisco?": ChatCompletionResponse( + completion_message=CompletionMessage( + role="assistant", + content="Hello world", + stop_reason="end_of_message", + tool_calls=[ + ToolCall( + call_id="get_weather", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + ) + ], + ) + ), + }, + "stream=True": { + "content=You are a helpful assistant.|What's the weather like today?": [ + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="Hello", + ) + ), + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta="world", + ) + ), + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="this is a test", + stop_reason="end_of_turn", + ) + ), + ], + "content=You are a helpful assistant.|What's the weather like today?|What's the weather like in San Francisco?": [ + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=ToolCallDelta( + content=ToolCall( + call_id="get_weather", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + ), + parse_status=ToolCallParseStatus.success, + ), + ), + ), + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=ToolCall( + call_id="get_weather", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + ), + parse_status=ToolCallParseStatus.success, + ), + ), + ), + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=ToolCallDelta( + content=ToolCall( + call_id="get_weather", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + ), + parse_status=ToolCallParseStatus.success, + ), + ) + ), + ], + }, + } + + @staticmethod + async def chat_completion(*args, **kwargs): + query_content = "|".join([msg.content for msg in kwargs["messages"]]) + if kwargs["stream"]: + stream_mock = MagicMock() + stream_mock.__aiter__.return_value = Stubs.chat_completion_stub_matchers[ + "stream=True" + ][f"content={query_content}"] + return stream_mock + return Stubs.chat_completion_stub_matchers["stream=False"][ + f"content={query_content}" + ] + + +def setup_models_stubs(model_mock: Model, routing_table_mock: Type[ModelsRoutingTable]): + routing_table_mock.return_value.list_models.return_value = [model_mock] + + +def setup_provider_stubs( + model_mock: Model, routing_table_mock: Type[ModelsRoutingTable] +): + provider_mock = Mock() + provider_mock.__provider_spec__ = Mock() + provider_mock.__provider_spec__.provider_type = model_mock.provider_type + routing_table_mock.return_value.get_provider_impl.return_value = provider_mock + + +def setup_inference_router_stubs(adapter_class: Type[Inference]): + # Set up competion stubs + InferenceRouter.completion = create_autospec(adapter_class.completion) + InferenceRouter.completion.side_effect = Stubs.process_completion + + # Set up chat completion stubs + InferenceRouter.chat_completion = create_autospec(adapter_class.chat_completion) + InferenceRouter.chat_completion.side_effect = Stubs.chat_completion + + +@pytest.fixture(scope="session") +def inference_ollama_mocks(inference_model): + with patch( + "llama_stack.providers.remote.inference.ollama.get_adapter_impl", + autospec=True, + ) as get_adapter_impl_mock, patch( + "llama_stack.distribution.routers.ModelsRoutingTable", + autospec=True, + ) as ModelsRoutingTableMock: # noqa N806 + model_mock = create_autospec(Model) + model_mock.identifier = inference_model + model_mock.provider_id = "ollama" + model_mock.provider_type = "remote::ollama" + + setup_models_stubs(model_mock, ModelsRoutingTableMock) + setup_provider_stubs(model_mock, ModelsRoutingTableMock) + setup_inference_router_stubs(OllamaInferenceAdapter) + + impl_mock = create_autospec(OllamaInferenceAdapter) + get_adapter_impl_mock.return_value = impl_mock + yield