From 1d6ef73dd74c90b393ea90d2dab48a2ce2b8a06e Mon Sep 17 00:00:00 2001 From: Sajikumar JS Date: Sat, 26 Apr 2025 01:09:46 +0530 Subject: [PATCH] added additional params and new functions required to watsonx --- docs/_static/llama-stack-spec.html | 6 + docs/_static/llama-stack-spec.yaml | 4 + llama_stack/apis/inference/inference.py | 1 + .../remote/inference/watsonx/watsonx.py | 139 +++++++++++++++++- 4 files changed, 148 insertions(+), 2 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 4c5393947..dcb3ef945 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4191,6 +4191,12 @@ "type": "string" }, "description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence." + }, + "additional_params": { + "type": "object", + "additionalProperties": { + "type": "string" + } } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a24f1a9db..c0a704230 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2907,6 +2907,10 @@ components: description: >- Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. + additional_params: + type: object + additionalProperties: + type: string additionalProperties: false required: - strategy diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 309171f20..cb72c3b76 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -82,6 +82,7 @@ class SamplingParams(BaseModel): max_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop: Optional[List[str]] = None + additional_params: Optional[Dict[str, str]] = {} class LogProbConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index d5d87ec01..63484c888 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from ibm_watson_machine_learning.foundation_models import Model from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( @@ -27,10 +28,18 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -95,6 +104,14 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): return Model(model_id=model_id, credentials=credentials, project_id=project_id) + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + self._openai_client = AsyncOpenAI( + base_url=f"{self._config.url}/openai/v1", + api_key=self._config.api_key, + ) + return self._openai_client + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client(request.model).generate(**params) @@ -257,4 +274,122 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - pass + raise NotImplementedError("embedding is not supported for watsonx") + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + if params.get("stream", False): + return self._stream_openai_chat_completion(params) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore + + async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: + # watsonx.ai sometimes adds usage data to the stream + include_usage = False + if params.get("stream_options", None): + include_usage = params["stream_options"].get("include_usage", False) + stream = await self._get_openai_client().chat.completions.create(**params) + + seen_finish_reason = False + async for chunk in stream: + # Final usage chunk with no choices that the user didn't request, so discard + if not include_usage and seen_finish_reason and len(chunk.choices) == 0: + break + yield chunk + for choice in chunk.choices: + if choice.finish_reason: + seen_finish_reason = True + break