mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 15:10:00 +00:00
Add provider data passing for library client
This commit is contained in:
parent
78727aad26
commit
33b4e8df49
10 changed files with 49 additions and 19 deletions
|
|
@ -13,12 +13,29 @@ from llama_stack_client import LlamaStackClient
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client():
|
||||
def provider_data():
|
||||
# check env for tavily secret, brave secret and inject all into provider data
|
||||
provider_data = {}
|
||||
if os.environ.get("TAVILY_SEARCH_API_KEY"):
|
||||
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
|
||||
if os.environ.get("BRAVE_SEARCH_API_KEY"):
|
||||
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
|
||||
return provider_data if len(provider_data) > 0 else None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client(provider_data):
|
||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
|
||||
client = LlamaStackAsLibraryClient(
|
||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||
provider_data=provider_data,
|
||||
)
|
||||
client.initialize()
|
||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
||||
client = LlamaStackClient(
|
||||
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
|
||||
provider_data=provider_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||
return client
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue