add dynamic clients for all APIs (#348)

* add dynamic clients for all APIs

* fix openapi generator

* inference + memory + agents tests now pass with "remote" providers

* Add docstring which fixes openapi generator :/
This commit is contained in:
Ashwin Bharambe 2024-10-31 14:46:25 -07:00 committed by GitHub
parent f04b566c5c
commit 37b330b4ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 350 additions and 84 deletions

View file

@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api:
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if obj.provider_id == "remote":
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None:
self.registry: Registry = {}
def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs:
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
if cls is None:
obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
obj = cls(**obj.model_dump(), provider_id=provider_id)
self.registry[obj.identifier].append(obj)
for pid, p in self.impls_by_provider_id.items():
@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable):
if api == Api.inference:
p.model_store = self
models = await p.list_models()
add_objects(
[ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
)
add_objects(models, pid, ModelDefWithProvider)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
add_objects(
[
ShieldDefWithProvider(**s.dict(), provider_id=pid)
for s in shields
]
)
add_objects(shields, pid, ShieldDefWithProvider)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
# do in-memory updates due to pesky Annotated unions
for m in memory_banks:
m.provider_id = pid
add_objects(memory_banks)
add_objects(memory_banks, pid, None)
elif api == Api.datasetio:
p.dataset_store = self
datasets = await p.list_datasets()
# do in-memory updates due to pesky Annotated unions
for d in datasets:
d.provider_id = pid
add_objects(datasets, pid, DatasetDefWithProvider)
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
add_objects(
[
ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
for s in scoring_functions
]
)
add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():