Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
jhpiedrahitao 2025-05-05 11:49:58 -05:00
commit b7f16ac7a6
535 changed files with 23539 additions and 8112 deletions

View file

@ -4,15 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: Optional[str] = None
anthropic_api_key: str | None = None
async def get_adapter_impl(config: AnthropicConfig, _deps):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: Optional[str] = Field(
anthropic_api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@ -20,13 +20,13 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -1,18 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
from .bedrock import BedrockInferenceAdapter
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)
await impl.initialize()
return impl
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
from .bedrock import BedrockInferenceAdapter
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from botocore.client import BaseClient
@ -79,26 +79,26 @@ class BedrockInferenceAdapter(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
@ -151,7 +151,7 @@ class BedrockInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
sampling_params = request.sampling_params
@ -176,10 +176,10 @@ class BedrockInferenceAdapter(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []

View file

@ -1,11 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class BedrockConfig(BedrockBaseConfig):
pass
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class BedrockConfig(BedrockBaseConfig):
pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from cerebras.cloud.sdk import AsyncCerebras
@ -79,10 +79,10 @@ class CerebrasInferenceAdapter(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -120,15 +120,15 @@ class CerebrasInferenceAdapter(
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -166,7 +166,7 @@ class CerebrasInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
@ -188,9 +188,9 @@ class CerebrasInferenceAdapter(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@ -20,13 +20,13 @@ class CerebrasImplConfig(BaseModel):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: Optional[SecretStr] = Field(
api_key: SecretStr | None = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": "${env.CEREBRAS_API_KEY}",

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: Optional[str] = Field(
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@ -20,7 +20,7 @@ class CerebrasProviderDataValidator(BaseModel):
@json_schema_type
class CerebrasCompatConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="The Cerebras API key",
)
@ -31,7 +31,7 @@ class CerebrasCompatConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.cerebras.ai/v1",
"api_key": api_key,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel, Field
@ -28,7 +28,7 @@ class DatabricksImplConfig(BaseModel):
url: str = "${env.DATABRICKS_URL}",
api_token: str = "${env.DATABRICKS_API_TOKEN}",
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return {
"url": url,
"api_token": api_token,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from collections.abc import AsyncGenerator
from openai import OpenAI
@ -78,25 +78,25 @@ class DatabricksInferenceAdapter(
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -146,9 +146,9 @@ class DatabricksInferenceAdapter(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@ -17,13 +17,13 @@ class FireworksImplConfig(BaseModel):
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
api_key: Optional[SecretStr] = Field(
api_key: SecretStr | None = Field(
default=None,
description="The Fireworks.ai API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from fireworks.client import Fireworks
from openai import AsyncOpenAI
@ -105,10 +106,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -146,9 +147,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
def _build_options(
self,
sampling_params: Optional[SamplingParams],
sampling_params: SamplingParams | None,
fmt: ResponseFormat,
logprobs: Optional[LogProbConfig],
logprobs: LogProbConfig | None,
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
@ -177,15 +178,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -229,7 +230,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
@ -263,10 +264,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -288,24 +289,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
@ -338,30 +339,63 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
@ -387,11 +421,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user=user,
)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class FireworksProviderDataValidator(BaseModel):
fireworks_api_key: Optional[str] = Field(
fireworks_api_key: str | None = Field(
default=None,
description="API key for Fireworks models",
)
@ -20,7 +20,7 @@ class FireworksProviderDataValidator(BaseModel):
@json_schema_type
class FireworksCompatConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="The Fireworks API key",
)
@ -31,7 +31,7 @@ class FireworksCompatConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,

View file

@ -4,15 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: Optional[str] = None
gemini_api_key: str | None = None
async def get_adapter_impl(config: GeminiConfig, _deps):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: Optional[str] = Field(
gemini_api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@ -20,13 +20,13 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: Optional[str] = Field(
groq_api_key: str | None = Field(
default=None,
description="API key for Groq models",
)
@ -20,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": api_key,

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
@ -59,29 +60,29 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Groq does not support json_schema response format, so we need to convert it to json_object

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: Optional[str] = Field(
groq_api_key: str | None = Field(
default=None,
description="API key for Groq models",
)
@ -20,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqCompatConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="The Groq API key",
)
@ -31,7 +31,7 @@ class GroqCompatConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.groq.com/openai/v1",
"api_key": api_key,

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
return adapter

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class LlamaProviderDataValidator(BaseModel):
llama_api_key: str | None = Field(
default=None,
description="API key for api.llama models",
)
@json_schema_type
class LlamaCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from .models import MODEL_ENTRIES
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct-FP8",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -0,0 +1,85 @@
# NVIDIA Inference Provider for LlamaStack
This provider enables running inference using NVIDIA NIM.
## Features
- Endpoints for completions, chat completions, and embeddings for registered models
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to NVIDIA NIM deployment
- NIM for model to use for inference is deployed
### Setup
Build the NVIDIA environment:
```bash
llama stack build --template nvidia --image-type conda
```
### Basic Usage using the LlamaStack Python Client
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = (
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
)
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
### Create Completion
```python
response = client.completion(
model_id="meta-llama/Llama-3.1-8b-Instruct",
content="Complete the sentence using one word: Roses are red, violets are :",
stream=False,
sampling_params={
"max_tokens": 50,
},
)
print(f"Response: {response.content}")
```
### Create Chat Completion
```python
response = client.chat_completion(
model_id="meta-llama/Llama-3.1-8b-Instruct",
messages=[
{
"role": "system",
"content": "You must respond to each message with only one word",
},
{
"role": "user",
"content": "Complete the sentence using one word: Roses are red, violets are:",
},
],
stream=False,
sampling_params={
"max_tokens": 50,
},
)
print(f"Response: {response.completion_message.content}")
```
### Create Embeddings
```python
response = client.embeddings(
model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"]
)
print(f"Embeddings: {response.embeddings}")
```

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@ -39,7 +39,7 @@ class NVIDIAConfig(BaseModel):
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[SecretStr] = Field(
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service",
)
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
}

View file

@ -48,6 +48,10 @@ MODEL_ENTRIES = [
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
# NeMo Retriever Text Embedding models -
#
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html

View file

@ -6,8 +6,9 @@
import logging
import warnings
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -33,7 +34,6 @@ from llama_stack.apis.inference import (
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -42,7 +42,11 @@ from llama_stack.apis.inference.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -120,21 +124,29 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if content_has_media(content):
@ -144,7 +156,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
model=provider_model_id,
@ -171,24 +183,24 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported")
#
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
# Llama Stack: contents = list[str] | list[InterleavedContentItem]
# ->
# OpenAI: input = str | List[str]
# OpenAI: input = str | list[str]
#
# we can ignore str and always pass List[str] to OpenAI
# we can ignore str and always pass list[str] to OpenAI
#
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
model = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
extra_body = {}
@ -211,8 +223,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(model).embeddings.create(
model=model,
response = await self._get_client(provider_model_id).embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
@ -220,25 +232,25 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
raise ValueError(f"Failed to get embeddings: {e}") from e
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
# Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if tool_prompt_format:
@ -246,10 +258,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
model=provider_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
@ -275,26 +287,26 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
provider_model_id = self.get_provider_model_id(model)
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
@ -324,30 +336,30 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model)
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
@ -379,3 +391,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
return await self._get_client(provider_model_id).chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import warnings
from typing import Any, AsyncGenerator, Dict, List, Optional
from collections.abc import AsyncGenerator
from typing import Any
from openai import AsyncStream
from openai.types.chat.chat_completion import (
@ -64,7 +65,7 @@ async def convert_chat_completion_request(
)
nvext = {}
payload: Dict[str, Any] = dict(
payload: dict[str, Any] = dict(
model=request.model,
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
stream=request.stream,
@ -137,7 +138,7 @@ def convert_completion_request(
# logprobs.top_k -> logprobs
nvext = {}
payload: Dict[str, Any] = dict(
payload: dict[str, Any] = dict(
model=request.model,
prompt=request.content,
stream=request.stream,
@ -176,8 +177,8 @@ def convert_completion_request(
def _convert_openai_completion_logprobs(
logprobs: Optional[OpenAICompletionLogprobs],
) -> Optional[List[TokenLogProbs]]:
logprobs: OpenAICompletionLogprobs | None,
) -> list[TokenLogProbs] | None:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
"""

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
from typing import Tuple
import httpx
@ -18,7 +17,7 @@ def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
async def _get_health(url: str) -> Tuple[bool, bool]:
async def _get_health(url: str) -> tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -15,5 +15,5 @@ class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> dict[str, Any]:
return {"url": url}

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
from ollama import AsyncClient
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
@ -130,10 +131,10 @@ class OllamaInferenceAdapter(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
@ -188,15 +189,15 @@ class OllamaInferenceAdapter(
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
@ -216,7 +217,7 @@ class OllamaInferenceAdapter(
else:
return await self._nonstream_chat_completion(request)
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
@ -314,10 +315,10 @@ class OllamaInferenceAdapter(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self._get_model(model_id)
@ -333,7 +334,10 @@ class OllamaInferenceAdapter(
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
@ -342,9 +346,12 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
if provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
if provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
@ -352,30 +359,31 @@ class OllamaInferenceAdapter(
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
model.provider_resource_id = provider_resource_id
return model
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
@ -409,30 +417,36 @@ class OllamaInferenceAdapter(
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# ollama still makes tool calls even when tool_choice is "none"
# so we need to remove the tools in that case
if tool_choice == "none" and tools is not None:
tools = None
params = {
k: v
for k, v in {
@ -467,27 +481,27 @@ class OllamaInferenceAdapter(
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {

View file

@ -4,15 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: Optional[str] = None
openai_api_key: str | None = None
async def get_adapter_impl(config: OpenAIConfig, _deps):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: Optional[str] = Field(
openai_api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
@ -20,13 +20,13 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@ -18,13 +18,13 @@ class PassthroughImplConfig(BaseModel):
description="The URL for the passthrough endpoint",
)
api_key: Optional[SecretStr] = Field(
api_key: SecretStr | None = Field(
default=None,
description="API Key for the passthrouth endpoint",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.PASSTHROUGH_URL}",
"api_key": "${env.PASSTHROUGH_API_KEY}",

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from llama_stack_client import AsyncLlamaStackClient
@ -93,10 +94,10 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -123,15 +124,15 @@ class PassthroughInferenceAdapter(Inference):
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -165,7 +166,7 @@ class PassthroughInferenceAdapter(Inference):
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
@ -178,7 +179,7 @@ class PassthroughInferenceAdapter(Inference):
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
@ -193,10 +194,10 @@ class PassthroughInferenceAdapter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[InterleavedContent],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
client = self._get_client()
model = await self.model_store.get_model(model_id)
@ -212,24 +213,24 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
@ -261,29 +262,29 @@ class PassthroughInferenceAdapter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
@ -315,7 +316,7 @@ class PassthroughInferenceAdapter(Inference):
return await client.inference.openai_chat_completion(**params)
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -13,17 +13,17 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RunpodImplConfig(BaseModel):
url: Optional[str] = Field(
url: str | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
api_token: Optional[str] = Field(
api_token: str | None = Field(
default=None,
description="The API token",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:}",
"api_token": "${env.RUNPOD_API_TOKEN:}",

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from collections.abc import AsyncGenerator
from openai import OpenAI

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: Optional[str] = Field(
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@ -24,13 +24,13 @@ class SambaNovaImplConfig(BaseModel):
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: Optional[SecretStr] = Field(
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from typing import Dict, Iterable, List, Union
from collections.abc import Iterable
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
@ -72,7 +72,7 @@ logger = get_logger(name=__name__, category="inference")
async def convert_message_to_openai_dict_with_b64_images(
message: Message | Dict,
message: Message | dict,
) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
@ -101,10 +101,10 @@ async def convert_message_to_openai_dict_with_b64_images(
# List[...] -> List[...]
async def _convert_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: Optional[str] = Field(
sambanova_api_key: str | None = Field(
default=None,
description="API key for SambaNova models",
)
@ -20,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaCompatConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="The SambaNova API key",
)
@ -31,7 +31,7 @@ class SambaNovaCompatConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.sambanova.ai/v1",
"api_key": api_key,

View file

@ -4,13 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Union
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
async def get_adapter_impl(
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
config: InferenceAPIImplConfig | InferenceEndpointImplConfig | TGIImplConfig,
_deps,
):
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field, SecretStr
@ -29,7 +28,7 @@ class InferenceEndpointImplConfig(BaseModel):
endpoint_name: str = Field(
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
)
api_token: Optional[SecretStr] = Field(
api_token: SecretStr | None = Field(
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@ -52,7 +51,7 @@ class InferenceAPIImplConfig(BaseModel):
huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: Optional[SecretStr] = Field(
api_token: SecretStr | None = Field(
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)

View file

@ -6,7 +6,7 @@
import logging
from typing import AsyncGenerator, List, Optional
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
@ -105,10 +105,10 @@ class _HfAdapter(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -134,7 +134,7 @@ class _HfAdapter(
def _build_options(
self,
sampling_params: Optional[SamplingParams] = None,
sampling_params: SamplingParams | None = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
@ -209,15 +209,15 @@ class _HfAdapter(
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -284,10 +284,10 @@ class _HfAdapter(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, SecretStr
@ -17,13 +17,13 @@ class TogetherImplConfig(BaseModel):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: Optional[SecretStr] = Field(
api_key: SecretStr | None = Field(
default=None,
description="The Together AI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:}",

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from together import AsyncTogether
@ -76,17 +77,20 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def shutdown(self) -> None:
if self._client:
await self._client.close()
# Together client has no close method, so just set to None
self._client = None
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -144,8 +148,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def _build_options(
self,
sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig],
sampling_params: SamplingParams | None,
logprobs: LogProbConfig | None,
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
@ -172,15 +176,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -221,7 +225,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
@ -246,10 +250,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
@ -266,24 +270,24 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
@ -310,29 +314,29 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
@ -359,7 +363,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p,
user=user,
)
if params.get("stream", True):
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
class TogetherProviderDataValidator(BaseModel):
together_api_key: Optional[str] = Field(
together_api_key: str | None = Field(
default=None,
description="API key for Together models",
)
@ -20,7 +20,7 @@ class TogetherProviderDataValidator(BaseModel):
@json_schema_type
class TogetherCompatConfig(BaseModel):
api_key: Optional[str] = Field(
api_key: str | None = Field(
default=None,
description="The Together API key",
)
@ -31,7 +31,7 @@ class TogetherCompatConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.together.xyz/v1",
"api_key": api_key,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field
@ -13,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(BaseModel):
url: Optional[str] = Field(
url: str | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
@ -21,7 +20,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: Optional[str] = Field(
api_token: str | None = Field(
default="fake",
description="The API token",
)

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import json
import logging
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
from openai import AsyncOpenAI
@ -94,7 +95,7 @@ def build_hf_repo_model_entries():
def _convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> List[ToolCall]:
) -> list[ToolCall]:
if not tool_calls:
return []
@ -109,7 +110,7 @@ def _convert_to_vllm_tool_calls_in_response(
]
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]:
compat_tools = []
for tool in tools:
@ -231,12 +232,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
pass
async def shutdown(self) -> None:
pass
@ -249,15 +245,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def _lazy_initialize_client(self):
if self.client is not None:
return
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()
def _create_client(self):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -277,16 +288,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -357,9 +369,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
assert self.client is not None
model = await self.register_helper.register_model(model)
res = await self.client.models.list()
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
res = await client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
@ -368,13 +386,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return model
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
@ -404,11 +423,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)
@ -429,28 +449,29 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
extra_body: Dict[str, Any] = {}
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
@ -481,29 +502,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
@ -535,21 +557,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
@json_schema_type
class WatsonXConfig(BaseModel):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key, only needed of using the hosted service",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}",
}

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -0,0 +1,379 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
GreedySamplingStrategy,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._project_id = self._config.project_id
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_completion_response(response)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = self._get_client(request.model).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client(model_id).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
}
return params
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError("embedding is not supported for watsonx")
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# watsonx.ai sometimes adds usage data to the stream
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break