pre-commit issues fix

This commit is contained in:
Sajikumar JS 2025-04-17 23:45:27 +05:30
parent 34a3f1a749
commit efe5b124f3
7 changed files with 207 additions and 31 deletions

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
import os
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
from pydantic import BaseModel, Field, SecretStr
class WatsonXProviderDataValidator(BaseModel):
@ -19,7 +20,6 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type
class WatsonXConfig(BaseModel):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the Watsonx.ai",
@ -42,5 +42,5 @@ class WatsonXConfig(BaseModel):
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}"
}
"project_id": "${env.WATSONX_PROJECT_ID:}",
}

View file

@ -43,7 +43,5 @@ MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
)
),
]

View file

@ -6,8 +6,10 @@
from typing import AsyncGenerator, List, Optional, Union
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -18,7 +20,6 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
@ -26,6 +27,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
@ -41,14 +43,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from . import WatsonXConfig
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from .models import MODEL_ENTRIES
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
@ -94,12 +91,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {
"url": config_url,
"apikey": config_api_key
}
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id,credentials=credentials, project_id=project_id)
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
@ -186,6 +180,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client(model_id).generate_text_stream(**params)
@ -225,19 +220,29 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
if request.sampling_params.additional_params.get("temperature"):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
if request.sampling_params.additional_params.get("length_penalty"):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params["length_penalty"]
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
"length_penalty"
]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params["min_new_tokens"]
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
"min_new_tokens"
]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params["stop_sequences"]
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
"stop_sequences"
]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params["truncate_input_tokens"]
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
"truncate_input_tokens"
]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params["return_options"]
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
"return_options"
]
params = {
**input_dict,