diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 5746af4ba..dea951395 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -57,17 +57,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model - def get_provider_model_id(self, identifier: str) -> str: - if identifier in self.alias_to_provider_id_map: - return self.alias_to_provider_id_map[identifier] - else: - return None + def get_provider_model_id(self, identifier: str) -> Optional[str]: + return self.alias_to_provider_id_map.get(identifier, None) - def get_llama_model(self, provider_model_id: str) -> str: - if provider_model_id in self.provider_id_to_llama_model_map: - return self.provider_id_to_llama_model_map[provider_model_id] - else: - return None + def get_llama_model(self, provider_model_id: str) -> Optional[str]: + return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: if model.model_type == ModelType.embedding: