mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
add WatsonX inference adapter
This commit is contained in:
parent
a9c5d3cd3d
commit
44c51efc55
9 changed files with 395 additions and 0 deletions
|
@ -258,4 +258,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="watsonx",
|
||||||
|
pip_packages=["ibm_watson_machine_learning"],
|
||||||
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||||
|
from .watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, WatsonXConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
adapter = WatsonXInferenceAdapter(config)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
|
project_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WatsonXConfig(BaseModel):
|
||||||
|
|
||||||
|
url: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
|
description="A base url for accessing the Watsonx.ai",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||||
|
description="The Watsonx API key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
project_id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||||
|
description="The Project ID key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Timeout for the HTTP requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||||
|
"api_key": "${env.WATSONX_API_KEY:}",
|
||||||
|
"project_id": "${env.WATSONX_PROJECT_ID:}"
|
||||||
|
}
|
16
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
16
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-3-70b-instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
87
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
87
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional, Union, AsyncIterator
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
|
from llama_stack.apis.inference import Inference, Message, ToolChoice, ResponseFormat, LogProbConfig, ToolConfig, \
|
||||||
|
ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, TextTruncation, EmbeddingTaskType
|
||||||
|
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
|
from . import WatsonXConfig
|
||||||
|
|
||||||
|
from ibm_watson_machine_learning.foundation_models import Model
|
||||||
|
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||||
|
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
def __init__(self, config: WatsonXConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
|
|
||||||
|
print(f"Initializing WatsonXInferenceAdapter({config.url})...")
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
self._credential = {
|
||||||
|
"url": self._config.url,
|
||||||
|
"apikey": self._config.api_key
|
||||||
|
}
|
||||||
|
|
||||||
|
self._project_id = self._config.project_id
|
||||||
|
self.params = {
|
||||||
|
GenParams.MAX_NEW_TOKENS: 4096,
|
||||||
|
GenParams.STOP_SEQUENCES: ["<|endoftext|>"]
|
||||||
|
}
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
):
|
||||||
|
# Language model
|
||||||
|
model = Model(
|
||||||
|
model_id=model_id,
|
||||||
|
credentials=self._credential,
|
||||||
|
project_id=self._project_id,
|
||||||
|
)
|
||||||
|
prompt = "\n".join(messages) + "\nAI: "
|
||||||
|
|
||||||
|
response = model.generate_text(prompt=prompt, params=self.params)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
7
llama_stack/templates/watsonx/__init__.py
Normal file
7
llama_stack/templates/watsonx/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
30
llama_stack/templates/watsonx/build.yaml
Normal file
30
llama_stack/templates/watsonx/build.yaml
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Use WatsonX for running LLM inference
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::watsonx
|
||||||
|
vector_io:
|
||||||
|
- inline::faiss
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: venv
|
87
llama_stack/templates/watsonx/run.yaml
Normal file
87
llama_stack/templates/watsonx/run.yaml
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: watsonx
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: watsonx
|
||||||
|
provider_type: remote::watsonx
|
||||||
|
config:
|
||||||
|
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||||
|
api_key: ${env.WATSONX_API_KEY:}
|
||||||
|
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llamastack-watsonx}/agents_store.db
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llamastack-watsonx}/faiss_store.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config: {}
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config: {}
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
post_training:
|
||||||
|
- provider_id: torchtune
|
||||||
|
provider_type: inline::torchtune
|
||||||
|
config: {}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/llamastack-watsonx/trace_store.db}
|
||||||
|
metadata_store: null
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
shields: []
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
90
llama_stack/templates/watsonx/watsonx.py
Normal file
90
llama_stack/templates/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
|
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::watsonx"],
|
||||||
|
"vector_io": ["inline::faiss"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="watsonx",
|
||||||
|
provider_type="remote::watsonx",
|
||||||
|
config=WatsonXConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
"watsonx": MODEL_ENTRIES,
|
||||||
|
}
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="watsonx",
|
||||||
|
distro_type="remote_hosted",
|
||||||
|
description="Use WatsonX for running LLM inference",
|
||||||
|
container_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"WATSONX_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Watsonx API Key",
|
||||||
|
),
|
||||||
|
"WATSONX_PROJECT_ID": (
|
||||||
|
"",
|
||||||
|
"Watsonx Project ID",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue