mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Add really basic testing for memory API
weaviate does not work; the cluster URL seems malformed
This commit is contained in:
parent
dba7caf1d0
commit
4ab6e1b81a
10 changed files with 220 additions and 81 deletions
|
@ -147,10 +147,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
|
|
||||||
inner_impls = {}
|
inner_impls = {}
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
for entry in provider.spec.registry:
|
inner_impls = inner_impls_by_provider_id[
|
||||||
inner_impls[entry.provider_id] = inner_impls_by_provider_id[
|
f"inner-{provider.spec.router_api.value}"
|
||||||
f"inner-{provider.spec.router_api.value}"
|
]
|
||||||
][entry.provider_id]
|
|
||||||
|
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(
|
||||||
provider,
|
provider,
|
||||||
|
|
|
@ -70,8 +70,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any:
|
def get_provider_impl(self, routing_key: str) -> Any:
|
||||||
if routing_key not in self.routing_key_to_object:
|
if routing_key not in self.routing_key_to_object:
|
||||||
raise ValueError(f"Could not find provider for {routing_key}")
|
raise ValueError(f"Object `{routing_key}` not registered")
|
||||||
|
|
||||||
obj = self.routing_key_to_object[routing_key]
|
obj = self.routing_key_to_object[routing_key]
|
||||||
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
from .config import WeaviateConfig
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||||
from .weaviate import WeaviateMemoryAdapter
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
|
@ -43,8 +43,6 @@ class ProviderSpec(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable(Protocol):
|
class RoutingTable(Protocol):
|
||||||
def get_routing_keys(self) -> List[str]: ...
|
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
adapter_type="weaviate",
|
adapter_type="weaviate",
|
||||||
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||||
module="llama_stack.providers.adapters.memory.weaviate",
|
module="llama_stack.providers.adapters.memory.weaviate",
|
||||||
|
config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig",
|
||||||
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
|
@ -5,21 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
|
||||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|
||||||
|
|
||||||
|
|
||||||
def group_chunks(response):
|
def group_chunks(response):
|
||||||
|
@ -39,68 +33,6 @@ def get_expected_stop_reason(model: str):
|
||||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
||||||
async def stack_impls(model):
|
|
||||||
if "PROVIDER_CONFIG" not in os.environ:
|
|
||||||
raise ValueError(
|
|
||||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
|
||||||
config_dict = yaml.safe_load(f)
|
|
||||||
|
|
||||||
if "providers" not in config_dict:
|
|
||||||
raise ValueError("Config file should contain a `providers` key")
|
|
||||||
|
|
||||||
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
|
||||||
if len(providers_by_id) == 0:
|
|
||||||
raise ValueError("No providers found in config file")
|
|
||||||
|
|
||||||
if "PROVIDER_ID" in os.environ:
|
|
||||||
provider_id = os.environ["PROVIDER_ID"]
|
|
||||||
if provider_id not in providers_by_id:
|
|
||||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
|
||||||
provider = providers_by_id[provider_id]
|
|
||||||
else:
|
|
||||||
provider = list(providers_by_id.values())[0]
|
|
||||||
provider_id = provider["provider_id"]
|
|
||||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
|
||||||
|
|
||||||
run_config = dict(
|
|
||||||
built_at=datetime.now(),
|
|
||||||
image_name="test-fixture",
|
|
||||||
apis=[
|
|
||||||
Api.inference,
|
|
||||||
Api.models,
|
|
||||||
],
|
|
||||||
providers=dict(
|
|
||||||
inference=[
|
|
||||||
Provider(**provider),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
models=[
|
|
||||||
ModelDef(
|
|
||||||
identifier=model,
|
|
||||||
llama_model=model,
|
|
||||||
provider_id=provider["provider_id"],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
shields=[],
|
|
||||||
memory_banks=[],
|
|
||||||
)
|
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
|
||||||
impls = await resolve_impls_with_routing(run_config)
|
|
||||||
|
|
||||||
# may need something cleaner here
|
|
||||||
if "provider_data" in config_dict:
|
|
||||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
|
||||||
if provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
|
||||||
)
|
|
||||||
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
# This is going to create multiple Stack impls without tearing down the previous one
|
# This is going to create multiple Stack impls without tearing down the previous one
|
||||||
# Fix that!
|
# Fix that!
|
||||||
@pytest_asyncio.fixture(
|
@pytest_asyncio.fixture(
|
||||||
|
@ -113,7 +45,17 @@ async def stack_impls(model):
|
||||||
)
|
)
|
||||||
async def inference_settings(request):
|
async def inference_settings(request):
|
||||||
model = request.param["model"]
|
model = request.param["model"]
|
||||||
impls = await stack_impls(model)
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.inference,
|
||||||
|
models=[
|
||||||
|
ModelDef(
|
||||||
|
identifier=model,
|
||||||
|
llama_model=model,
|
||||||
|
provider_id="",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"impl": impls[Api.inference],
|
"impl": impls[Api.inference],
|
||||||
"common_params": {
|
"common_params": {
|
||||||
|
@ -266,12 +208,11 @@ async def test_chat_completion_with_tool_calling_streaming(
|
||||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
|
||||||
|
|
||||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||||
# expected_stop_reason = get_expected_stop_reason(
|
# expected_stop_reason = get_expected_stop_reason(
|
||||||
# inference_settings["common_params"]["model"]
|
# inference_settings["common_params"]["model"]
|
||||||
# )
|
# )
|
||||||
|
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||||
# assert end.event.stop_reason == expected_stop_reason
|
# assert end.event.stop_reason == expected_stop_reason
|
||||||
|
|
||||||
model = inference_settings["common_params"]["model"]
|
model = inference_settings["common_params"]["model"]
|
||||||
|
|
5
llama_stack/providers/tests/memory/__init__.py
Normal file
5
llama_stack/providers/tests/memory/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,24 @@
|
||||||
|
providers:
|
||||||
|
- provider_id: test-faiss
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
- provider_id: test-chroma
|
||||||
|
provider_type: remote::chroma
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 6001
|
||||||
|
- provider_id: test-remote
|
||||||
|
provider_type: remote
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 7002
|
||||||
|
- provider_id: test-weaviate
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config: {}
|
||||||
|
# if a provider needs private keys from the client, they use the
|
||||||
|
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||||
|
# this is a place to provide such data.
|
||||||
|
provider_data:
|
||||||
|
"test-weaviate":
|
||||||
|
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
||||||
|
weaviate_cluster_url: http://foobarbaz
|
60
llama_stack/providers/tests/memory/test_memory.py
Normal file
60
llama_stack/providers/tests/memory/test_memory.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def memory_impl():
|
||||||
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.memory,
|
||||||
|
memory_banks=[],
|
||||||
|
)
|
||||||
|
return impls[Api.memory]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_document():
|
||||||
|
return MemoryBankDocument(
|
||||||
|
document_id="doc1",
|
||||||
|
content="This is a sample document for testing.",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={"author": "Test Author"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def register_memory_bank(memory_impl: Memory):
|
||||||
|
bank = VectorMemoryBankDef(
|
||||||
|
identifier="test_bank",
|
||||||
|
provider_id="",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory_impl.register_memory_bank(bank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_documents(memory_impl, sample_document):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await memory_impl.insert_documents("test_bank", [sample_document])
|
||||||
|
|
||||||
|
await register_memory_bank(memory_impl)
|
||||||
|
await memory_impl.insert_documents("test_bank", [sample_document])
|
||||||
|
|
||||||
|
query = ["sample ", "document"]
|
||||||
|
response = await memory_impl.query_documents("test_bank", query)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryDocumentsResponse)
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert len(response.scores) > 0
|
||||||
|
assert len(response.chunks) == len(response.scores)
|
100
llama_stack/providers/tests/resolver.py
Normal file
100
llama_stack/providers/tests/resolver.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_impls_for_test(
|
||||||
|
api: Api,
|
||||||
|
models: List[ModelDef] = None,
|
||||||
|
memory_banks: List[MemoryBankDef] = None,
|
||||||
|
shields: List[ShieldDef] = None,
|
||||||
|
):
|
||||||
|
if "PROVIDER_CONFIG" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if "providers" not in config_dict:
|
||||||
|
raise ValueError("Config file should contain a `providers` key")
|
||||||
|
|
||||||
|
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||||
|
if len(providers_by_id) == 0:
|
||||||
|
raise ValueError("No providers found in config file")
|
||||||
|
|
||||||
|
if "PROVIDER_ID" in os.environ:
|
||||||
|
provider_id = os.environ["PROVIDER_ID"]
|
||||||
|
if provider_id not in providers_by_id:
|
||||||
|
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||||
|
provider = providers_by_id[provider_id]
|
||||||
|
else:
|
||||||
|
provider = list(providers_by_id.values())[0]
|
||||||
|
provider_id = provider["provider_id"]
|
||||||
|
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||||
|
|
||||||
|
models = models or []
|
||||||
|
shields = shields or []
|
||||||
|
memory_banks = memory_banks or []
|
||||||
|
|
||||||
|
models = [
|
||||||
|
ModelDef(
|
||||||
|
**{
|
||||||
|
**m.dict(),
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for m in models
|
||||||
|
]
|
||||||
|
shields = [
|
||||||
|
ShieldDef(
|
||||||
|
**{
|
||||||
|
**s.dict(),
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for s in shields
|
||||||
|
]
|
||||||
|
memory_banks = [
|
||||||
|
MemoryBankDef(
|
||||||
|
**{
|
||||||
|
**m.dict(),
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for m in memory_banks
|
||||||
|
]
|
||||||
|
run_config = dict(
|
||||||
|
built_at=datetime.now(),
|
||||||
|
image_name="test-fixture",
|
||||||
|
apis=[api],
|
||||||
|
providers={api.value: [Provider(**provider)]},
|
||||||
|
models=models,
|
||||||
|
memory_banks=memory_banks,
|
||||||
|
shields=shields,
|
||||||
|
)
|
||||||
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
|
impls = await resolve_impls_with_routing(run_config)
|
||||||
|
|
||||||
|
if "provider_data" in config_dict:
|
||||||
|
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||||
|
if provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return impls
|
Loading…
Add table
Add a link
Reference in a new issue