Address review comments

Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
Bill Murdock 2025-10-06 15:45:24 -04:00
parent ca771cd921
commit 1d941b6aa0
9 changed files with 61 additions and 200 deletions

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .watsonx import get_distribution_template # noqa: F401

View file

@ -3,44 +3,33 @@ distribution_spec:
description: Use watsonx for running LLM inference
providers:
inference:
- provider_id: watsonx
provider_type: remote::watsonx
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
- provider_type: remote::watsonx
- provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_type: inline::faiss
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
- provider_type: inline::llama-guard
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]
- aiosqlite
- aiosqlite

View file

@ -4,13 +4,13 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
- files
providers:
inference:
- provider_id: watsonx
@ -19,8 +19,6 @@ providers:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -48,7 +46,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
@ -109,102 +107,7 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
models:
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-2-13b-chat
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-2-13b
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-8b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-11b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-1b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-3b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-90b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-guard-3-11b-vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
models: []
shields: []
vector_dbs: []
datasets: []

View file

@ -4,17 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
config=WatsonXConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"inference": [inference_provider],
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_models=[],
default_tool_groups=default_tool_groups,
),
},

View file

@ -277,7 +277,7 @@ Available Models:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",

View file

@ -25,12 +25,16 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
# This seems like it should be required, but none of the other remote inference
# providers require it, so this is optional here too for consistency.
# The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
default=None,
description="The watsonx.ai API key",
)
# As above, this is optional here too for consistency.
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
default=None,
description="The watsonx.ai project ID",
)
timeout: int = Field(

View file

@ -16,9 +16,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_config: WatsonXConfig
__provider_id__: str = "watsonx"
def __init__(self, config: WatsonXConfig):
LiteLLMOpenAIMixin.__init__(
self,
@ -29,17 +26,9 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
self.available_models = None
self.config = config
# get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
return self.config.url
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncIterator
from typing import Any
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
@ -334,6 +336,7 @@ class LiteLLMOpenAIMixin(
api_key=self.get_api_key(),
api_base=self.api_base,
)
logger.info(f"params to litellm (openai compat): {params}")
return await litellm.acompletion(**params)
async def check_model_availability(self, model: str) -> bool:
@ -349,3 +352,28 @@ class LiteLLMOpenAIMixin(
return False
return model in litellm.models_by_provider[self.litellm_provider_name]
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data["embedding"]:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data["embedding"]
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -3,9 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import struct
import time
import uuid
import warnings
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
SamplingParams,
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
params["user"] = user
return params
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data["embedding"]:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data["embedding"]
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data