More work towards making remote stacks usable from tests

This commit is contained in:
Ashwin Bharambe 2024-11-12 17:09:31 -08:00
parent 8645f8bc9e
commit 8b7be87bec
7 changed files with 91 additions and 99 deletions

View file

@ -85,4 +85,5 @@ async def agents_stack(request, inference_model, safety_shield):
],
shields=[safety_shield],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -186,12 +186,7 @@ async def inference_stack(request, inference_model):
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
model_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
models=[ModelInput(model_id=inference_model)],
)
return (impls[Api.inference], impls[Api.models])

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
@ -25,12 +26,12 @@ async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None,
shields: Optional[List[Shield]] = None,
memory_banks: Optional[List[MemoryBank]] = None,
datasets: Optional[List[Dataset]] = None,
scoring_fns: Optional[List[ScoringFn]] = None,
eval_tasks: Optional[List[EvalTask]] = None,
models: Optional[List[ModelInput]] = None,
shields: Optional[List[ShieldInput]] = None,
memory_banks: Optional[List[MemoryBankInput]] = None,
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(