diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index d5d87ec01..fa9cc4391 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from ibm_watson_machine_learning.foundation_models import Model from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( @@ -27,10 +28,21 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + GreedySamplingStrategy, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -95,6 +107,14 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): return Model(model_id=model_id, credentials=credentials, project_id=project_id) + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + self._openai_client = AsyncOpenAI( + base_url=f"{self._config.url}/openai/v1", + api_key=self._config.api_key, + ) + return self._openai_client + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client(request.model).generate(**params) @@ -213,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - if request.sampling_params.additional_params.get("top_p"): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] - if request.sampling_params.additional_params.get("top_k"): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] - if request.sampling_params.additional_params.get("temperature"): - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] - if request.sampling_params.additional_params.get("length_penalty"): - input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[ - "length_penalty" - ] - if request.sampling_params.additional_params.get("random_seed"): - input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"] - if request.sampling_params.additional_params.get("min_new_tokens"): - input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[ - "min_new_tokens" - ] - if request.sampling_params.additional_params.get("stop_sequences"): - input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[ - "stop_sequences" - ] - if request.sampling_params.additional_params.get("time_limit"): - input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"] - if request.sampling_params.additional_params.get("truncate_input_tokens"): - input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[ - "truncate_input_tokens" - ] - if request.sampling_params.additional_params.get("return_options"): - input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[ - "return_options" - ] + + if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): + input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature + if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k + if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): + input_dict["params"][GenParams.TEMPERATURE] = 0.0 + + input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] params = { **input_dict, @@ -257,4 +257,122 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - pass + raise NotImplementedError("embedding is not supported for watsonx") + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + if params.get("stream", False): + return self._stream_openai_chat_completion(params) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore + + async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: + # watsonx.ai sometimes adds usage data to the stream + include_usage = False + if params.get("stream_options", None): + include_usage = params["stream_options"].get("include_usage", False) + stream = await self._get_openai_client().chat.completions.create(**params) + + seen_finish_reason = False + async for chunk in stream: + # Final usage chunk with no choices that the user didn't request, so discard + if not include_usage and seen_finish_reason and len(chunk.choices) == 0: + break + yield chunk + for choice in chunk.choices: + if choice.finish_reason: + seen_finish_reason = True + break