Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -28,7 +28,7 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/memory/test_inference.py \
# pytest -s llama_stack/providers/tests/inference/test_inference.py \
# --tb=short --disable-warnings
# ```
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
# {"model": Llama_3B},
],
ids=lambda d: d["model"],
)
@ -64,16 +64,11 @@ async def inference_settings(request):
model = request.param["model"]
impls = await resolve_impls_for_test(
Api.inference,
models=[
ModelDef(
identifier=model,
llama_model=model,
)
],
)
return {
"impl": impls[Api.inference],
"models_impl": impls[Api.models],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,
@ -108,6 +103,25 @@ def sample_tool_definition():
)
@pytest.mark.asyncio
async def test_model_list(inference_settings):
params = inference_settings["common_params"]
models_impl = inference_settings["models_impl"]
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
assert all(isinstance(model, ModelDefWithProvider) for model in response)
model_def = None
for model in response:
if model.identifier == params["model"]:
model_def = model
break
assert model_def is not None
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
async def memory_settings():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
return {
"memory_impl": impls[Api.memory],
"memory_banks_impl": impls[Api.memory_banks],
}
@pytest.fixture
@ -64,23 +67,35 @@ def sample_documents():
]
async def register_memory_bank(memory_impl: Memory):
async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_id=os.environ["PROVIDER_ID"],
)
await memory_impl.register_memory_bank(bank)
await banks_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_documents):
async def test_banks_list(memory_settings):
banks_impl = memory_settings["memory_banks_impl"]
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_query_documents(memory_settings, sample_documents):
memory_impl = memory_settings["memory_impl"]
banks_impl = memory_settings["memory_banks_impl"]
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(memory_impl)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"

View file

@ -18,9 +18,6 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
async def resolve_impls_for_test(
api: Api,
models: List[ModelDef] = None,
memory_banks: List[MemoryBankDef] = None,
shields: List[ShieldDef] = None,
):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
@ -47,45 +44,11 @@ async def resolve_impls_for_test(
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
models = models or []
shields = shields or []
memory_banks = memory_banks or []
models = [
ModelDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in models
]
shields = [
ShieldDef(
**{
**s.dict(),
"provider_id": provider_id,
}
)
for s in shields
]
memory_banks = [
MemoryBankDef(
**{
**m.dict(),
"provider_id": provider_id,
}
)
for m in memory_banks
]
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api],
providers={api.value: [Provider(**provider)]},
models=models,
memory_banks=memory_banks,
shields=shields,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)