add dynamic clients for all APIs (#348)

* add dynamic clients for all APIs

* fix openapi generator

* inference + memory + agents tests now pass with "remote" providers

* Add docstring which fixes openapi generator :/
This commit is contained in:
Ashwin Bharambe 2024-10-31 14:46:25 -07:00 committed by GitHub
parent f04b566c5c
commit 37b330b4ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 350 additions and 84 deletions

View file

@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ...
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
@ -171,7 +171,7 @@ as being "Llama Stack compatible"
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_stack.apis.{self.api.value}.client"
return "llama_stack.distribution.client"
@property
def pip_packages(self) -> List[str]:

View file

@ -26,6 +26,7 @@ from dotenv import load_dotenv
#
# ```bash
# PROVIDER_ID=<your_provider> \
# MODEL_ID=<your_model> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
# --tb=short --disable-warnings
@ -44,7 +45,7 @@ async def agents_settings():
"impl": impls[Api.agents],
"memory_impl": impls[Api.memory],
"common_params": {
"model": "Llama3.1-8B-Instruct",
"model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
}

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_id=os.environ["PROVIDER_ID"],
)
await banks_impl.register_memory_bank(bank)