fix: updated watsonx inference chat apis with new repo changes (#2033)

# What does this PR do?
There are new changes in repo which needs to add some additional
functions to the inference which is fixed. Also need one additional
params to pass some extra arguments to watsonx.ai

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)

---------

Co-authored-by: Sajikumar JS <sajikumar.js@ibm.com>
This commit is contained in:
Sajikumar JS 2025-04-26 22:47:52 +05:30 committed by GitHub
parent 0266b20535
commit 6cf6791de1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
@ -27,10 +28,21 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
GreedySamplingStrategy,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -95,6 +107,14 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
@ -213,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if request.sampling_params.additional_params.get("top_p"):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
if request.sampling_params.additional_params.get("top_k"):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
if request.sampling_params.additional_params.get("temperature"):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
if request.sampling_params.additional_params.get("length_penalty"):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
"length_penalty"
]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
"min_new_tokens"
]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
"stop_sequences"
]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
"truncate_input_tokens"
]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
"return_options"
]
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
@ -257,4 +257,122 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
pass
raise NotImplementedError("embedding is not supported for watsonx")
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# watsonx.ai sometimes adds usage data to the stream
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break