diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 9fc258880..b69cd822f 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -40,7 +40,9 @@ from llama_stack.apis.inference.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, ) +from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.exceptions import UnsupportedModelError from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( @@ -90,6 +92,12 @@ class LiteLLMOpenAIMixin( async def shutdown(self): pass + async def register_model(self, model: Model) -> Model: + model_id = self.get_provider_model_id(model.provider_resource_id) + if model_id is None: + raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) + return model + def get_litellm_model_name(self, model_id: str) -> str: # users may be using openai/ prefix in their model names. the openai/models.py did this by default. # model_id.startswith("openai/") is for backwards compatibility.