Inference to use provider resource id to register and validate (#428)

This PR changes the way model id gets translated to the final model name
that gets passed through the provider.
Major changes include:
1) Providers are responsible for registering an object and as part of
the registration returning the object with the correct provider specific
name of the model provider_resource_id
2) To help with the common look ups different names a new ModelLookup
class is created.



Tested all inference providers including together, fireworks, vllm,
ollama, meta reference and bedrock
This commit is contained in:
Dinesh Yeduguru 2024-11-12 20:02:00 -08:00 committed by GitHub
parent e51107e019
commit fdff24e77a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 460 additions and 290 deletions

View file

@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
if obj.provider_id == "remote":
@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
obj.provider_id = ""
if api == Api.inference:
await p.register_model(obj)
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
assert len(objects) == 1
return objects[0]
async def register_object(self, obj: RoutableObjectWithProvider):
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# Get existing objects from registry
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
print(
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
)
return
return existing_obj
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
await self.dist_registry.register(obj)
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
)
await self.register_object(model)
return model
registered_model = await self.register_object(model)
return registered_model
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):