diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 4c5393947..dcb3ef945 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4191,6 +4191,12 @@
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
+ },
+ "additional_params": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index a24f1a9db..c0a704230 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2907,6 +2907,10 @@ components:
description: >-
Up to 4 sequences where the API will stop generating further tokens. The
returned text will not contain the stop sequence.
+ additional_params:
+ type: object
+ additionalProperties:
+ type: string
additionalProperties: false
required:
- strategy
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 309171f20..cb72c3b76 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -82,6 +82,7 @@ class SamplingParams(BaseModel):
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
+ additional_params: Optional[Dict[str, str]] = {}
class LogProbConfig(BaseModel):
diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py
index d5d87ec01..63484c888 100644
--- a/llama_stack/providers/remote/inference/watsonx/watsonx.py
+++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py
@@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import AsyncGenerator, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
+from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
@@ -27,10 +28,18 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
+from llama_stack.apis.inference.inference import (
+ OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
+ OpenAICompletion,
+ OpenAIMessageParam,
+ OpenAIResponseFormatParam,
+)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
+ prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@@ -95,6 +104,14 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
+ def _get_openai_client(self) -> AsyncOpenAI:
+ if not self._openai_client:
+ self._openai_client = AsyncOpenAI(
+ base_url=f"{self._config.url}/openai/v1",
+ api_key=self._config.api_key,
+ )
+ return self._openai_client
+
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
@@ -257,4 +274,122 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
- pass
+ raise NotImplementedError("embedding is not supported for watsonx")
+
+ async def openai_completion(
+ self,
+ model: str,
+ prompt: Union[str, List[str], List[int], List[List[int]]],
+ best_of: Optional[int] = None,
+ echo: Optional[bool] = None,
+ frequency_penalty: Optional[float] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ logprobs: Optional[bool] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[float] = None,
+ seed: Optional[int] = None,
+ stop: Optional[Union[str, List[str]]] = None,
+ stream: Optional[bool] = None,
+ stream_options: Optional[Dict[str, Any]] = None,
+ temperature: Optional[float] = None,
+ top_p: Optional[float] = None,
+ user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
+ ) -> OpenAICompletion:
+ model_obj = await self.model_store.get_model(model)
+ params = await prepare_openai_completion_params(
+ model=model_obj.provider_resource_id,
+ prompt=prompt,
+ best_of=best_of,
+ echo=echo,
+ frequency_penalty=frequency_penalty,
+ logit_bias=logit_bias,
+ logprobs=logprobs,
+ max_tokens=max_tokens,
+ n=n,
+ presence_penalty=presence_penalty,
+ seed=seed,
+ stop=stop,
+ stream=stream,
+ stream_options=stream_options,
+ temperature=temperature,
+ top_p=top_p,
+ user=user,
+ )
+ return await self._get_openai_client().completions.create(**params) # type: ignore
+
+ async def openai_chat_completion(
+ self,
+ model: str,
+ messages: List[OpenAIMessageParam],
+ frequency_penalty: Optional[float] = None,
+ function_call: Optional[Union[str, Dict[str, Any]]] = None,
+ functions: Optional[List[Dict[str, Any]]] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ logprobs: Optional[bool] = None,
+ max_completion_tokens: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ presence_penalty: Optional[float] = None,
+ response_format: Optional[OpenAIResponseFormatParam] = None,
+ seed: Optional[int] = None,
+ stop: Optional[Union[str, List[str]]] = None,
+ stream: Optional[bool] = None,
+ stream_options: Optional[Dict[str, Any]] = None,
+ temperature: Optional[float] = None,
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ top_logprobs: Optional[int] = None,
+ top_p: Optional[float] = None,
+ user: Optional[str] = None,
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
+ model_obj = await self.model_store.get_model(model)
+ params = await prepare_openai_completion_params(
+ model=model_obj.provider_resource_id,
+ messages=messages,
+ frequency_penalty=frequency_penalty,
+ function_call=function_call,
+ functions=functions,
+ logit_bias=logit_bias,
+ logprobs=logprobs,
+ max_completion_tokens=max_completion_tokens,
+ max_tokens=max_tokens,
+ n=n,
+ parallel_tool_calls=parallel_tool_calls,
+ presence_penalty=presence_penalty,
+ response_format=response_format,
+ seed=seed,
+ stop=stop,
+ stream=stream,
+ stream_options=stream_options,
+ temperature=temperature,
+ tool_choice=tool_choice,
+ tools=tools,
+ top_logprobs=top_logprobs,
+ top_p=top_p,
+ user=user,
+ )
+ if params.get("stream", False):
+ return self._stream_openai_chat_completion(params)
+ return await self._get_openai_client().chat.completions.create(**params) # type: ignore
+
+ async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
+ # watsonx.ai sometimes adds usage data to the stream
+ include_usage = False
+ if params.get("stream_options", None):
+ include_usage = params["stream_options"].get("include_usage", False)
+ stream = await self._get_openai_client().chat.completions.create(**params)
+
+ seen_finish_reason = False
+ async for chunk in stream:
+ # Final usage chunk with no choices that the user didn't request, so discard
+ if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
+ break
+ yield chunk
+ for choice in chunk.choices:
+ if choice.finish_reason:
+ seen_finish_reason = True
+ break