mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Add really basic testing for memory API
weaviate does not work; the cluster URL seems malformed
This commit is contained in:
parent
dba7caf1d0
commit
4ab6e1b81a
10 changed files with 220 additions and 81 deletions
|
|
@ -5,21 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import yaml
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
|
|
@ -39,68 +33,6 @@ def get_expected_stop_reason(model: str):
|
|||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
||||
async def stack_impls(model):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||
)
|
||||
|
||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
if "providers" not in config_dict:
|
||||
raise ValueError("Config file should contain a `providers` key")
|
||||
|
||||
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||
if len(providers_by_id) == 0:
|
||||
raise ValueError("No providers found in config file")
|
||||
|
||||
if "PROVIDER_ID" in os.environ:
|
||||
provider_id = os.environ["PROVIDER_ID"]
|
||||
if provider_id not in providers_by_id:
|
||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||
provider = providers_by_id[provider_id]
|
||||
else:
|
||||
provider = list(providers_by_id.values())[0]
|
||||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[
|
||||
Api.inference,
|
||||
Api.models,
|
||||
],
|
||||
providers=dict(
|
||||
inference=[
|
||||
Provider(**provider),
|
||||
]
|
||||
),
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id=provider["provider_id"],
|
||||
)
|
||||
],
|
||||
shields=[],
|
||||
memory_banks=[],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
|
||||
# may need something cleaner here
|
||||
if "provider_data" in config_dict:
|
||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||
)
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
# This is going to create multiple Stack impls without tearing down the previous one
|
||||
# Fix that!
|
||||
@pytest_asyncio.fixture(
|
||||
|
|
@ -113,7 +45,17 @@ async def stack_impls(model):
|
|||
)
|
||||
async def inference_settings(request):
|
||||
model = request.param["model"]
|
||||
impls = await stack_impls(model)
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.inference,
|
||||
models=[
|
||||
ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
provider_id="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.inference],
|
||||
"common_params": {
|
||||
|
|
@ -266,12 +208,11 @@ async def test_chat_completion_with_tool_calling_streaming(
|
|||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
|
||||
# This is not supported in most providers :/ they don't return eom_id / eot_id
|
||||
# expected_stop_reason = get_expected_stop_reason(
|
||||
# inference_settings["common_params"]["model"]
|
||||
# )
|
||||
# end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||
# assert end.event.stop_reason == expected_stop_reason
|
||||
|
||||
model = inference_settings["common_params"]["model"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue