From 4ab6e1b81aae3538c08d3014369ce7f8c71b01d9 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Oct 2024 22:34:53 -0700 Subject: [PATCH] Add really basic testing for memory API weaviate does not work; the cluster URL seems malformed --- llama_stack/distribution/resolver.py | 7 +- .../distribution/routers/routing_tables.py | 6 +- .../adapters/memory/weaviate/__init__.py | 11 +- llama_stack/providers/datatypes.py | 2 - llama_stack/providers/registry/memory.py | 1 + .../tests/inference/test_inference.py | 85 +++------------ .../providers/tests/memory/__init__.py | 5 + .../tests/memory/provider_config_example.yaml | 24 +++++ .../providers/tests/memory/test_memory.py | 60 +++++++++++ llama_stack/providers/tests/resolver.py | 100 ++++++++++++++++++ 10 files changed, 220 insertions(+), 81 deletions(-) create mode 100644 llama_stack/providers/tests/memory/__init__.py create mode 100644 llama_stack/providers/tests/memory/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/memory/test_memory.py create mode 100644 llama_stack/providers/tests/resolver.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 2d3679177..4db72d29e 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -147,10 +147,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An inner_impls = {} if isinstance(provider.spec, RoutingTableProviderSpec): - for entry in provider.spec.registry: - inner_impls[entry.provider_id] = inner_impls_by_provider_id[ - f"inner-{provider.spec.router_api.value}" - ][entry.provider_id] + inner_impls = inner_impls_by_provider_id[ + f"inner-{provider.spec.router_api.value}" + ] impl = await instantiate_provider( provider, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3d89aa19f..73e26dd2e 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -70,8 +70,12 @@ class CommonRoutingTableImpl(RoutingTable): def get_provider_impl(self, routing_key: str) -> Any: if routing_key not in self.routing_key_to_object: - raise ValueError(f"Could not find provider for {routing_key}") + raise ValueError(f"Object `{routing_key}` not registered") + obj = self.routing_key_to_object[routing_key] + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") + return self.impls_by_provider_id[obj.provider_id] def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]: diff --git a/llama_stack/providers/adapters/memory/weaviate/__init__.py b/llama_stack/providers/adapters/memory/weaviate/__init__.py index b564eabf4..504bd1508 100644 --- a/llama_stack/providers/adapters/memory/weaviate/__init__.py +++ b/llama_stack/providers/adapters/memory/weaviate/__init__.py @@ -1,8 +1,15 @@ -from .config import WeaviateConfig +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 + async def get_adapter_impl(config: WeaviateConfig, _deps): from .weaviate import WeaviateMemoryAdapter impl = WeaviateMemoryAdapter(config) await impl.initialize() - return impl \ No newline at end of file + return impl diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index a254e2808..0c8f6ad21 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -43,8 +43,6 @@ class ProviderSpec(BaseModel): class RoutingTable(Protocol): - def get_routing_keys(self) -> List[str]: ... - def get_provider_impl(self, routing_key: str) -> Any: ... diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index a3f0bdb6f..a8d776c3f 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -62,6 +62,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="weaviate", pip_packages=EMBEDDING_DEPS + ["weaviate-client"], module="llama_stack.providers.adapters.memory.weaviate", + config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig", provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData", ), ), diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 094ee5924..de8241b20 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -5,21 +5,15 @@ # the root directory of this source tree. import itertools -import json -import os -from datetime import datetime import pytest import pytest_asyncio -import yaml from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.configure import parse_and_maybe_upgrade_config -from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls_with_routing +from llama_stack.providers.tests.resolver import resolve_impls_for_test def group_chunks(response): @@ -39,68 +33,6 @@ def get_expected_stop_reason(model: str): return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn -async def stack_impls(model): - if "PROVIDER_CONFIG" not in os.environ: - raise ValueError( - "You must set PROVIDER_CONFIG to a YAML file containing provider config" - ) - - with open(os.environ["PROVIDER_CONFIG"], "r") as f: - config_dict = yaml.safe_load(f) - - if "providers" not in config_dict: - raise ValueError("Config file should contain a `providers` key") - - providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]} - if len(providers_by_id) == 0: - raise ValueError("No providers found in config file") - - if "PROVIDER_ID" in os.environ: - provider_id = os.environ["PROVIDER_ID"] - if provider_id not in providers_by_id: - raise ValueError(f"Provider ID {provider_id} not found in config file") - provider = providers_by_id[provider_id] - else: - provider = list(providers_by_id.values())[0] - provider_id = provider["provider_id"] - print(f"No provider ID specified, picking first `{provider_id}`") - - run_config = dict( - built_at=datetime.now(), - image_name="test-fixture", - apis=[ - Api.inference, - Api.models, - ], - providers=dict( - inference=[ - Provider(**provider), - ] - ), - models=[ - ModelDef( - identifier=model, - llama_model=model, - provider_id=provider["provider_id"], - ) - ], - shields=[], - memory_banks=[], - ) - run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls_with_routing(run_config) - - # may need something cleaner here - if "provider_data" in config_dict: - provider_data = config_dict["provider_data"].get(provider_id, {}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-ProviderData": json.dumps(provider_data)} - ) - - return impls - - # This is going to create multiple Stack impls without tearing down the previous one # Fix that! @pytest_asyncio.fixture( @@ -113,7 +45,17 @@ async def stack_impls(model): ) async def inference_settings(request): model = request.param["model"] - impls = await stack_impls(model) + impls = await resolve_impls_for_test( + Api.inference, + models=[ + ModelDef( + identifier=model, + llama_model=model, + provider_id="", + ) + ], + ) + return { "impl": impls[Api.inference], "common_params": { @@ -266,12 +208,11 @@ async def test_chat_completion_with_tool_calling_streaming( assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 - end = grouped[ChatCompletionResponseEventType.complete][0] - # This is not supported in most providers :/ they don't return eom_id / eot_id # expected_stop_reason = get_expected_stop_reason( # inference_settings["common_params"]["model"] # ) + # end = grouped[ChatCompletionResponseEventType.complete][0] # assert end.event.stop_reason == expected_stop_reason model = inference_settings["common_params"]["model"] diff --git a/llama_stack/providers/tests/memory/__init__.py b/llama_stack/providers/tests/memory/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/memory/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml new file mode 100644 index 000000000..cac1adde5 --- /dev/null +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -0,0 +1,24 @@ +providers: + - provider_id: test-faiss + provider_type: meta-reference + config: {} + - provider_id: test-chroma + provider_type: remote::chroma + config: + host: localhost + port: 6001 + - provider_id: test-remote + provider_type: remote + config: + host: localhost + port: 7002 + - provider_id: test-weaviate + provider_type: remote::weaviate + config: {} +# if a provider needs private keys from the client, they use the +# "get_request_provider_data" function (see distribution/request_headers.py) +# this is a place to provide such data. +provider_data: + "test-weaviate": + weaviate_api_key: 0xdeadbeefputrealapikeyhere + weaviate_cluster_url: http://foobarbaz diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py new file mode 100644 index 000000000..4f6dadb14 --- /dev/null +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.tests.resolver import resolve_impls_for_test + + +@pytest_asyncio.fixture(scope="session") +async def memory_impl(): + impls = await resolve_impls_for_test( + Api.memory, + memory_banks=[], + ) + return impls[Api.memory] + + +@pytest.fixture +def sample_document(): + return MemoryBankDocument( + document_id="doc1", + content="This is a sample document for testing.", + mime_type="text/plain", + metadata={"author": "Test Author"}, + ) + + +async def register_memory_bank(memory_impl: Memory): + bank = VectorMemoryBankDef( + identifier="test_bank", + provider_id="", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await memory_impl.register_memory_bank(bank) + + +@pytest.mark.asyncio +async def test_query_documents(memory_impl, sample_document): + with pytest.raises(ValueError): + await memory_impl.insert_documents("test_bank", [sample_document]) + + await register_memory_bank(memory_impl) + await memory_impl.insert_documents("test_bank", [sample_document]) + + query = ["sample ", "document"] + response = await memory_impl.query_documents("test_bank", query) + + assert isinstance(response, QueryDocumentsResponse) + assert len(response.chunks) > 0 + assert len(response.scores) > 0 + assert len(response.chunks) == len(response.scores) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py new file mode 100644 index 000000000..266f252e4 --- /dev/null +++ b/llama_stack/providers/tests/resolver.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +from datetime import datetime + +import yaml + +from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.resolver import resolve_impls_with_routing + + +async def resolve_impls_for_test( + api: Api, + models: List[ModelDef] = None, + memory_banks: List[MemoryBankDef] = None, + shields: List[ShieldDef] = None, +): + if "PROVIDER_CONFIG" not in os.environ: + raise ValueError( + "You must set PROVIDER_CONFIG to a YAML file containing provider config" + ) + + with open(os.environ["PROVIDER_CONFIG"], "r") as f: + config_dict = yaml.safe_load(f) + + if "providers" not in config_dict: + raise ValueError("Config file should contain a `providers` key") + + providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]} + if len(providers_by_id) == 0: + raise ValueError("No providers found in config file") + + if "PROVIDER_ID" in os.environ: + provider_id = os.environ["PROVIDER_ID"] + if provider_id not in providers_by_id: + raise ValueError(f"Provider ID {provider_id} not found in config file") + provider = providers_by_id[provider_id] + else: + provider = list(providers_by_id.values())[0] + provider_id = provider["provider_id"] + print(f"No provider ID specified, picking first `{provider_id}`") + + models = models or [] + shields = shields or [] + memory_banks = memory_banks or [] + + models = [ + ModelDef( + **{ + **m.dict(), + "provider_id": provider_id, + } + ) + for m in models + ] + shields = [ + ShieldDef( + **{ + **s.dict(), + "provider_id": provider_id, + } + ) + for s in shields + ] + memory_banks = [ + MemoryBankDef( + **{ + **m.dict(), + "provider_id": provider_id, + } + ) + for m in memory_banks + ] + run_config = dict( + built_at=datetime.now(), + image_name="test-fixture", + apis=[api], + providers={api.value: [Provider(**provider)]}, + models=models, + memory_banks=memory_banks, + shields=shields, + ) + run_config = parse_and_maybe_upgrade_config(run_config) + impls = await resolve_impls_with_routing(run_config) + + if "provider_data" in config_dict: + provider_data = config_dict["provider_data"].get(provider_id, {}) + if provider_data: + set_request_provider_data( + {"X-LlamaStack-ProviderData": json.dumps(provider_data)} + ) + + return impls