From bf11cc0450722ac7ec728f0a57f3388545ce4c8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 12 Feb 2025 07:10:28 +0100 Subject: [PATCH] chore: update return type to Optional[str] (#982) --- .../providers/utils/inference/model_registry.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 5746af4ba..dea951395 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -57,17 +57,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model - def get_provider_model_id(self, identifier: str) -> str: - if identifier in self.alias_to_provider_id_map: - return self.alias_to_provider_id_map[identifier] - else: - return None + def get_provider_model_id(self, identifier: str) -> Optional[str]: + return self.alias_to_provider_id_map.get(identifier, None) - def get_llama_model(self, provider_model_id: str) -> str: - if provider_model_id in self.provider_id_to_llama_model_map: - return self.provider_id_to_llama_model_map[provider_model_id] - else: - return None + def get_llama_model(self, provider_model_id: str) -> Optional[str]: + return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: if model.model_type == ModelType.embedding: