Add really basic testing for memory API

weaviate does not work; the cluster URL seems malformed
This commit is contained in:
Ashwin Bharambe 2024-10-07 22:34:53 -07:00 committed by Ashwin Bharambe
parent dba7caf1d0
commit 4ab6e1b81a
10 changed files with 220 additions and 81 deletions

View file

@ -147,10 +147,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
inner_impls = {} inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
for entry in provider.spec.registry: inner_impls = inner_impls_by_provider_id[
inner_impls[entry.provider_id] = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}" f"inner-{provider.spec.router_api.value}"
][entry.provider_id] ]
impl = await instantiate_provider( impl = await instantiate_provider(
provider, provider,

View file

@ -70,8 +70,12 @@ class CommonRoutingTableImpl(RoutingTable):
def get_provider_impl(self, routing_key: str) -> Any: def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object: if routing_key not in self.routing_key_to_object:
raise ValueError(f"Could not find provider for {routing_key}") raise ValueError(f"Object `{routing_key}` not registered")
obj = self.routing_key_to_object[routing_key] obj = self.routing_key_to_object[routing_key]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
return self.impls_by_provider_id[obj.provider_id] return self.impls_by_provider_id[obj.provider_id]
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]: def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:

View file

@ -1,4 +1,11 @@
from .config import WeaviateConfig # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps): async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter from .weaviate import WeaviateMemoryAdapter

View file

@ -43,8 +43,6 @@ class ProviderSpec(BaseModel):
class RoutingTable(Protocol): class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ... def get_provider_impl(self, routing_key: str) -> Any: ...

View file

@ -62,6 +62,7 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="weaviate", adapter_type="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"], pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.adapters.memory.weaviate", module="llama_stack.providers.adapters.memory.weaviate",
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
), ),
), ),

View file

@ -5,21 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import itertools import itertools
import json
import os
from datetime import datetime
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import yaml
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.providers.tests.resolver import resolve_impls_for_test
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
def group_chunks(response): def group_chunks(response):
@ -39,68 +33,6 @@ def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
async def stack_impls(model):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[
Api.inference,
Api.models,
],
providers=dict(
inference=[
Provider(**provider),
]
),
models=[
ModelDef(
identifier=model,
llama_model=model,
provider_id=provider["provider_id"],
)
],
shields=[],
memory_banks=[],
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
# may need something cleaner here
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
# This is going to create multiple Stack impls without tearing down the previous one # This is going to create multiple Stack impls without tearing down the previous one
# Fix that! # Fix that!
@pytest_asyncio.fixture( @pytest_asyncio.fixture(
@ -113,7 +45,17 @@ async def stack_impls(model):
) )
async def inference_settings(request): async def inference_settings(request):
model = request.param["model"] model = request.param["model"]
impls = await stack_impls(model) impls = await resolve_impls_for_test(
Api.inference,
models=[
ModelDef(
identifier=model,
llama_model=model,
provider_id="",
)
],
)
return { return {
"impl": impls[Api.inference], "impl": impls[Api.inference],
"common_params": { "common_params": {
@ -266,12 +208,11 @@ async def test_chat_completion_with_tool_calling_streaming(
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
end = grouped[ChatCompletionResponseEventType.complete][0]
# This is not supported in most providers :/ they don't return eom_id / eot_id # This is not supported in most providers :/ they don't return eom_id / eot_id
# expected_stop_reason = get_expected_stop_reason( # expected_stop_reason = get_expected_stop_reason(
# inference_settings["common_params"]["model"] # inference_settings["common_params"]["model"]
# ) # )
# end = grouped[ChatCompletionResponseEventType.complete][0]
# assert end.event.stop_reason == expected_stop_reason # assert end.event.stop_reason == expected_stop_reason
model = inference_settings["common_params"]["model"] model = inference_settings["common_params"]["model"]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,24 @@
providers:
- provider_id: test-faiss
provider_type: meta-reference
config: {}
- provider_id: test-chroma
provider_type: remote::chroma
config:
host: localhost
port: 6001
- provider_id: test-remote
provider_type: remote
config:
host: localhost
port: 7002
- provider_id: test-weaviate
provider_type: remote::weaviate
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-weaviate":
weaviate_api_key: 0xdeadbeefputrealapikeyhere
weaviate_cluster_url: http://foobarbaz

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
@pytest.fixture
def sample_document():
return MemoryBankDocument(
document_id="doc1",
content="This is a sample document for testing.",
mime_type="text/plain",
metadata={"author": "Test Author"},
)
async def register_memory_bank(memory_impl: Memory):
bank = VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await memory_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_document):
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", [sample_document])
await register_memory_bank(memory_impl)
await memory_impl.insert_documents("test_bank", [sample_document])
query = ["sample ", "document"]
response = await memory_impl.query_documents("test_bank", query)
assert isinstance(response, QueryDocumentsResponse)
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
if len(providers_by_id) == 0:
raise ValueError("No providers found in config file")
if "PROVIDER_ID" in os.environ:
provider_id = os.environ["PROVIDER_ID"]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls