Ran precommit

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:27:19 -07:00
parent 9886520b40
commit 9fc0d966f6
7 changed files with 153 additions and 310 deletions

View file

@ -9,6 +9,7 @@ from typing import Any
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -33,7 +34,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
request_has_media,
)
from openai import AsyncOpenAI
from . import WatsonXConfig
from .models import MODEL_ENTRIES
@ -65,9 +65,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
self._project_id = self._config.project_id
def _get_client(self, model_id) -> Model:
config_api_key = (
self._config.api_key.get_secret_value() if self._config.api_key else None
)
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
@ -82,46 +80,28 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
)
return self._openai_client
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, llama_model
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert (
not media_present
), "Together does not support media for Completion requests"
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][
GenParams.DECODING_METHOD
] = request.sampling_params.strategy.type
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][
GenParams.MAX_NEW_TOKENS
] = request.sampling_params.max_tokens
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][
GenParams.REPETITION_PENALTY
] = request.sampling_params.repetition_penalty
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][
GenParams.TOP_P
] = request.sampling_params.strategy.top_p
input_dict["params"][
GenParams.TEMPERATURE
] = request.sampling_params.strategy.temperature
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][
GenParams.TOP_K
] = request.sampling_params.strategy.top_k
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0