diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 0d3af70f6..e7f96405a 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -16,6 +16,8 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): + _model_cache: dict[str, Model] = {} + def __init__(self, config: WatsonXConfig): LiteLLMOpenAIMixin.__init__( self, @@ -38,10 +40,20 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): params["time_limit"] = self.config.timeout return params - async def check_model_availability(self, model): - return True + # Copied from OpenAIMixin + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the provider's /v1/models. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + if not self._model_cache: + await self.list_models() + return model in self._model_cache async def list_models(self) -> list[Model] | None: + self._model_cache = {} models = [] for model_spec in self._get_model_specs(): functions = [f["id"] for f in model_spec.get("functions", [])] @@ -52,6 +64,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): # 'label': 'granite-embedding-278m-multilingual', # 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768}, # ... + provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}" if "embedding" in functions: embedding_dimension = model_spec["model_limits"]["embedding_dimension"] context_length = model_spec["model_limits"]["max_sequence_length"] @@ -59,25 +72,29 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): "embedding_dimension": embedding_dimension, "context_length": context_length, } - models.append( - Model( - identifier=model_spec["model_id"], - provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}", - provider_id=self.__provider_id__, - metadata=embedding_metadata, - model_type=ModelType.embedding, - ) + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata=embedding_metadata, + model_type=ModelType.embedding, ) + self._model_cache[provider_resource_id] = model + models.append(model) if "text_chat" in functions: - models.append( - Model( - identifier=model_spec["model_id"], - provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}", - provider_id=self.__provider_id__, - metadata={}, - model_type=ModelType.llm, - ) + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata={}, + model_type=ModelType.llm, ) + # In theory, I guess it is possible that a model could be both an embedding model and a text chat model. + # In that case, the cache will record the generator Model object, and the list which we return will have + # both the generator Model object and the text chat Model object. That's fine because the cache is + # only used for check_model_availability() anyway. + self._model_cache[provider_resource_id] = model + models.append(model) return models # LiteLLM provides methods to list models for many providers, but not for watsonx.ai. diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 444a6a715..6bef97dd5 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -336,7 +336,6 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - logger.info(f"params to litellm (openai compat): {params}") return await litellm.acompletion(**params) async def check_model_availability(self, model: str) -> bool: