More review comment fixes

Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
Bill Murdock 2025-10-06 16:43:41 -04:00
parent e77b7a127c
commit a4b9b1e494
2 changed files with 35 additions and 19 deletions

View file

@ -16,6 +16,8 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_model_cache: dict[str, Model] = {}
def __init__(self, config: WatsonXConfig):
LiteLLMOpenAIMixin.__init__(
self,
@ -38,10 +40,20 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
params["time_limit"] = self.config.timeout
return params
async def check_model_availability(self, model):
return True
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
models = []
for model_spec in self._get_model_specs():
functions = [f["id"] for f in model_spec.get("functions", [])]
@ -52,6 +64,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
# 'label': 'granite-embedding-278m-multilingual',
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
# ...
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
if "embedding" in functions:
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
context_length = model_spec["model_limits"]["max_sequence_length"]
@ -59,25 +72,29 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
"embedding_dimension": embedding_dimension,
"context_length": context_length,
}
models.append(
Model(
identifier=model_spec["model_id"],
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
provider_id=self.__provider_id__,
metadata=embedding_metadata,
model_type=ModelType.embedding,
)
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata=embedding_metadata,
model_type=ModelType.embedding,
)
self._model_cache[provider_resource_id] = model
models.append(model)
if "text_chat" in functions:
models.append(
Model(
identifier=model_spec["model_id"],
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
# In that case, the cache will record the generator Model object, and the list which we return will have
# both the generator Model object and the text chat Model object. That's fine because the cache is
# only used for check_model_availability() anyway.
self._model_cache[provider_resource_id] = model
models.append(model)
return models
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.

View file

@ -336,7 +336,6 @@ class LiteLLMOpenAIMixin(
api_key=self.get_api_key(),
api_base=self.api_base,
)
logger.info(f"params to litellm (openai compat): {params}")
return await litellm.acompletion(**params)
async def check_model_availability(self, model: str) -> bool: