mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
update SambaNovaInferenceAdapter to use _get_params from LiteLLMOpenAIMixin by adding extra params to the mixin
This commit is contained in:
parent
037d28f08e
commit
6711fd4f5a
2 changed files with 12 additions and 58 deletions
|
@ -6,19 +6,9 @@
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
ToolChoice,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
get_sampling_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import SambaNovaImplConfig
|
from .config import SambaNovaImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
@ -39,54 +29,10 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||||
provider_data_api_key_field="sambanova_api_key",
|
provider_data_api_key_field="sambanova_api_key",
|
||||||
openai_compat_api_base=self.config.url,
|
openai_compat_api_base=self.config.url,
|
||||||
|
download_images=True, # SambaNova requires base64 image encoding
|
||||||
|
json_schema_strict=False, # SambaNova doesn't support strict=True yet
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
input_dict = {}
|
|
||||||
|
|
||||||
input_dict["messages"] = [
|
|
||||||
await convert_message_to_openai_dict_new(m, download_images=True) for m in request.messages
|
|
||||||
]
|
|
||||||
if fmt := request.response_format:
|
|
||||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
fmt = fmt.json_schema
|
|
||||||
name = fmt["title"]
|
|
||||||
del fmt["title"]
|
|
||||||
fmt["additionalProperties"] = False
|
|
||||||
|
|
||||||
# Apply additionalProperties: False recursively to all objects
|
|
||||||
fmt = self._add_additional_properties_recursive(fmt)
|
|
||||||
|
|
||||||
input_dict["response_format"] = {
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {
|
|
||||||
"name": name,
|
|
||||||
"schema": fmt,
|
|
||||||
"strict": False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if request.tools:
|
|
||||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
|
||||||
if request.tool_config.tool_choice:
|
|
||||||
input_dict["tool_choice"] = (
|
|
||||||
request.tool_config.tool_choice.value
|
|
||||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
|
||||||
else request.tool_config.tool_choice
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model": request.model,
|
|
||||||
"api_key": self.get_api_key(),
|
|
||||||
"api_base": self.api_base,
|
|
||||||
**input_dict,
|
|
||||||
"stream": request.stream,
|
|
||||||
**get_sampling_options(request.sampling_params),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,8 @@ class LiteLLMOpenAIMixin(
|
||||||
api_key_from_config: str | None,
|
api_key_from_config: str | None,
|
||||||
provider_data_api_key_field: str,
|
provider_data_api_key_field: str,
|
||||||
openai_compat_api_base: str | None = None,
|
openai_compat_api_base: str | None = None,
|
||||||
|
download_images: bool = False,
|
||||||
|
json_schema_strict: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the LiteLLMOpenAIMixin.
|
Initialize the LiteLLMOpenAIMixin.
|
||||||
|
@ -81,6 +83,8 @@ class LiteLLMOpenAIMixin(
|
||||||
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
||||||
:param litellm_provider_name: The name of the provider, used for model lookups.
|
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||||
|
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||||
|
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
|
||||||
"""
|
"""
|
||||||
ModelRegistryHelper.__init__(self, model_entries)
|
ModelRegistryHelper.__init__(self, model_entries)
|
||||||
|
|
||||||
|
@ -88,6 +92,8 @@ class LiteLLMOpenAIMixin(
|
||||||
self.api_key_from_config = api_key_from_config
|
self.api_key_from_config = api_key_from_config
|
||||||
self.provider_data_api_key_field = provider_data_api_key_field
|
self.provider_data_api_key_field = provider_data_api_key_field
|
||||||
self.api_base = openai_compat_api_base
|
self.api_base = openai_compat_api_base
|
||||||
|
self.download_images = download_images
|
||||||
|
self.json_schema_strict = json_schema_strict
|
||||||
|
|
||||||
if openai_compat_api_base:
|
if openai_compat_api_base:
|
||||||
self.is_openai_compat = True
|
self.is_openai_compat = True
|
||||||
|
@ -206,7 +212,9 @@ class LiteLLMOpenAIMixin(
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
|
input_dict["messages"] = [
|
||||||
|
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||||
|
]
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -226,7 +234,7 @@ class LiteLLMOpenAIMixin(
|
||||||
"json_schema": {
|
"json_schema": {
|
||||||
"name": name,
|
"name": name,
|
||||||
"schema": fmt,
|
"schema": fmt,
|
||||||
"strict": True,
|
"strict": self.json_schema_strict,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if request.tools:
|
if request.tools:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue