mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
Merge branch 'main' into use-secret-str
This commit is contained in:
commit
39854f4562
14 changed files with 196 additions and 473 deletions
|
|
@ -18,7 +18,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
|||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||
| `project_id` | `str \| None` | No | | The Project ID key |
|
||||
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
|||
|
|
@ -611,7 +611,7 @@ class InferenceRouter(Inference):
|
|||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry and chunk.usage:
|
||||
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .watsonx import get_distribution_template # noqa: F401
|
||||
|
|
|
|||
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
|||
description: Use watsonx for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::watsonx
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_type: inline::faiss
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- aiosqlite
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
- files
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
|
|
@ -19,8 +19,6 @@ providers:
|
|||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:=}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
@ -48,7 +46,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
@ -109,102 +107,7 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-3-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-2-13b-chat
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-2-13b
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-8b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-1b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-3b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-guard-3-11b-vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||
|
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
config=WatsonXConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"watsonx": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use watsonx for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"inference": [inference_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -268,7 +268,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
|
|
|
|||
|
|
@ -4,19 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import WatsonXConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .watsonx import WatsonXInferenceAdapter
|
||||
|
||||
if not isinstance(config, WatsonXConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = WatsonXInferenceAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
||||
|
|
|
|||
|
|
@ -7,16 +7,18 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
project_id: str
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
)
|
||||
watsonx_api_key: str | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -26,8 +28,8 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
|||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||
description="The Project ID key",
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-2-13b-chat",
|
||||
CoreModelId.llama2_13b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
|
@ -4,240 +4,120 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ibm_watsonx_ai.foundation_models import Model
|
||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::watsonx")
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
|
||||
# Note on structured output
|
||||
# WatsonX returns responses with a json embedded into a string.
|
||||
# Examples:
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
|
||||
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
|
||||
# Not even a valid JSON, but we can still extract the JSON from the content
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
)
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
|
||||
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
|
||||
# "year_born": "1963", "year_retired": "2003"\\}}$')
|
||||
# Find the start of the boxed content
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
self._config = config
|
||||
self._openai_client: AsyncOpenAI | None = None
|
||||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
def _get_client(self, model_id) -> Model:
|
||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
config_url = self._config.url
|
||||
project_id = self._config.project_id
|
||||
credentials = {"url": config_url, "apikey": config_api_key}
|
||||
|
||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/openai/v1",
|
||||
api_key=self._config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
input_dict = {"params": {}}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
if request.sampling_params:
|
||||
if request.sampling_params.strategy:
|
||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||
if request.sampling_params.max_tokens:
|
||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||
if request.sampling_params.repetition_penalty:
|
||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||
|
||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||
|
||||
params = {
|
||||
**input_dict,
|
||||
}
|
||||
# Add watsonx.ai specific parameters
|
||||
params["project_id"] = self.config.project_id
|
||||
params["time_limit"] = self.config.timeout
|
||||
return params
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
models = []
|
||||
for model_spec in self._get_model_specs():
|
||||
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# watsonx.ai sometimes adds usage data to the stream
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
# Example of an embedding model:
|
||||
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||
# 'label': 'granite-embedding-278m-multilingual',
|
||||
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||
# ...
|
||||
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||
if "embedding" in functions:
|
||||
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||
embedding_metadata = {
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"context_length": context_length,
|
||||
}
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata=embedding_metadata,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
if "text_chat" in functions:
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||
# only used for check_model_availability() anyway.
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
||||
def _get_model_specs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieves foundation model specifications from the watsonx.ai API.
|
||||
"""
|
||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||
headers = {
|
||||
# Note that there is no authorization header. Listing models does not require authentication.
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
# --- Process the Response ---
|
||||
# Raise an exception for bad status codes (4xx or 5xx)
|
||||
response.raise_for_status()
|
||||
|
||||
# If the request is successful, parse and return the JSON response.
|
||||
# The response should contain a list of model specifications
|
||||
response_data = response.json()
|
||||
if "resources" not in response_data:
|
||||
raise ValueError("Resources not found in response")
|
||||
return response_data["resources"]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
|
|
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
b64_encode_openai_embeddings_response,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
get_sampling_options,
|
||||
|
|
@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
|
|||
return False
|
||||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: list[dict], encoding_format: str | None = "float"
|
||||
) -> list[OpenAIEmbeddingData]:
|
||||
"""
|
||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||
"""
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response_data):
|
||||
if encoding_format == "base64":
|
||||
byte_array = bytearray()
|
||||
for embedding_value in embedding_data["embedding"]:
|
||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||
|
||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||
else:
|
||||
response_embedding = embedding_data["embedding"]
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=response_embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
|
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
SamplingParams,
|
||||
|
|
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
|
|||
params["user"] = user
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: dict, encoding_format: str | None = "float"
|
||||
) -> list[OpenAIEmbeddingData]:
|
||||
"""
|
||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||
"""
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response_data):
|
||||
if encoding_format == "base64":
|
||||
byte_array = bytearray()
|
||||
for embedding_value in embedding_data.embedding:
|
||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||
|
||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||
else:
|
||||
response_embedding = embedding_data.embedding
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=response_embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
|||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
|
|||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
[
|
||||
(
|
||||
WatsonXConfig,
|
||||
WatsonXInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||
assumption that there is an OpenAI-compatible client object."""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.get_api_key() == api_key
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue