mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 17:01:59 +00:00
pre-commit issues fix
This commit is contained in:
parent
34a3f1a749
commit
efe5b124f3
7 changed files with 207 additions and 31 deletions
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
|
|
@ -19,7 +20,6 @@ class WatsonXProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(BaseModel):
|
||||
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the Watsonx.ai",
|
||||
|
|
@ -42,5 +42,5 @@ class WatsonXConfig(BaseModel):
|
|||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:}"
|
||||
}
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,5 @@ MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
)
|
||||
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -18,7 +20,6 @@ from llama_stack.apis.inference import (
|
|||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
|
|
@ -26,6 +27,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
|
|
@ -41,14 +43,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from . import WatsonXConfig
|
||||
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
|
@ -94,12 +91,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
config_url = self._config.url
|
||||
project_id = self._config.project_id
|
||||
credentials = {
|
||||
"url": config_url,
|
||||
"apikey": config_api_key
|
||||
}
|
||||
credentials = {"url": config_url, "apikey": config_api_key}
|
||||
|
||||
return Model(model_id=model_id,credentials=credentials, project_id=project_id)
|
||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
|
@ -186,6 +180,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
model_id = request.model
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client(model_id).generate_text_stream(**params)
|
||||
|
|
@ -225,19 +220,29 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
if request.sampling_params.additional_params.get("temperature"):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
|
||||
if request.sampling_params.additional_params.get("length_penalty"):
|
||||
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params["length_penalty"]
|
||||
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
|
||||
"length_penalty"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("random_seed"):
|
||||
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
|
||||
if request.sampling_params.additional_params.get("min_new_tokens"):
|
||||
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params["min_new_tokens"]
|
||||
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
|
||||
"min_new_tokens"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("stop_sequences"):
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params["stop_sequences"]
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
|
||||
"stop_sequences"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("time_limit"):
|
||||
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
|
||||
if request.sampling_params.additional_params.get("truncate_input_tokens"):
|
||||
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params["truncate_input_tokens"]
|
||||
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
|
||||
"truncate_input_tokens"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("return_options"):
|
||||
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params["return_options"]
|
||||
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
|
||||
"return_options"
|
||||
]
|
||||
|
||||
params = {
|
||||
**input_dict,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue