From 6cf6791de1772fc44bc2192da0a3241babc8e60c Mon Sep 17 00:00:00 2001 From: Sajikumar JS <35679404+Sajikumarjs@users.noreply.github.com> Date: Sat, 26 Apr 2025 22:47:52 +0530 Subject: [PATCH] fix: updated watsonx inference chat apis with new repo changes (#2033) # What does this PR do? There are new changes in repo which needs to add some additional functions to the inference which is fixed. Also need one additional params to pass some extra arguments to watsonx.ai [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --------- Co-authored-by: Sajikumar JS --- .../remote/inference/watsonx/watsonx.py | 182 +++++++++++++++--- 1 file changed, 150 insertions(+), 32 deletions(-) diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index d5d87ec01..fa9cc4391 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from ibm_watson_machine_learning.foundation_models import Model from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( @@ -27,10 +28,21 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + GreedySamplingStrategy, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -95,6 +107,14 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): return Model(model_id=model_id, credentials=credentials, project_id=project_id) + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + self._openai_client = AsyncOpenAI( + base_url=f"{self._config.url}/openai/v1", + api_key=self._config.api_key, + ) + return self._openai_client + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client(request.model).generate(**params) @@ -213,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - if request.sampling_params.additional_params.get("top_p"): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] - if request.sampling_params.additional_params.get("top_k"): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] - if request.sampling_params.additional_params.get("temperature"): - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] - if request.sampling_params.additional_params.get("length_penalty"): - input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[ - "length_penalty" - ] - if request.sampling_params.additional_params.get("random_seed"): - input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"] - if request.sampling_params.additional_params.get("min_new_tokens"): - input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[ - "min_new_tokens" - ] - if request.sampling_params.additional_params.get("stop_sequences"): - input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[ - "stop_sequences" - ] - if request.sampling_params.additional_params.get("time_limit"): - input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"] - if request.sampling_params.additional_params.get("truncate_input_tokens"): - input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[ - "truncate_input_tokens" - ] - if request.sampling_params.additional_params.get("return_options"): - input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[ - "return_options" - ] + + if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): + input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature + if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k + if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): + input_dict["params"][GenParams.TEMPERATURE] = 0.0 + + input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] params = { **input_dict, @@ -257,4 +257,122 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - pass + raise NotImplementedError("embedding is not supported for watsonx") + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + if params.get("stream", False): + return self._stream_openai_chat_completion(params) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore + + async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: + # watsonx.ai sometimes adds usage data to the stream + include_usage = False + if params.get("stream_options", None): + include_usage = params["stream_options"].get("include_usage", False) + stream = await self._get_openai_client().chat.completions.create(**params) + + seen_finish_reason = False + async for chunk in stream: + # Final usage chunk with no choices that the user didn't request, so discard + if not include_usage and seen_finish_reason and len(chunk.choices) == 0: + break + yield chunk + for choice in chunk.choices: + if choice.finish_reason: + seen_finish_reason = True + break