temp commit

This commit is contained in:
Botao Chen 2024-12-12 21:44:03 -08:00
parent 8efe33646d
commit de44af1501
9 changed files with 153 additions and 53 deletions

View file

@ -90,6 +90,7 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
print("inference router")
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)

View file

@ -32,6 +32,7 @@ def get_impl_api(p: Any) -> Api:
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
print("registering object with provider", api)
assert obj.provider_id != "remote", "Remote provider should not be registered"
@ -169,6 +170,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# Get existing objects from registry
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
@ -181,7 +183,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
if obj is None:
print("obj is None")
registered_obj = await register_object_with_provider(obj, p)
if registered_obj is None:
print("registered_obj is None")
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
@ -211,6 +218,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model:
print("register_model", model_id)
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
@ -239,7 +247,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata=metadata,
model_type=model_type,
)
if model is None:
print("model is None!!!")
print("before registered_model")
registered_model = await self.register_object(model)
print("after registered_model")
return registered_model
async def unregister_model(self, model_id: str) -> None: