mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:32:38 +00:00
Add Ollama inference mocks
Summary: This commit adds mock support for Ollama inference testing. Use `--mock-overrides` during your test run: ``` pytest llama_stack/providers/tests/inference/test_text_inference.py -m "ollama" --mock-overrides inference=ollama --inference-model Llama3.2-1B-Instruct ``` The test will run using Ollama provider using mock Adapter. Test Plan: Run tests ``` pytest llama_stack/providers/tests/inference/test_text_inference.py -m "ollama" --mock-overrides inference=ollama --inference-model Llama3.2-1B-Instruct -v -s --tb=short --disable-warnings ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 -- /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/bin/python cachedir: .pytest_cache rootdir: /Users/vivic/Code/llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 56 items / 48 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-ollama] Overriding inference=ollama with mocks from inference_ollama_mocks Resolved 4 providers inner-inference => ollama models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: Llama3.2-1B-Instruct served by ollama PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[-ollama] SKIPPED (This test is not quite robust) llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-ollama] SKIPPED (Other inference providers don't support structured output yet) llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-ollama] PASSED ==================================================================================== 6 passed, 2 skipped, 48 deselected, 6 warnings in 0.11s ==================================================================================== ```
This commit is contained in:
parent
ac1791f8b1
commit
560467e6fe
2 changed files with 223 additions and 2 deletions
|
|
@ -23,8 +23,9 @@ from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
|||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture, should_use_mock_overrides
|
||||
from ..env import get_env_or_fail
|
||||
from .mocks import * # noqa
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -182,6 +183,18 @@ INFERENCE_FIXTURES = [
|
|||
async def inference_stack(request, inference_model):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
|
||||
# Setup mocks if they are specified via the command line and they are defined
|
||||
if should_use_mock_overrides(
|
||||
request, f"inference={fixture_name}", f"inference_{fixture_name}_mocks"
|
||||
):
|
||||
try:
|
||||
request.getfixturevalue(f"inference_{fixture_name}_mocks")
|
||||
except pytest.FixtureLookupError:
|
||||
print(
|
||||
f"Fixture inference_{fixture_name}_mocks not implemented, skipping mocks."
|
||||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
|
|
@ -189,4 +202,4 @@ async def inference_stack(request, inference_model):
|
|||
models=[ModelInput(model_id=inference_model)],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
|
|
|
|||
208
llama_stack/providers/tests/inference/mocks.py
Normal file
208
llama_stack/providers/tests/inference/mocks.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Type
|
||||
|
||||
from unittest.mock import create_autospec, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.routers import ModelsRoutingTable
|
||||
from llama_stack.distribution.routers.routers import InferenceRouter
|
||||
from llama_stack.providers.remote.inference.ollama.ollama import OllamaInferenceAdapter
|
||||
|
||||
|
||||
class Stubs:
|
||||
completion_stub_matchers = {
|
||||
"stream=False": {
|
||||
"content=Micheael Jordan is born in ": CompletionResponse(
|
||||
content="1963",
|
||||
stop_reason="end_of_message",
|
||||
logprobs=None,
|
||||
)
|
||||
},
|
||||
"stream=True": {
|
||||
"content=Roses are red,": CompletionResponseStreamChunk(
|
||||
delta="", stop_reason="out_of_tokens", logprobs=None
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def process_completion(*args, **kwargs):
|
||||
if kwargs["stream"]:
|
||||
stream_mock = MagicMock()
|
||||
stream_mock.__aiter__.return_value = [
|
||||
Stubs.completion_stub_matchers["stream=True"][
|
||||
f"content={kwargs['content']}"
|
||||
]
|
||||
]
|
||||
return stream_mock
|
||||
return Stubs.completion_stub_matchers["stream=False"][
|
||||
f"content={kwargs['content']}"
|
||||
]
|
||||
|
||||
chat_completion_stub_matchers = {
|
||||
"stream=False": {
|
||||
"content=You are a helpful assistant.|What's the weather like today?": ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello world",
|
||||
stop_reason="end_of_message",
|
||||
)
|
||||
),
|
||||
"content=You are a helpful assistant.|What's the weather like today?|What's the weather like in San Francisco?": ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello world",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="get_weather",
|
||||
tool_name="get_weather",
|
||||
arguments={"location": "San Francisco"},
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
},
|
||||
"stream=True": {
|
||||
"content=You are a helpful assistant.|What's the weather like today?": [
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="Hello",
|
||||
)
|
||||
),
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta="world",
|
||||
)
|
||||
),
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="this is a test",
|
||||
stop_reason="end_of_turn",
|
||||
)
|
||||
),
|
||||
],
|
||||
"content=You are a helpful assistant.|What's the weather like today?|What's the weather like in San Francisco?": [
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
call_id="get_weather",
|
||||
tool_name="get_weather",
|
||||
arguments={"location": "San Francisco"},
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
),
|
||||
),
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
call_id="get_weather",
|
||||
tool_name="get_weather",
|
||||
arguments={"location": "San Francisco"},
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
),
|
||||
),
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
call_id="get_weather",
|
||||
tool_name="get_weather",
|
||||
arguments={"location": "San Francisco"},
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def chat_completion(*args, **kwargs):
|
||||
query_content = "|".join([msg.content for msg in kwargs["messages"]])
|
||||
if kwargs["stream"]:
|
||||
stream_mock = MagicMock()
|
||||
stream_mock.__aiter__.return_value = Stubs.chat_completion_stub_matchers[
|
||||
"stream=True"
|
||||
][f"content={query_content}"]
|
||||
return stream_mock
|
||||
return Stubs.chat_completion_stub_matchers["stream=False"][
|
||||
f"content={query_content}"
|
||||
]
|
||||
|
||||
|
||||
def setup_models_stubs(model_mock: Model, routing_table_mock: Type[ModelsRoutingTable]):
|
||||
routing_table_mock.return_value.list_models.return_value = [model_mock]
|
||||
|
||||
|
||||
def setup_provider_stubs(
|
||||
model_mock: Model, routing_table_mock: Type[ModelsRoutingTable]
|
||||
):
|
||||
provider_mock = Mock()
|
||||
provider_mock.__provider_spec__ = Mock()
|
||||
provider_mock.__provider_spec__.provider_type = model_mock.provider_type
|
||||
routing_table_mock.return_value.get_provider_impl.return_value = provider_mock
|
||||
|
||||
|
||||
def setup_inference_router_stubs(adapter_class: Type[Inference]):
|
||||
# Set up competion stubs
|
||||
InferenceRouter.completion = create_autospec(adapter_class.completion)
|
||||
InferenceRouter.completion.side_effect = Stubs.process_completion
|
||||
|
||||
# Set up chat completion stubs
|
||||
InferenceRouter.chat_completion = create_autospec(adapter_class.chat_completion)
|
||||
InferenceRouter.chat_completion.side_effect = Stubs.chat_completion
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_ollama_mocks(inference_model):
|
||||
with patch(
|
||||
"llama_stack.providers.remote.inference.ollama.get_adapter_impl",
|
||||
autospec=True,
|
||||
) as get_adapter_impl_mock, patch(
|
||||
"llama_stack.distribution.routers.ModelsRoutingTable",
|
||||
autospec=True,
|
||||
) as ModelsRoutingTableMock: # noqa N806
|
||||
model_mock = create_autospec(Model)
|
||||
model_mock.identifier = inference_model
|
||||
model_mock.provider_id = "ollama"
|
||||
model_mock.provider_type = "remote::ollama"
|
||||
|
||||
setup_models_stubs(model_mock, ModelsRoutingTableMock)
|
||||
setup_provider_stubs(model_mock, ModelsRoutingTableMock)
|
||||
setup_inference_router_stubs(OllamaInferenceAdapter)
|
||||
|
||||
impl_mock = create_autospec(OllamaInferenceAdapter)
|
||||
get_adapter_impl_mock.return_value = impl_mock
|
||||
yield
|
||||
Loading…
Add table
Add a link
Reference in a new issue