From 428995286d31f15d223f1655a3aa2cf520b4f169 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 14 Nov 2024 16:16:44 -0800 Subject: [PATCH] model update and delete for provider --- llama_stack/distribution/routers/routing_tables.py | 4 ++++ llama_stack/providers/datatypes.py | 4 ++++ .../providers/inline/inference/meta_reference/inference.py | 6 ++++++ llama_stack/providers/inline/inference/vllm/vllm.py | 6 ++++++ llama_stack/providers/remote/inference/ollama/ollama.py | 6 ++++++ llama_stack/providers/remote/inference/tgi/tgi.py | 6 ++++++ llama_stack/providers/remote/inference/vllm/vllm.py | 6 ++++++ 7 files changed, 38 insertions(+) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d0d588a91..d196d3557 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -57,6 +57,8 @@ async def update_object_with_provider( api = get_impl_api(p) if api == Api.memory: return await p.update_memory_bank(obj) + elif api == Api.inference: + return await p.update_model(obj) else: raise ValueError(f"Update not supported for {api}") @@ -65,6 +67,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) if api == Api.memory: return await p.unregister_memory_bank(obj.identifier) + elif api == Api.inference: + return await p.unregister_model(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 05fc3a33a..1b5d0ebd1 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -45,6 +45,10 @@ class Api(Enum): class ModelsProtocolPrivate(Protocol): async def register_model(self, model: Model) -> None: ... + async def update_model(self, model: Model) -> None: ... + + async def unregister_model(self, model_id: str) -> None: ... + class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 4f5c0c8c2..18128f354 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,6 +71,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP f"Model mismatch: {request.model} != {self.model.descriptor()}" ) + async def update_model(self, model: Model) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 8869cc07f..e4742240d 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -108,6 +108,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): return VLLMSamplingParams(**kwargs) + async def update_model(self, model: Model) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 297eecbdc..208e5036b 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -93,6 +93,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass + async def update_model(self, model: Model) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 8d3d1f86d..0aef9d706 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -69,6 +69,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass + async def update_model(self, model: Model) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 696cfb15d..4dc8220f1 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -58,6 +58,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass + async def update_model(self, model: Model) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + async def completion( self, model_id: str,