mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from llama_stack import LlamaStackAsLibraryClient
|
|
from llama_stack.providers.tests.env import get_env_or_fail
|
|
from llama_stack_client import LlamaStackClient
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.option.tbstyle = "short"
|
|
config.option.disable_warnings = True
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def provider_data():
|
|
# check env for tavily secret, brave secret and inject all into provider data
|
|
provider_data = {}
|
|
if os.environ.get("TAVILY_SEARCH_API_KEY"):
|
|
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
|
|
if os.environ.get("BRAVE_SEARCH_API_KEY"):
|
|
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
|
|
return provider_data if len(provider_data) > 0 else None
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def llama_stack_client(provider_data):
|
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
|
client = LlamaStackAsLibraryClient(
|
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
|
provider_data=provider_data,
|
|
skip_logger_removal=True,
|
|
)
|
|
client.initialize()
|
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
|
client = LlamaStackClient(
|
|
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
|
|
provider_data=provider_data,
|
|
)
|
|
else:
|
|
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
|
return client
|