Merge branch 'main' into store_registeration_bug_fix

This commit is contained in:
Omar Abdelwahab 2025-09-16 18:18:56 -07:00 committed by GitHub
commit 43a2600158
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
58 changed files with 21962 additions and 343 deletions

View file

@ -93,3 +93,11 @@ class Benchmarks(Protocol):
:param metadata: The metadata to use for the benchmark.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE")
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.
:param benchmark_id: The ID of the benchmark to unregister.
"""
...

View file

@ -197,3 +197,11 @@ class ScoringFunctions(Protocol):
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE")
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function.
:param scoring_fn_id: The ID of the scoring function to unregister.
"""
...

View file

@ -56,3 +56,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
async def unregister_benchmark(self, benchmark_id: str) -> None:
existing_benchmark = await self.get_benchmark(benchmark_id)
await self.unregister_object(existing_benchmark)

View file

@ -64,6 +64,10 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.eval:
return await p.unregister_benchmark(obj.identifier)
elif api == Api.scoring:
return await p.unregister_scoring_function(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
else:

View file

@ -60,3 +60,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
await self.unregister_object(existing_scoring_fn)

View file

@ -10,6 +10,7 @@ apis:
- telemetry
- tool_runtime
- vector_io
- files
providers:
inference:
- provider_id: watsonx
@ -94,6 +95,14 @@ providers:
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/watsonx/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db

View file

@ -9,6 +9,7 @@ from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -16,7 +17,7 @@ from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template() -> DistributionTemplate:
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
providers = {
"inference": [
BuildProvider(provider_type="remote::watsonx"),
@ -42,6 +43,7 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
@ -79,9 +81,14 @@ def get_distribution_template() -> DistributionTemplate:
},
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="watsonx",
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
@ -92,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,

View file

@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def unregister_benchmark(self, benchmark_id: str) -> None:
if benchmark_id in self.benchmarks:
del self.benchmarks[benchmark_id]
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
await self.kvstore.delete(key)
async def run_eval(
self,
benchmark_id: str,

View file

@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
async def register_scoring_function(self, function_def: ScoringFn) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
async def score_batch(
self,
dataset_id: str,

View file

@ -51,18 +51,23 @@ class NVIDIAEvalImpl(
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
async def _evaluator_get(self, path: str):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
async def _evaluator_post(self, path: str, data: dict[str, Any]):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def _evaluator_delete(self, path: str) -> None:
"""Helper for making DELETE requests to the evaluator service."""
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
@ -75,6 +80,10 @@ class NVIDIAEvalImpl(
},
)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
async def run_eval(
self,
benchmark_id: str,

View file

@ -8,6 +8,7 @@
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -33,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -41,16 +43,15 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
class _HfAdapter(
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
url: str
api_key: SecretStr
hf_client: AsyncInferenceClient
max_tokens: int
model_id: str
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
@ -176,7 +200,7 @@ class _HfAdapter(
params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
@ -194,7 +218,7 @@ class _HfAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -241,7 +265,7 @@ class _HfAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -256,7 +280,7 @@ class _HfAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
@ -308,18 +332,21 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.api_key = SecretStr("NO_KEY")
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
# TODO: how do we set url for this?
class InferenceEndpointAdapter(_HfAdapter):
@ -331,6 +358,7 @@ class InferenceEndpointAdapter(_HfAdapter):
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.hf_client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
# TODO: how do we set url for this?

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
@ -21,57 +20,84 @@ SAFETY_MODELS_ENTRIES = [
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
ProviderModelEntry(
# source: https://docs.together.ai/docs/serverless-models#embedding-models
EMBEDDING_MODEL_ENTRIES = {
"togethercomputer/m2-bert-80M-32k-retrieval": ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 32768,
},
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
"BAAI/bge-large-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-large-en-v1.5",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
"BAAI/bge-base-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-base-en-v1.5",
metadata={
"embedding_dimension": 768,
"context_length": 512,
},
),
] + SAFETY_MODELS_ENTRIES
"Alibaba-NLP/gte-modernbert-base": ProviderModelEntry(
provider_model_id="Alibaba-NLP/gte-modernbert-base",
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
"intfloat/multilingual-e5-large-instruct": ProviderModelEntry(
provider_model_id="intfloat/multilingual-e5-large-instruct",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
}
MODEL_ENTRIES = (
[
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
+ SAFETY_MODELS_ENTRIES
+ list(EMBEDDING_MODEL_ENTRIES.values())
)

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
from openai import NOT_GIVEN, AsyncOpenAI
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -23,12 +23,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -38,18 +33,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -59,15 +56,22 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
from .models import EMBEDDING_MODEL_ENTRIES, MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
self._model_cache: dict[str, Model] = {}
def get_api_key(self):
return self.config.api_key.get_secret_value()
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
@ -255,6 +259,37 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in EMBEDDING_MODEL_ENTRIES:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=EMBEDDING_MODEL_ENTRIES[m.id].provider_model_id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=EMBEDDING_MODEL_ENTRIES[m.id].metadata,
)
else:
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
return self._model_cache.values()
async def should_refresh_models(self) -> bool:
return True
async def check_model_availability(self, model):
return model in self._model_cache
async def openai_embeddings(
self,
model: str,
@ -263,125 +298,39 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
"""
Together's OpenAI-compatible embeddings endpoint is not compatible with
the standard OpenAI embeddings endpoint.
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
The endpoint -
- does not return usage information
- does not support user param, returns 400 Unrecognized request arguments supplied: user
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
- does not support encoding_format param, always returns floats, never base64
"""
# Together support ticket #13332 -> will not fix
if user is not None:
raise ValueError("Together's embeddings endpoint does not support user param.")
# Together support ticket #13333 -> escalated
if dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
# Together support ticket #13331 -> will not fix, compute client side
if encoding_format not in (None, NOT_GIVEN, "float"):
raise ValueError("Together's embeddings endpoint only supports encoding_format='float'.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
# Together support ticket #13330 -> escalated
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
if not hasattr(response, "usage") or response.usage is None:
logger.warning(
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break
return response

View file

@ -26,11 +26,11 @@ class WatsonXConfig(BaseModel):
)
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key, only needed of using the hosted service",
description="The watsonx API key",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
description="The Project ID key",
)
timeout: int = Field(
default=60,

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -57,14 +58,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
# Note on structured output
# WatsonX returns responses with a json embedded into a string.
# Examples:
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
# Not even a valid JSON, but we can still extract the JSON from the content
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
# "year_born": "1963", "year_retired": "2003"\\}}$')
# Find the start of the boxed content
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing watsonx InferenceAdapter({config.url})...")
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
@ -43,6 +44,12 @@ class OpenAIMixin(ABC):
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False
@abstractmethod
def get_api_key(self) -> str:
"""
@ -110,6 +117,23 @@ class OpenAIMixin(ABC):
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp
new_id = f"cltsd-{uuid.uuid4()}"
if stream:
async def _gen():
async for chunk in resp:
chunk.id = new_id
yield chunk
return _gen()
else:
resp.id = new_id
return resp
async def openai_completion(
self,
model: str,
@ -147,7 +171,7 @@ class OpenAIMixin(ABC):
extra_body["guided_choice"] = guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
resp = await self.client.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
@ -171,6 +195,8 @@ class OpenAIMixin(ABC):
extra_body=extra_body,
)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_chat_completion(
self,
model: str,
@ -200,8 +226,7 @@ class OpenAIMixin(ABC):
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
@ -229,6 +254,8 @@ class OpenAIMixin(ABC):
)
)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_embeddings(
self,
model: str,

View file

@ -7,7 +7,6 @@
from __future__ import annotations # for forward references
import hashlib
import inspect
import json
import os
from collections.abc import Generator
@ -243,11 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation
if inspect.iscoroutinefunction(original_method):
return await original_method(self, *args, **kwargs)
else:
if endpoint == "/v1/models":
return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
# Get base URL based on client type
if client_type == "openai":
@ -298,10 +296,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
)
elif _current_mode == InferenceMode.RECORD:
if inspect.iscoroutinefunction(original_method):
response = await original_method(self, *args, **kwargs)
else:
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":

View file

@ -18,7 +18,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^12.23.12",
"llama-stack-client": "^0.2.21",
"llama-stack-client": "^0.2.22",
"lucide-react": "^0.542.0",
"next": "15.5.3",
"next-auth": "^4.24.11",
@ -10314,9 +10314,9 @@
"license": "MIT"
},
"node_modules/llama-stack-client": {
"version": "0.2.21",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.21.tgz",
"integrity": "sha512-rjU2Vx5xStxDYavU8K1An/SYXiQQjroLcK98B+p0Paz/a7OgRao2S0YwvThJjPUyChY4fO03UIXP9LpmHqlXWQ==",
"version": "0.2.22",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.22.tgz",
"integrity": "sha512-7aW3UQj5MwjV73Brd+yQ1e4W1W33nhozyeHM5tzOgbsVZ88tL78JNiNvyFqDR5w6V9XO4/uSGGiQVG6v83yR4w==",
"license": "MIT",
"dependencies": {
"@types/node": "^18.11.18",

View file

@ -23,7 +23,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^12.23.12",
"llama-stack-client": "^0.2.21",
"llama-stack-client": "^0.2.22",
"lucide-react": "^0.542.0",
"next": "15.5.3",
"next-auth": "^4.24.11",