mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
Fixed WatsonX bugs
This commit is contained in:
parent
1136daf310
commit
effe7609a9
3 changed files with 236 additions and 19 deletions
|
|
@ -271,7 +271,7 @@ Available Models:
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.watsonx",
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||||
),
|
),
|
||||||
RemoteProviderSpec(
|
RemoteProviderSpec(
|
||||||
|
|
|
||||||
|
|
@ -7,18 +7,18 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class WatsonXProviderDataValidator(BaseModel):
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
model_config = ConfigDict(
|
watsonx_project_id: str | None = Field(
|
||||||
from_attributes=True,
|
default=None,
|
||||||
extra="forbid",
|
description="IBM WatsonX project ID",
|
||||||
)
|
)
|
||||||
watsonx_api_key: str | None
|
watsonx_api_key: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
||||||
|
|
@ -4,42 +4,259 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import litellm
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAIChatCompletionUsage,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
||||||
|
|
||||||
|
|
||||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
_model_cache: dict[str, Model] = {}
|
_model_cache: dict[str, Model] = {}
|
||||||
|
|
||||||
|
provider_data_api_key_field: str = "watsonx_api_key"
|
||||||
|
|
||||||
def __init__(self, config: WatsonXConfig):
|
def __init__(self, config: WatsonXConfig):
|
||||||
|
self.available_models = None
|
||||||
|
self.config = config
|
||||||
|
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name="watsonx",
|
litellm_provider_name="watsonx",
|
||||||
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
api_key_from_config=api_key,
|
||||||
provider_data_api_key_field="watsonx_api_key",
|
provider_data_api_key_field="watsonx_api_key",
|
||||||
|
openai_compat_api_base=self.get_base_url(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""
|
||||||
|
Override parent method to add timeout and inject usage object when missing.
|
||||||
|
This works around a LiteLLM defect where usage block is sometimes dropped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add usage tracking for streaming when telemetry is active
|
||||||
|
stream_options = params.stream_options
|
||||||
|
if params.stream and get_current_span() is not None:
|
||||||
|
if stream_options is None:
|
||||||
|
stream_options = {"include_usage": True}
|
||||||
|
elif "include_usage" not in stream_options:
|
||||||
|
stream_options = {**stream_options, "include_usage": True}
|
||||||
|
|
||||||
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
request_params = await prepare_openai_completion_params(
|
||||||
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
|
messages=params.messages,
|
||||||
|
frequency_penalty=params.frequency_penalty,
|
||||||
|
function_call=params.function_call,
|
||||||
|
functions=params.functions,
|
||||||
|
logit_bias=params.logit_bias,
|
||||||
|
logprobs=params.logprobs,
|
||||||
|
max_completion_tokens=params.max_completion_tokens,
|
||||||
|
max_tokens=params.max_tokens,
|
||||||
|
n=params.n,
|
||||||
|
parallel_tool_calls=params.parallel_tool_calls,
|
||||||
|
presence_penalty=params.presence_penalty,
|
||||||
|
response_format=params.response_format,
|
||||||
|
seed=params.seed,
|
||||||
|
stop=params.stop,
|
||||||
|
stream=params.stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=params.temperature,
|
||||||
|
tool_choice=params.tool_choice,
|
||||||
|
tools=params.tools,
|
||||||
|
top_logprobs=params.top_logprobs,
|
||||||
|
top_p=params.top_p,
|
||||||
|
user=params.user,
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
api_base=self.api_base,
|
||||||
|
# These are watsonx-specific parameters
|
||||||
|
timeout=self.config.timeout,
|
||||||
|
project_id=self.config.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await litellm.acompletion(**request_params)
|
||||||
|
|
||||||
|
# If not streaming, check and inject usage if missing
|
||||||
|
if not params.stream:
|
||||||
|
# Use getattr to safely handle cases where usage attribute might not exist
|
||||||
|
if getattr(result, "usage", None) is None:
|
||||||
|
# Create usage object with zeros
|
||||||
|
usage_obj = OpenAIChatCompletionUsage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
)
|
||||||
|
# Use model_copy to create a new response with the usage injected
|
||||||
|
result = result.model_copy(update={"usage": usage_obj})
|
||||||
|
return result
|
||||||
|
|
||||||
|
# For streaming, wrap the iterator to normalize chunks
|
||||||
|
return self._normalize_stream(result)
|
||||||
|
|
||||||
|
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
|
||||||
|
"""
|
||||||
|
Normalize a chunk to ensure it has all expected attributes.
|
||||||
|
This works around LiteLLM not always including all expected attributes.
|
||||||
|
"""
|
||||||
|
# Ensure chunk has usage attribute with zeros if missing
|
||||||
|
if not hasattr(chunk, "usage") or chunk.usage is None:
|
||||||
|
usage_obj = OpenAIChatCompletionUsage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
)
|
||||||
|
chunk = chunk.model_copy(update={"usage": usage_obj})
|
||||||
|
|
||||||
|
# Ensure all delta objects in choices have expected attributes
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
normalized_choices = []
|
||||||
|
for choice in chunk.choices:
|
||||||
|
if hasattr(choice, "delta") and choice.delta:
|
||||||
|
delta = choice.delta
|
||||||
|
# Build update dict for missing attributes
|
||||||
|
delta_updates = {}
|
||||||
|
if not hasattr(delta, "refusal"):
|
||||||
|
delta_updates["refusal"] = None
|
||||||
|
if not hasattr(delta, "reasoning_content"):
|
||||||
|
delta_updates["reasoning_content"] = None
|
||||||
|
|
||||||
|
# If we need to update delta, create a new choice with updated delta
|
||||||
|
if delta_updates:
|
||||||
|
new_delta = delta.model_copy(update=delta_updates)
|
||||||
|
new_choice = choice.model_copy(update={"delta": new_delta})
|
||||||
|
normalized_choices.append(new_choice)
|
||||||
|
else:
|
||||||
|
normalized_choices.append(choice)
|
||||||
|
else:
|
||||||
|
normalized_choices.append(choice)
|
||||||
|
|
||||||
|
# If we modified any choices, create a new chunk with updated choices
|
||||||
|
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
|
||||||
|
chunk = chunk.model_copy(update={"choices": normalized_choices})
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
async def _normalize_stream(
|
||||||
|
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
|
||||||
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""
|
||||||
|
Normalize all chunks in the stream to ensure they have expected attributes.
|
||||||
|
This works around LiteLLM sometimes not including expected attributes.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for chunk in stream:
|
||||||
|
# Normalize and yield each chunk immediately
|
||||||
|
yield self._normalize_chunk(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error normalizing stream: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
"""
|
||||||
|
Override parent method to add watsonx-specific parameters.
|
||||||
|
"""
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
|
||||||
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
request_params = await prepare_openai_completion_params(
|
||||||
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
|
prompt=params.prompt,
|
||||||
|
best_of=params.best_of,
|
||||||
|
echo=params.echo,
|
||||||
|
frequency_penalty=params.frequency_penalty,
|
||||||
|
logit_bias=params.logit_bias,
|
||||||
|
logprobs=params.logprobs,
|
||||||
|
max_tokens=params.max_tokens,
|
||||||
|
n=params.n,
|
||||||
|
presence_penalty=params.presence_penalty,
|
||||||
|
seed=params.seed,
|
||||||
|
stop=params.stop,
|
||||||
|
stream=params.stream,
|
||||||
|
stream_options=params.stream_options,
|
||||||
|
temperature=params.temperature,
|
||||||
|
top_p=params.top_p,
|
||||||
|
user=params.user,
|
||||||
|
suffix=params.suffix,
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
api_base=self.api_base,
|
||||||
|
# These are watsonx-specific parameters
|
||||||
|
timeout=self.config.timeout,
|
||||||
|
project_id=self.config.project_id,
|
||||||
|
)
|
||||||
|
return await litellm.atext_completion(**request_params)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""
|
||||||
|
Override parent method to add watsonx-specific parameters.
|
||||||
|
"""
|
||||||
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
# Convert input to list if it's a string
|
||||||
|
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||||
|
|
||||||
|
# Call litellm embedding function with watsonx-specific parameters
|
||||||
|
response = litellm.embedding(
|
||||||
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
|
input=input_list,
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
api_base=self.api_base,
|
||||||
|
dimensions=params.dimensions,
|
||||||
|
# These are watsonx-specific parameters
|
||||||
|
timeout=self.config.timeout,
|
||||||
|
project_id=self.config.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to OpenAI format
|
||||||
|
from llama_stack.apis.inference import OpenAIEmbeddingUsage
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
|
||||||
|
|
||||||
|
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
|
||||||
|
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||||
|
total_tokens=response["usage"]["total_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
usage=usage,
|
||||||
)
|
)
|
||||||
self.available_models = None
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return self.config.url
|
return self.config.url
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
|
||||||
# Get base parameters from parent
|
|
||||||
params = await super()._get_params(request)
|
|
||||||
|
|
||||||
# Add watsonx.ai specific parameters
|
|
||||||
params["project_id"] = self.config.project_id
|
|
||||||
params["time_limit"] = self.config.timeout
|
|
||||||
return params
|
|
||||||
|
|
||||||
# Copied from OpenAIMixin
|
# Copied from OpenAIMixin
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue