Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-19 22:53:03 +09:00 committed by GitHub
commit c71bcd5479
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
124 changed files with 25574 additions and 2425 deletions

View file

@ -51,18 +51,23 @@ class NVIDIAEvalImpl(
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
async def _evaluator_get(self, path: str):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
async def _evaluator_post(self, path: str, data: dict[str, Any]):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def _evaluator_delete(self, path: str) -> None:
"""Helper for making DELETE requests to the evaluator service."""
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
@ -75,6 +80,10 @@ class NVIDIAEvalImpl(
},
)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
async def run_eval(
self,
benchmark_id: str,

View file

@ -7,12 +7,10 @@
import asyncio
import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options,
prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def client(self) -> AsyncClient:
def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
if self._openai_client is None:
url = self.config.url.rstrip("/")
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
return self._openai_client
def get_api_key(self):
return "NO_KEY"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.client.list()
response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand
models = [
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
r = await self.ollama_client.chat(**params)
else:
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.client.chat(**params)
s = await self.ollama_client.chat(**params)
else:
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
response = await self.ollama_client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding:
response = await self.client.list()
response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
response = await self.ollama_client.list()
available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
top_p=top_p,
user=user,
)
response = await self.openai_client.chat.completions.create(**params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:

View file

@ -8,6 +8,7 @@
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -33,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -41,16 +43,15 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
class _HfAdapter(
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
url: str
api_key: SecretStr
hf_client: AsyncInferenceClient
max_tokens: int
model_id: str
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
@ -176,7 +200,7 @@ class _HfAdapter(
params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
@ -194,7 +218,7 @@ class _HfAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -241,7 +265,7 @@ class _HfAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -256,7 +280,7 @@ class _HfAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
@ -308,18 +332,21 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.api_key = SecretStr("NO_KEY")
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
# TODO: how do we set url for this?
class InferenceEndpointAdapter(_HfAdapter):
@ -331,6 +358,7 @@ class InferenceEndpointAdapter(_HfAdapter):
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.hf_client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
# TODO: how do we set url for this?

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
@ -21,57 +20,84 @@ SAFETY_MODELS_ENTRIES = [
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
ProviderModelEntry(
# source: https://docs.together.ai/docs/serverless-models#embedding-models
EMBEDDING_MODEL_ENTRIES = {
"togethercomputer/m2-bert-80M-32k-retrieval": ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 32768,
},
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
"BAAI/bge-large-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-large-en-v1.5",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
"BAAI/bge-base-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-base-en-v1.5",
metadata={
"embedding_dimension": 768,
"context_length": 512,
},
),
] + SAFETY_MODELS_ENTRIES
"Alibaba-NLP/gte-modernbert-base": ProviderModelEntry(
provider_model_id="Alibaba-NLP/gte-modernbert-base",
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
"intfloat/multilingual-e5-large-instruct": ProviderModelEntry(
provider_model_id="intfloat/multilingual-e5-large-instruct",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
}
MODEL_ENTRIES = (
[
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
+ SAFETY_MODELS_ENTRIES
+ list(EMBEDDING_MODEL_ENTRIES.values())
)

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
from openai import NOT_GIVEN, AsyncOpenAI
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -23,12 +23,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -38,18 +33,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -59,15 +56,22 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
from .models import EMBEDDING_MODEL_ENTRIES, MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
self._model_cache: dict[str, Model] = {}
def get_api_key(self):
return self.config.api_key.get_secret_value()
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
@ -255,6 +259,37 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in EMBEDDING_MODEL_ENTRIES:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=EMBEDDING_MODEL_ENTRIES[m.id].provider_model_id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=EMBEDDING_MODEL_ENTRIES[m.id].metadata,
)
else:
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
return self._model_cache.values()
async def should_refresh_models(self) -> bool:
return True
async def check_model_availability(self, model):
return model in self._model_cache
async def openai_embeddings(
self,
model: str,
@ -263,125 +298,39 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
"""
Together's OpenAI-compatible embeddings endpoint is not compatible with
the standard OpenAI embeddings endpoint.
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
The endpoint -
- does not return usage information
- does not support user param, returns 400 Unrecognized request arguments supplied: user
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
- does not support encoding_format param, always returns floats, never base64
"""
# Together support ticket #13332 -> will not fix
if user is not None:
raise ValueError("Together's embeddings endpoint does not support user param.")
# Together support ticket #13333 -> escalated
if dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
# Together support ticket #13331 -> will not fix, compute client side
if encoding_format not in (None, NOT_GIVEN, "float"):
raise ValueError("Together's embeddings endpoint only supports encoding_format='float'.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
# Together support ticket #13330 -> escalated
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
if not hasattr(response, "usage") or response.usage is None:
logger.warning(
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break
return response

View file

@ -4,9 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import VLLMInferenceAdapterConfig
class VLLMProviderDataValidator(BaseModel):
vllm_api_token: str | None = None
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError, AsyncOpenAI
@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import (
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
@ -62,6 +64,7 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
@ -281,15 +284,31 @@ async def _process_vllm_chat_completion_stream_response(
yield c
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
raise ValueError("No base URL configured")
return self.config.url
async def initialize(self) -> None:
if not self.config.url:
raise ValueError(
@ -297,6 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
async def should_refresh_models(self) -> bool:
# Strictly respecting the refresh_models directive
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
@ -325,13 +345,19 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Uses the unauthenticated /health endpoint.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
base_url = self.get_base_url()
health_url = urljoin(base_url, "health")
async with httpx.AsyncClient() as client:
response = await client.get(health_url)
response.raise_for_status()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -340,16 +366,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_api_key(self):
return self.config.api_token
def get_base_url(self):
return self.config.url
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion(
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
self,
model_id: str,
content: InterleavedContent,
@ -411,13 +431,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)
return self._stream_chat_completion_with_client(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
@ -431,9 +452,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
# This method is called from LiteLLMOpenAIMixin.chat_completion
# The response parameter contains the litellm response
# We need to convert it to our format
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _stream_chat_completion_with_client(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Helper method for streaming with explicit client parameter."""
assert self.client is not None
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
@ -445,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
@ -453,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.client is not None
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
stream = await self.client.completions.create(**params)

View file

@ -26,11 +26,11 @@ class WatsonXConfig(BaseModel):
)
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key, only needed of using the hosted service",
description="The watsonx API key",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
description="The Project ID key",
)
timeout: int = Field(
default=60,

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -57,14 +58,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
# Note on structured output
# WatsonX returns responses with a json embedded into a string.
# Examples:
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
# Not even a valid JSON, but we can still extract the JSON from the content
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
# "year_born": "1963", "year_retired": "2003"\\}}$')
# Find the start of the boxed content
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing watsonx InferenceAdapter({config.url})...")
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import hashlib
import uuid
from typing import Any
@ -49,10 +50,13 @@ def convert_id(_id: str) -> str:
Converts any string into a UUID string based on a seed.
Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
We use a SHA-256 hash to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
hash_input = f"qdrant_id:{_id}".encode()
sha256_hash = hashlib.sha256(hash_input).hexdigest()
# Use the first 32 characters to create a valid UUID
return str(uuid.UUID(sha256_hash[:32]))
class QdrantIndex(EmbeddingIndex):