mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-25 09:05:37 +00:00
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 14s
UI Tests / ui-tests (22) (push) Successful in 43s
Pre-commit / pre-commit (push) Successful in 1m35s
# What does this PR do? Clean up telemetry code since the telemetry API has been remove. - moved telemetry files out of providers to core - removed from Api ## Test Plan ❯ OTEL_SERVICE_NAME=llama_stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter ❯ curl http://localhost:8321/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "openai/gpt-4o-mini", "messages": [ { "role": "user", "content": "Hello!" } ] }' -> verify traces in Grafana CI
340 lines
14 KiB
Python
340 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
import litellm
|
|
import requests
|
|
|
|
from llama_stack.apis.inference.inference import (
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
OpenAIChatCompletionUsage,
|
|
OpenAICompletion,
|
|
OpenAICompletionRequestWithExtraBody,
|
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
OpenAIEmbeddingsResponse,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.apis.models.models import ModelType
|
|
from llama_stack.core.telemetry.tracing import get_current_span
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
|
|
|
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
|
|
|
|
|
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|
_model_cache: dict[str, Model] = {}
|
|
|
|
provider_data_api_key_field: str = "watsonx_api_key"
|
|
|
|
def __init__(self, config: WatsonXConfig):
|
|
self.available_models = None
|
|
self.config = config
|
|
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
|
|
LiteLLMOpenAIMixin.__init__(
|
|
self,
|
|
litellm_provider_name="watsonx",
|
|
api_key_from_config=api_key,
|
|
provider_data_api_key_field="watsonx_api_key",
|
|
openai_compat_api_base=self.get_base_url(),
|
|
)
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
"""
|
|
Override parent method to add timeout and inject usage object when missing.
|
|
This works around a LiteLLM defect where usage block is sometimes dropped.
|
|
"""
|
|
|
|
# Add usage tracking for streaming when telemetry is active
|
|
stream_options = params.stream_options
|
|
if params.stream and get_current_span() is not None:
|
|
if stream_options is None:
|
|
stream_options = {"include_usage": True}
|
|
elif "include_usage" not in stream_options:
|
|
stream_options = {**stream_options, "include_usage": True}
|
|
|
|
model_obj = await self.model_store.get_model(params.model)
|
|
|
|
request_params = await prepare_openai_completion_params(
|
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
|
messages=params.messages,
|
|
frequency_penalty=params.frequency_penalty,
|
|
function_call=params.function_call,
|
|
functions=params.functions,
|
|
logit_bias=params.logit_bias,
|
|
logprobs=params.logprobs,
|
|
max_completion_tokens=params.max_completion_tokens,
|
|
max_tokens=params.max_tokens,
|
|
n=params.n,
|
|
parallel_tool_calls=params.parallel_tool_calls,
|
|
presence_penalty=params.presence_penalty,
|
|
response_format=params.response_format,
|
|
seed=params.seed,
|
|
stop=params.stop,
|
|
stream=params.stream,
|
|
stream_options=stream_options,
|
|
temperature=params.temperature,
|
|
tool_choice=params.tool_choice,
|
|
tools=params.tools,
|
|
top_logprobs=params.top_logprobs,
|
|
top_p=params.top_p,
|
|
user=params.user,
|
|
api_key=self.get_api_key(),
|
|
api_base=self.api_base,
|
|
# These are watsonx-specific parameters
|
|
timeout=self.config.timeout,
|
|
project_id=self.config.project_id,
|
|
)
|
|
|
|
result = await litellm.acompletion(**request_params)
|
|
|
|
# If not streaming, check and inject usage if missing
|
|
if not params.stream:
|
|
# Use getattr to safely handle cases where usage attribute might not exist
|
|
if getattr(result, "usage", None) is None:
|
|
# Create usage object with zeros
|
|
usage_obj = OpenAIChatCompletionUsage(
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
)
|
|
# Use model_copy to create a new response with the usage injected
|
|
result = result.model_copy(update={"usage": usage_obj})
|
|
return result
|
|
|
|
# For streaming, wrap the iterator to normalize chunks
|
|
return self._normalize_stream(result)
|
|
|
|
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
|
|
"""
|
|
Normalize a chunk to ensure it has all expected attributes.
|
|
This works around LiteLLM not always including all expected attributes.
|
|
"""
|
|
# Ensure chunk has usage attribute with zeros if missing
|
|
if not hasattr(chunk, "usage") or chunk.usage is None:
|
|
usage_obj = OpenAIChatCompletionUsage(
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
)
|
|
chunk = chunk.model_copy(update={"usage": usage_obj})
|
|
|
|
# Ensure all delta objects in choices have expected attributes
|
|
if hasattr(chunk, "choices") and chunk.choices:
|
|
normalized_choices = []
|
|
for choice in chunk.choices:
|
|
if hasattr(choice, "delta") and choice.delta:
|
|
delta = choice.delta
|
|
# Build update dict for missing attributes
|
|
delta_updates = {}
|
|
if not hasattr(delta, "refusal"):
|
|
delta_updates["refusal"] = None
|
|
if not hasattr(delta, "reasoning_content"):
|
|
delta_updates["reasoning_content"] = None
|
|
|
|
# If we need to update delta, create a new choice with updated delta
|
|
if delta_updates:
|
|
new_delta = delta.model_copy(update=delta_updates)
|
|
new_choice = choice.model_copy(update={"delta": new_delta})
|
|
normalized_choices.append(new_choice)
|
|
else:
|
|
normalized_choices.append(choice)
|
|
else:
|
|
normalized_choices.append(choice)
|
|
|
|
# If we modified any choices, create a new chunk with updated choices
|
|
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
|
|
chunk = chunk.model_copy(update={"choices": normalized_choices})
|
|
|
|
return chunk
|
|
|
|
async def _normalize_stream(
|
|
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
|
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
"""
|
|
Normalize all chunks in the stream to ensure they have expected attributes.
|
|
This works around LiteLLM sometimes not including expected attributes.
|
|
"""
|
|
try:
|
|
async for chunk in stream:
|
|
# Normalize and yield each chunk immediately
|
|
yield self._normalize_chunk(chunk)
|
|
except Exception as e:
|
|
logger.error(f"Error normalizing stream: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def openai_completion(
|
|
self,
|
|
params: OpenAICompletionRequestWithExtraBody,
|
|
) -> OpenAICompletion:
|
|
"""
|
|
Override parent method to add watsonx-specific parameters.
|
|
"""
|
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
|
|
|
model_obj = await self.model_store.get_model(params.model)
|
|
|
|
request_params = await prepare_openai_completion_params(
|
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
|
prompt=params.prompt,
|
|
best_of=params.best_of,
|
|
echo=params.echo,
|
|
frequency_penalty=params.frequency_penalty,
|
|
logit_bias=params.logit_bias,
|
|
logprobs=params.logprobs,
|
|
max_tokens=params.max_tokens,
|
|
n=params.n,
|
|
presence_penalty=params.presence_penalty,
|
|
seed=params.seed,
|
|
stop=params.stop,
|
|
stream=params.stream,
|
|
stream_options=params.stream_options,
|
|
temperature=params.temperature,
|
|
top_p=params.top_p,
|
|
user=params.user,
|
|
suffix=params.suffix,
|
|
api_key=self.get_api_key(),
|
|
api_base=self.api_base,
|
|
# These are watsonx-specific parameters
|
|
timeout=self.config.timeout,
|
|
project_id=self.config.project_id,
|
|
)
|
|
return await litellm.atext_completion(**request_params)
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
|
) -> OpenAIEmbeddingsResponse:
|
|
"""
|
|
Override parent method to add watsonx-specific parameters.
|
|
"""
|
|
model_obj = await self.model_store.get_model(params.model)
|
|
|
|
# Convert input to list if it's a string
|
|
input_list = [params.input] if isinstance(params.input, str) else params.input
|
|
|
|
# Call litellm embedding function with watsonx-specific parameters
|
|
response = litellm.embedding(
|
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
|
input=input_list,
|
|
api_key=self.get_api_key(),
|
|
api_base=self.api_base,
|
|
dimensions=params.dimensions,
|
|
# These are watsonx-specific parameters
|
|
timeout=self.config.timeout,
|
|
project_id=self.config.project_id,
|
|
)
|
|
|
|
# Convert response to OpenAI format
|
|
from llama_stack.apis.inference import OpenAIEmbeddingUsage
|
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
|
|
|
|
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
|
|
|
|
usage = OpenAIEmbeddingUsage(
|
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
|
total_tokens=response["usage"]["total_tokens"],
|
|
)
|
|
|
|
return OpenAIEmbeddingsResponse(
|
|
data=data,
|
|
model=model_obj.provider_resource_id,
|
|
usage=usage,
|
|
)
|
|
|
|
def get_base_url(self) -> str:
|
|
return self.config.url
|
|
|
|
# Copied from OpenAIMixin
|
|
async def check_model_availability(self, model: str) -> bool:
|
|
"""
|
|
Check if a specific model is available from the provider's /v1/models.
|
|
|
|
:param model: The model identifier to check.
|
|
:return: True if the model is available dynamically, False otherwise.
|
|
"""
|
|
if not self._model_cache:
|
|
await self.list_models()
|
|
return model in self._model_cache
|
|
|
|
async def list_models(self) -> list[Model] | None:
|
|
self._model_cache = {}
|
|
models = []
|
|
for model_spec in self._get_model_specs():
|
|
functions = [f["id"] for f in model_spec.get("functions", [])]
|
|
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
|
|
|
# Example of an embedding model:
|
|
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
|
# 'label': 'granite-embedding-278m-multilingual',
|
|
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
|
# ...
|
|
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
|
if "embedding" in functions:
|
|
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
|
context_length = model_spec["model_limits"]["max_sequence_length"]
|
|
embedding_metadata = {
|
|
"embedding_dimension": embedding_dimension,
|
|
"context_length": context_length,
|
|
}
|
|
model = Model(
|
|
identifier=model_spec["model_id"],
|
|
provider_resource_id=provider_resource_id,
|
|
provider_id=self.__provider_id__,
|
|
metadata=embedding_metadata,
|
|
model_type=ModelType.embedding,
|
|
)
|
|
self._model_cache[provider_resource_id] = model
|
|
models.append(model)
|
|
if "text_chat" in functions:
|
|
model = Model(
|
|
identifier=model_spec["model_id"],
|
|
provider_resource_id=provider_resource_id,
|
|
provider_id=self.__provider_id__,
|
|
metadata={},
|
|
model_type=ModelType.llm,
|
|
)
|
|
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
|
# In that case, the cache will record the generator Model object, and the list which we return will have
|
|
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
|
# only used for check_model_availability() anyway.
|
|
self._model_cache[provider_resource_id] = model
|
|
models.append(model)
|
|
return models
|
|
|
|
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
|
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
|
def _get_model_specs(self) -> list[dict[str, Any]]:
|
|
"""
|
|
Retrieves foundation model specifications from the watsonx.ai API.
|
|
"""
|
|
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
|
headers = {
|
|
# Note that there is no authorization header. Listing models does not require authentication.
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
response = requests.get(url, headers=headers)
|
|
|
|
# --- Process the Response ---
|
|
# Raise an exception for bad status codes (4xx or 5xx)
|
|
response.raise_for_status()
|
|
|
|
# If the request is successful, parse and return the JSON response.
|
|
# The response should contain a list of model specifications
|
|
response_data = response.json()
|
|
if "resources" not in response_data:
|
|
raise ValueError("Resources not found in response")
|
|
return response_data["resources"]
|