mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
More review comment fixes
Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
parent
e77b7a127c
commit
a4b9b1e494
2 changed files with 35 additions and 19 deletions
|
|
@ -16,6 +16,8 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
|
||||
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
|
|
@ -38,10 +40,20 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
params["time_limit"] = self.config.timeout
|
||||
return params
|
||||
|
||||
async def check_model_availability(self, model):
|
||||
return True
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
models = []
|
||||
for model_spec in self._get_model_specs():
|
||||
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||
|
|
@ -52,6 +64,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# 'label': 'granite-embedding-278m-multilingual',
|
||||
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||
# ...
|
||||
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||
if "embedding" in functions:
|
||||
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||
|
|
@ -59,25 +72,29 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
"embedding_dimension": embedding_dimension,
|
||||
"context_length": context_length,
|
||||
}
|
||||
models.append(
|
||||
Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata=embedding_metadata,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata=embedding_metadata,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
if "text_chat" in functions:
|
||||
models.append(
|
||||
Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||
# only used for check_model_availability() anyway.
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||
|
|
|
|||
|
|
@ -336,7 +336,6 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
logger.info(f"params to litellm (openai compat): {params}")
|
||||
return await litellm.acompletion(**params)
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue