add support for remote providers in tests

This commit is contained in:
Ashwin Bharambe 2024-11-04 19:57:40 -08:00
parent 0763a0b85f
commit 7cf4c905f3
11 changed files with 79 additions and 15 deletions

View file

@ -6,6 +6,7 @@
import json
import os
import tempfile
from datetime import datetime
from typing import Any, Dict, List, Optional
@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
async def resolve_impls_for_test_v2(
@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2(
providers=providers,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
if provider_data:
set_request_provider_data(