Add provider data passing for library client

This commit is contained in:
Dinesh Yeduguru 2025-01-13 13:25:36 -08:00
parent 78727aad26
commit 33b4e8df49
10 changed files with 49 additions and 19 deletions

View file

@ -13,12 +13,29 @@ from llama_stack_client import LlamaStackClient
@pytest.fixture(scope="session")
def llama_stack_client():
def provider_data():
# check env for tavily secret, brave secret and inject all into provider data
provider_data = {}
if os.environ.get("TAVILY_SEARCH_API_KEY"):
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
if os.environ.get("BRAVE_SEARCH_API_KEY"):
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
return provider_data if len(provider_data) > 0 else None
@pytest.fixture(scope="session")
def llama_stack_client(provider_data):
if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
provider_data=provider_data,
)
client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"):
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
client = LlamaStackClient(
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
provider_data=provider_data,
)
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client