diff --git a/docs/source/providers/inference/remote_watsonx.md b/docs/source/providers/inference/remote_watsonx.md index 0eb8a6fc4..e885a07fc 100644 --- a/docs/source/providers/inference/remote_watsonx.md +++ b/docs/source/providers/inference/remote_watsonx.md @@ -9,8 +9,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key, only needed of using the hosted service | -| `project_id` | `str \| None` | No | | The Project ID key, only needed of using the hosted service | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | +| `project_id` | `str \| None` | No | | The Project ID key | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | ## Sample Configuration diff --git a/llama_stack/distributions/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml index f5fe31bef..92f367910 100644 --- a/llama_stack/distributions/watsonx/run.yaml +++ b/llama_stack/distributions/watsonx/run.yaml @@ -10,6 +10,7 @@ apis: - telemetry - tool_runtime - vector_io +- files providers: inference: - provider_id: watsonx @@ -94,6 +95,14 @@ providers: provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/watsonx/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/files_metadata.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py index 1ef2ef339..c3cab5d1b 100644 --- a/llama_stack/distributions/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -9,6 +9,7 @@ from pathlib import Path from llama_stack.apis.models import ModelType from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry +from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -16,7 +17,7 @@ from llama_stack.providers.remote.inference.watsonx import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES -def get_distribution_template() -> DistributionTemplate: +def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: providers = { "inference": [ BuildProvider(provider_type="remote::watsonx"), @@ -42,6 +43,7 @@ def get_distribution_template() -> DistributionTemplate: BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], + "files": [BuildProvider(provider_type="inline::localfs")], } inference_provider = Provider( @@ -79,9 +81,14 @@ def get_distribution_template() -> DistributionTemplate: }, ) + files_provider = Provider( + provider_id="meta-reference-files", + provider_type="inline::localfs", + config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) default_models, _ = get_model_registry(available_models) return DistributionTemplate( - name="watsonx", + name=name, distro_type="remote_hosted", description="Use watsonx for running LLM inference", container_image=None, @@ -92,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider, embedding_provider], + "files": [files_provider], }, default_models=default_models + [embedding_model], default_tool_groups=default_tool_groups, diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index ae4bd55c1..42c25d93e 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -26,11 +26,11 @@ class WatsonXConfig(BaseModel): ) api_key: SecretStr | None = Field( default_factory=lambda: os.getenv("WATSONX_API_KEY"), - description="The watsonx API key, only needed of using the hosted service", + description="The watsonx API key", ) project_id: str | None = Field( default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), - description="The Project ID key, only needed of using the hosted service", + description="The Project ID key", ) timeout: int = Field( default=60, diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index cb7fc175f..ab5ca76db 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( TopKSamplingStrategy, TopPSamplingStrategy, ) +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -57,14 +58,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from . import WatsonXConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference::watsonx") + + +# Note on structured output +# WatsonX returns responses with a json embedded into a string. +# Examples: + +# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n +# "first_name": "Michael",\n "last_name": "Jordan",\n'...) +# Not even a valid JSON, but we can still extract the JSON from the content + +# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", +# "year_born": "1963", "year_retired": "2003"\\}}$') +# Find the start of the boxed content + class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): def __init__(self, config: WatsonXConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) - print(f"Initializing watsonx InferenceAdapter({config.url})...") - + logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") self._config = config + self._openai_client: AsyncOpenAI | None = None self._project_id = self._config.project_id diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 35869276b..e7db6ef8f 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,6 +58,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # does not work with the specified model, gpt-5-mini. Please choose different model and try # again. You can learn more about which models can be used with each operation here: # https://go.microsoft.com/fwlink/?linkid=2197993.'}}"} + "remote::watsonx", # return 404 when hitting the /openai/v1 endpoint ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -110,6 +111,8 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::cerebras", "remote::databricks", "remote::runpod", + "remote::tgi", + "remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 621084231..a5f95a963 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -45,7 +45,7 @@ def skip_if_model_doesnt_support_json_schema_structured_output(client_with_model provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - if provider.provider_type in ("remote::sambanova", "remote::azure"): + if provider.provider_type in ("remote::sambanova", "remote::azure", "remote::watsonx"): pytest.skip( f"Model {model_id} hosted by {provider.provider_type} doesn't support json_schema structured output" ) @@ -211,6 +211,7 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id, ) def test_text_completion_structured_output(client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + skip_if_model_doesnt_support_json_schema_structured_output(client_with_models, text_model_id) class AnswerFormat(BaseModel): name: str