add WatsonX inference adapter

This commit is contained in:
Sajikumar JS 2025-03-13 00:08:55 +05:30
parent a9c5d3cd3d
commit 44c51efc55
9 changed files with 395 additions and 0 deletions

View file

@ -258,4 +258,14 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx",
pip_packages=["ibm_watson_machine_learning"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
),
),
]

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional, Dict, Any
from llama_stack.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
@json_schema_type
class WatsonXConfig(BaseModel):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the Watsonx.ai",
)
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The Watsonx API key, only needed of using the hosted service",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}"
}

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
)
]

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Union, AsyncIterator
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import Inference, Message, ToolChoice, ResponseFormat, LogProbConfig, ToolConfig, \
ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, TextTruncation, EmbeddingTaskType
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from . import WatsonXConfig
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from .models import MODEL_ENTRIES
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing WatsonXInferenceAdapter({config.url})...")
self._config = config
self._credential = {
"url": self._config.url,
"apikey": self._config.api_key
}
self._project_id = self._config.project_id
self.params = {
GenParams.MAX_NEW_TOKENS: 4096,
GenParams.STOP_SEQUENCES: ["<|endoftext|>"]
}
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
):
pass
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
pass
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
):
# Language model
model = Model(
model_id=model_id,
credentials=self._credential,
project_id=self._project_id,
)
prompt = "\n".join(messages) + "\nAI: "
response = model.generate_text(prompt=prompt, params=self.params)
return response

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .watsonx import get_distribution_template # noqa: F401

View file

@ -0,0 +1,30 @@
version: '2'
distribution_spec:
description: Use WatsonX for running LLM inference
providers:
inference:
- remote::watsonx
vector_io:
- inline::faiss
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: venv

View file

@ -0,0 +1,87 @@
version: '2'
image_name: watsonx
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: watsonx
provider_type: remote::watsonx
config:
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:}
project_id: ${env.WATSONX_PROJECT_ID:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llamastack-watsonx}/agents_store.db
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llamastack-watsonx}/faiss_store.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/llamastack-watsonx/trace_store.db}
metadata_store: null
models:
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::watsonx"],
"vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
inference_provider = Provider(
provider_id="watsonx",
provider_type="remote::watsonx",
config=WatsonXConfig.sample_run_config(),
)
available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
default_models = get_model_registry(available_models)
return DistributionTemplate(
name="watsonx",
distro_type="remote_hosted",
description="Use WatsonX for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"WATSONX_API_KEY": (
"",
"Watsonx API Key",
),
"WATSONX_PROJECT_ID": (
"",
"Watsonx Project ID",
),
},
)