model update and delete for provider

This commit is contained in:
Dinesh Yeduguru 2024-11-14 16:16:44 -08:00
parent e8b699797c
commit 428995286d
7 changed files with 38 additions and 0 deletions

View file

@ -57,6 +57,8 @@ async def update_object_with_provider(
api = get_impl_api(p)
if api == Api.memory:
return await p.update_memory_bank(obj)
elif api == Api.inference:
return await p.update_model(obj)
else:
raise ValueError(f"Update not supported for {api}")
@ -65,6 +67,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")

View file

@ -45,6 +45,10 @@ class Api(Enum):
class ModelsProtocolPrivate(Protocol):
async def register_model(self, model: Model) -> None: ...
async def update_model(self, model: Model) -> None: ...
async def unregister_model(self, model_id: str) -> None: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...

View file

@ -71,6 +71,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -108,6 +108,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
return VLLMSamplingParams(**kwargs)
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -93,6 +93,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,

View file

@ -69,6 +69,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model: str,

View file

@ -58,6 +58,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,