mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
# What does this PR do? use SecretStr for OpenAIMixin providers - RemoteInferenceProviderConfig now has auth_credential: SecretStr - the default alias is api_key (most common name) - some providers override to use api_token (RunPod, vLLM, Databricks) - some providers exclude it (Ollama, TGI, Vertex AI) addresses #3517 ## Test Plan ci w/ new tests
123 lines
5.3 KiB
Python
123 lines
5.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
from llama_stack.apis.inference import ChatCompletionRequest
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.apis.models.models import ModelType
|
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
|
|
|
|
|
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|
_model_cache: dict[str, Model] = {}
|
|
|
|
def __init__(self, config: WatsonXConfig):
|
|
LiteLLMOpenAIMixin.__init__(
|
|
self,
|
|
litellm_provider_name="watsonx",
|
|
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
|
provider_data_api_key_field="watsonx_api_key",
|
|
)
|
|
self.available_models = None
|
|
self.config = config
|
|
|
|
def get_base_url(self) -> str:
|
|
return self.config.url
|
|
|
|
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
|
# Get base parameters from parent
|
|
params = await super()._get_params(request)
|
|
|
|
# Add watsonx.ai specific parameters
|
|
params["project_id"] = self.config.project_id
|
|
params["time_limit"] = self.config.timeout
|
|
return params
|
|
|
|
# Copied from OpenAIMixin
|
|
async def check_model_availability(self, model: str) -> bool:
|
|
"""
|
|
Check if a specific model is available from the provider's /v1/models.
|
|
|
|
:param model: The model identifier to check.
|
|
:return: True if the model is available dynamically, False otherwise.
|
|
"""
|
|
if not self._model_cache:
|
|
await self.list_models()
|
|
return model in self._model_cache
|
|
|
|
async def list_models(self) -> list[Model] | None:
|
|
self._model_cache = {}
|
|
models = []
|
|
for model_spec in self._get_model_specs():
|
|
functions = [f["id"] for f in model_spec.get("functions", [])]
|
|
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
|
|
|
# Example of an embedding model:
|
|
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
|
# 'label': 'granite-embedding-278m-multilingual',
|
|
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
|
# ...
|
|
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
|
if "embedding" in functions:
|
|
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
|
context_length = model_spec["model_limits"]["max_sequence_length"]
|
|
embedding_metadata = {
|
|
"embedding_dimension": embedding_dimension,
|
|
"context_length": context_length,
|
|
}
|
|
model = Model(
|
|
identifier=model_spec["model_id"],
|
|
provider_resource_id=provider_resource_id,
|
|
provider_id=self.__provider_id__,
|
|
metadata=embedding_metadata,
|
|
model_type=ModelType.embedding,
|
|
)
|
|
self._model_cache[provider_resource_id] = model
|
|
models.append(model)
|
|
if "text_chat" in functions:
|
|
model = Model(
|
|
identifier=model_spec["model_id"],
|
|
provider_resource_id=provider_resource_id,
|
|
provider_id=self.__provider_id__,
|
|
metadata={},
|
|
model_type=ModelType.llm,
|
|
)
|
|
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
|
# In that case, the cache will record the generator Model object, and the list which we return will have
|
|
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
|
# only used for check_model_availability() anyway.
|
|
self._model_cache[provider_resource_id] = model
|
|
models.append(model)
|
|
return models
|
|
|
|
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
|
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
|
def _get_model_specs(self) -> list[dict[str, Any]]:
|
|
"""
|
|
Retrieves foundation model specifications from the watsonx.ai API.
|
|
"""
|
|
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
|
headers = {
|
|
# Note that there is no authorization header. Listing models does not require authentication.
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
response = requests.get(url, headers=headers)
|
|
|
|
# --- Process the Response ---
|
|
# Raise an exception for bad status codes (4xx or 5xx)
|
|
response.raise_for_status()
|
|
|
|
# If the request is successful, parse and return the JSON response.
|
|
# The response should contain a list of model specifications
|
|
response_data = response.json()
|
|
if "resources" not in response_data:
|
|
raise ValueError("Resources not found in response")
|
|
return response_data["resources"]
|