From 787e78d7d4808b56db35d710ac3bddee9ebb2ae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 6 Feb 2025 13:45:38 +0100 Subject: [PATCH] chore: update return type to Optional[str] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updated the return type of the methods `get_provider_model_id` and `get_llama_model` in the `ModelRegistryHelper` class to `Optional[str]` to indicate that they may return a string or None when no match is found. This change improves the clarity of the methods' behavior and supports better type safety. Replaced explicit `if-else` checks with `dict.get()` for cleaner code. Signed-off-by: Sébastien Han --- .../providers/utils/inference/model_registry.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 5746af4ba..dea951395 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -57,17 +57,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model - def get_provider_model_id(self, identifier: str) -> str: - if identifier in self.alias_to_provider_id_map: - return self.alias_to_provider_id_map[identifier] - else: - return None + def get_provider_model_id(self, identifier: str) -> Optional[str]: + return self.alias_to_provider_id_map.get(identifier, None) - def get_llama_model(self, provider_model_id: str) -> str: - if provider_model_id in self.provider_id_to_llama_model_map: - return self.provider_id_to_llama_model_map[provider_model_id] - else: - return None + def get_llama_model(self, provider_model_id: str) -> Optional[str]: + return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: if model.model_type == ModelType.embedding: