mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 21:18:38 +00:00
Merge branch 'main' into agents-openai-migration
This commit is contained in:
commit
724322eeb2
673 changed files with 164269 additions and 14378 deletions
|
@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
|
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
|
|||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: AdapterSpec = Field(
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here.
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -234,33 +207,6 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> str | None:
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api,
|
||||
adapter: AdapterSpec,
|
||||
api_dependencies: list[Api] | None = None,
|
||||
optional_api_dependencies: list[Api] | None = None,
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(StrEnum):
|
||||
OK = "OK"
|
||||
|
|
|
@ -830,6 +830,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
|
@ -873,6 +875,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
|
@ -952,7 +956,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
|
|
|
@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
inputs=input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
|
|
|
@ -10,10 +10,12 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
@ -50,6 +52,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal
|
|||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
def convert_tooldef_to_chat_tool(tool_def):
|
||||
"""Convert a ToolDef to OpenAI ChatCompletionToolParam format.
|
||||
|
||||
Args:
|
||||
tool_def: ToolDef from the tools API
|
||||
|
||||
Returns:
|
||||
ChatCompletionToolParam suitable for OpenAI chat completion
|
||||
"""
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
internal_tool_def = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
items=param.items,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(internal_tool_def)
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -117,10 +149,17 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
|
@ -164,10 +203,11 @@ class StreamingResponseOrchestrator:
|
|||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
approvals = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
|
@ -178,9 +218,23 @@ class StreamingResponseOrchestrator:
|
|||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
if self._approval_required(tool_call.function.name):
|
||||
approval_response = self.ctx.approval_response(
|
||||
tool_call.function.name, tool_call.function.arguments
|
||||
)
|
||||
if approval_response:
|
||||
if approval_response.approve:
|
||||
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
|
||||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
|
@ -556,23 +610,7 @@ class StreamingResponseOrchestrator:
|
|||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
# Add to chat tools for inference
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in t.parameters
|
||||
},
|
||||
)
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
@ -632,3 +670,46 @@ class StreamingResponseOrchestrator:
|
|||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
if mcp_server.require_approval == "always":
|
||||
return True
|
||||
if mcp_server.require_approval == "never":
|
||||
return False
|
||||
if isinstance(mcp_server, ApprovalFilter):
|
||||
if tool_name in mcp_server.always:
|
||||
return True
|
||||
if tool_name in mcp_server.never:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _add_mcp_approval_request(
|
||||
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
|
||||
arguments=arguments,
|
||||
id=f"approval_{uuid.uuid4()}",
|
||||
name=tool_name,
|
||||
server_label=mcp_server.server_label,
|
||||
)
|
||||
output_messages.append(mcp_approval_request)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
|
|
@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
|
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
self.approval_responses = {
|
||||
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
|
||||
}
|
||||
|
||||
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
|
||||
request = self._approval_request(tool_name, arguments)
|
||||
return self.approval_responses.get(request.id, None) if request else None
|
||||
|
||||
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
|
||||
for request in self.approval_requests:
|
||||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
|
|||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
|
||||
input_item, OpenAIResponseMCPApprovalResponse
|
||||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
|
@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
self.benchmarks[task_def.identifier] = task_def
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
if benchmark_id in self.benchmarks:
|
||||
del self.benchmarks[benchmark_id]
|
||||
|
||||
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
@ -152,31 +159,40 @@ class MetaReferenceEvalImpl(
|
|||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
if candidate.sampling_params.stop:
|
||||
sampling_params["stop"] = candidate.sampling_params.stop
|
||||
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.completion(
|
||||
response = await self.inference_api.openai_completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
prompt=input_content,
|
||||
**sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
input_messages = [
|
||||
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
|
||||
]
|
||||
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
**sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
|
|
|
@ -9,11 +9,12 @@ import uuid
|
|||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
|
@ -22,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
@ -44,7 +46,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -74,7 +76,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
|
@ -86,14 +88,13 @@ class LocalfsFilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
if expires_after is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
|
@ -150,7 +151,6 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
|
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
|
|||
results = await self._nonstream_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content_batch = [
|
||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||
]
|
||||
|
||||
request_batch = []
|
||||
for content in content_batch:
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
results = await self._nonstream_completion(request_batch)
|
||||
return BatchCompletionResponse(batch=results)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
|
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
|
|||
results = await self._nonstream_chat_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request_batch = []
|
||||
for messages in messages_batch:
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
results = await self._nonstream_chat_completion(request_batch)
|
||||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
|
|
|
@ -290,13 +290,13 @@ class LlamaGuardShield:
|
|||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=self.model,
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
content = response.completion_message.content
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
|
|||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
@ -224,10 +224,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Always log to console if console sink is enabled (debug)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
|
||||
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
|||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render({"messages": messages})
|
||||
rendered_content: str = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model_id=model,
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
response = await inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
query = response.completion_message.content
|
||||
query = response.choices[0].message.content
|
||||
|
||||
return query
|
||||
|
|
|
@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
parse_data_url,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
|
|||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
r.raise_for_status()
|
||||
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||
return r.content, mime_type
|
||||
else:
|
||||
if isinstance(doc.content, str):
|
||||
content_str = doc.content
|
||||
else:
|
||||
content_str = interleaved_content_as_str(doc.content)
|
||||
|
||||
if content_str.startswith("data:"):
|
||||
parts = parse_data_url(content_str)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
return
|
||||
|
||||
for doc in documents:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||
mime_type = parts["mimetype"]
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(doc.content.uri)
|
||||
file_data = response.content
|
||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||
else:
|
||||
content_str = await content_from_doc(doc)
|
||||
file_data = content_str.encode("utf-8")
|
||||
mime_type = doc.mime_type or "text/plain"
|
||||
try:
|
||||
try:
|
||||
file_data, mime_type = await raw_data_from_doc(doc)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
try:
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
try:
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
async def query(
|
||||
self,
|
||||
|
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
|
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content,
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[],
|
||||
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="huggingface",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
adapter_type="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.eval,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
|
|
@ -4,13 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
||||
|
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
provider_type="remote::s3",
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
|
@ -49,180 +48,167 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
adapter_type="ollama",
|
||||
provider_type="remote::ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vllm",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
adapter_type="vllm",
|
||||
provider_type="remote::vllm",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
adapter_type="tgi",
|
||||
provider_type="remote::tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
adapter_type="hf::serverless",
|
||||
provider_type="remote::hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
provider_type="remote::hf::endpoint",
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
adapter_type="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
adapter_type="together",
|
||||
provider_type="remote::together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
adapter_type="databricks",
|
||||
provider_type="remote::databricks",
|
||||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="runpod",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
adapter_type="runpod",
|
||||
provider_type="remote::runpod",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="gemini",
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
adapter_type="gemini",
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vertexai",
|
||||
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
adapter_type="vertexai",
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
|
||||
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
|
||||
• Better integration: Seamless integration with other Google Cloud services
|
||||
|
@ -242,61 +228,73 @@ Available Models:
|
|||
- vertex_ai/gemini-2.0-flash
|
||||
- vertex_ai/gemini-2.5-flash
|
||||
- vertex_ai/gemini-2.5-pro""",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
adapter_type="groq",
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
adapter_type="passthrough",
|
||||
provider_type="remote::passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||
description="""
|
||||
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
||||
Provider documentation
|
||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
|
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.post_training,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -38,7 +38,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.scoring,
|
||||
provider_type="inline::braintrust",
|
||||
pip_packages=["autoevals", "openai"],
|
||||
pip_packages=["autoevals"],
|
||||
module="llama_stack.providers.inline.scoring.braintrust",
|
||||
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
adapter_type="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
adapter_type="bing-search",
|
||||
provider_type="remote::bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
adapter_type="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
adapter_type="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
adapter_type="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
Please refer to the sqlite-vec provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -340,9 +341,6 @@ pip install chromadb
|
|||
## Documentation
|
||||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
@ -410,7 +410,7 @@ There are three implementations of search for PGVectoIndex available:
|
|||
- How it works:
|
||||
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
|
||||
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
|
||||
-Characteristics:
|
||||
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
|
||||
|
@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17
|
|||
## Documentation
|
||||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||
""",
|
||||
),
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
description="""
|
||||
description="""
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
|
|||
## Documentation
|
||||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -594,28 +590,29 @@ docker pull qdrant/qdrant
|
|||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
|||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
If you want to use inline Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus[milvus-lite]
|
||||
```
|
||||
|
||||
If you want to use remote Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
|
@ -806,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
|
|||
|
||||
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_stack.apis.datasets import Datasets
|
|||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
|
@ -45,24 +44,29 @@ class NVIDIAEvalImpl(
|
|||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path):
|
||||
async def _evaluator_get(self, path: str):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path, data):
|
||||
async def _evaluator_post(self, path: str, data: dict[str, Any]):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_delete(self, path: str) -> None:
|
||||
"""Helper for making DELETE requests to the evaluator service."""
|
||||
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
|
@ -75,6 +79,10 @@ class NVIDIAEvalImpl(
|
|||
},
|
||||
)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
|
||||
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Annotated, Any
|
|||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
@ -137,7 +138,7 @@ class S3FilesImpl(Files):
|
|||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
|
@ -164,7 +165,7 @@ class S3FilesImpl(Files):
|
|||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -195,8 +196,7 @@ class S3FilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
|
@ -204,14 +204,6 @@ class S3FilesImpl(Files):
|
|||
|
||||
created_at = self._now()
|
||||
|
||||
expires_after = None
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
# we use ExpiresAfter to validate input
|
||||
expires_after = ExpiresAfter(
|
||||
anchor=expires_after_anchor, # type: ignore[arg-type]
|
||||
seconds=expires_after_seconds, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# the default is no expiration.
|
||||
# to implement no expiration we set an expiration beyond the max.
|
||||
# we'll hide this fact from users when returning the file object.
|
||||
|
@ -268,7 +260,6 @@ class S3FilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
|
|
|
@ -8,14 +8,24 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
|
||||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
15
llama_stack/providers/remote/inference/azure/__init__.py
Normal file
15
llama_stack/providers/remote/inference/azure/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
62
llama_stack/providers/remote/inference/azure/azure.py
Normal file
62
llama_stack/providers/remote/inference/azure/azure.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AzureConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
||||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Azure specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "azure_api_key", None):
|
||||
params["api_key"] = provider_data.azure_api_key
|
||||
if getattr(provider_data, "azure_api_base", None):
|
||||
params["api_base"] = provider_data.azure_api_base
|
||||
if getattr(provider_data, "azure_api_version", None):
|
||||
params["api_version"] = provider_data.azure_api_version
|
||||
if getattr(provider_data, "azure_api_type", None):
|
||||
params["api_type"] = provider_data.azure_api_type
|
||||
else:
|
||||
params["api_key"] = self.config.api_key.get_secret_value()
|
||||
params["api_base"] = str(self.config.api_base)
|
||||
params["api_version"] = self.config.api_version
|
||||
params["api_type"] = self.config.api_type
|
||||
|
||||
return params
|
63
llama_stack/providers/remote/inference/azure/config.py
Normal file
63
llama_stack/providers/remote/inference/azure/config.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AzureProviderDataValidator(BaseModel):
|
||||
azure_api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
azure_api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
azure_api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version for Azure (e.g., 2024-06-01)",
|
||||
)
|
||||
azure_api_type: str | None = Field(
|
||||
default="azure",
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(BaseModel):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"api_type": api_type,
|
||||
}
|
|
@ -11,21 +11,17 @@ from botocore.client import BaseClient
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -47,12 +43,47 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
REGION_PREFIX_MAP = {
|
||||
"us": "us.",
|
||||
"eu": "eu.",
|
||||
"ap": "ap.",
|
||||
}
|
||||
|
||||
|
||||
def _get_region_prefix(region: str | None) -> str:
|
||||
# AWS requires region prefixes for inference profiles
|
||||
if region is None:
|
||||
return "us." # default to US when we don't know
|
||||
|
||||
# Handle case insensitive region matching
|
||||
region_lower = region.lower()
|
||||
for prefix in REGION_PREFIX_MAP:
|
||||
if region_lower.startswith(f"{prefix}-"):
|
||||
return REGION_PREFIX_MAP[prefix]
|
||||
|
||||
# Fallback to US for anything we don't recognize
|
||||
return "us."
|
||||
|
||||
|
||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||
# Return ARNs unchanged
|
||||
if model_id.startswith("arn:"):
|
||||
return model_id
|
||||
|
||||
# Return inference profile IDs that already have regional prefixes
|
||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||
return model_id
|
||||
|
||||
# Default to US East when no region is provided
|
||||
if region is None:
|
||||
region = "us-east-1"
|
||||
|
||||
return _get_region_prefix(region) + model_id
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
|
@ -61,7 +92,7 @@ class BedrockInferenceAdapter(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self._config = config
|
||||
self._client = None
|
||||
|
||||
|
@ -166,8 +197,13 @@ class BedrockInferenceAdapter(
|
|||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||
|
||||
return {
|
||||
"modelId": bedrock_model,
|
||||
"modelId": inference_profile_id,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
|
@ -176,31 +212,6 @@ class BedrockInferenceAdapter(
|
|||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_content_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -5,26 +5,23 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -35,42 +32,41 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
self._cerebras_client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
|
@ -107,14 +103,14 @@ class CerebrasInferenceAdapter(
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
@ -156,14 +152,14 @@ class CerebrasInferenceAdapter(
|
|||
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -20,8 +20,8 @@ class CerebrasImplConfig(BaseModel):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://inference-docs.cerebras.ai/models
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -17,16 +17,16 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None),
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
url: str = "${env.DATABRICKS_HOST:=}",
|
||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -4,74 +4,59 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
Model,
|
||||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
|
@ -80,89 +65,80 @@ class DatabricksInferenceAdapter(
|
|||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {} # from OpenAIMixin
|
||||
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
|
||||
endpoints = ws_client.serving_endpoints.list()
|
||||
for endpoint in endpoints:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=endpoint.name,
|
||||
identifier=endpoint.name,
|
||||
)
|
||||
if endpoint.task == "llm/v1/chat":
|
||||
model.model_type = ModelType.llm # this is redundant, but informative
|
||||
elif endpoint.task == "llm/v1/embeddings":
|
||||
if endpoint.name not in self.embedding_model_metadata:
|
||||
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
|
||||
continue
|
||||
model.model_type = ModelType.embedding
|
||||
model.metadata = self.embedding_model_metadata[endpoint.name]
|
||||
else:
|
||||
logger.warning(f"Unknown model type, skipping: {endpoint}")
|
||||
continue
|
||||
|
||||
self._model_cache[endpoint.name] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
|
@ -4,36 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -45,33 +33,35 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
}
|
||||
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -79,7 +69,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
|
@ -91,15 +81,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
)
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_base_url(self) -> str:
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -260,178 +253,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimension"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Fireworks does not support media for embeddings"
|
||||
)
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
logger.debug(f"fireworks params: {params}")
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
|
|
|
@ -8,14 +8,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
embedding_model_metadata = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
|
@ -4,12 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: GroqConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
|
|
|
@ -9,8 +9,6 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
|||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
|
@ -18,7 +16,6 @@ class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
def __init__(self, config: GroqConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -8,8 +8,6 @@ from llama_stack.providers.remote.inference.llama_openai_compat.config import Ll
|
|||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
|
@ -30,7 +28,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
|
@ -39,32 +39,13 @@ client = LlamaStackAsLibraryClient("nvidia")
|
|||
client.initialize()
|
||||
```
|
||||
|
||||
### Create Completion
|
||||
|
||||
The following example shows how to create a completion for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> The hosted NVIDIA Llama NIMs (for example ```meta-llama/Llama-3.1-8B-Instruct```) that have ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` do not support the ```completion``` method, while locally deployed NIMs do.
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
content="Complete the sentence using one word: Roses are red, violets are :",
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
print(f"Response: {response.content}")
|
||||
```
|
||||
|
||||
### Create Chat Completion
|
||||
|
||||
The following example shows how to create a chat completion for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -76,11 +57,9 @@ response = client.inference.chat_completion(
|
|||
},
|
||||
],
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
max_tokens=50,
|
||||
)
|
||||
print(f"Response: {response.completion_message.content}")
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
|
@ -108,15 +87,15 @@ tool_definition = ToolDefinition(
|
|||
},
|
||||
)
|
||||
|
||||
tool_response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
tool_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Tool Response: {tool_response.completion_message.content}")
|
||||
if tool_response.completion_message.tool_calls:
|
||||
for tool_call in tool_response.completion_message.tool_calls:
|
||||
print(f"Tool Response: {tool_response.choices[0].message.content}")
|
||||
if tool_response.choices[0].message.tool_calls:
|
||||
for tool_call in tool_response.choices[0].message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.tool_name}")
|
||||
print(f"Arguments: {tool_call.arguments}")
|
||||
```
|
||||
|
@ -142,8 +121,8 @@ response_format = JsonSchemaResponseFormat(
|
|||
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||
)
|
||||
|
||||
structured_response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
structured_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -153,7 +132,7 @@ structured_response = client.inference.chat_completion(
|
|||
response_format=response_format,
|
||||
)
|
||||
|
||||
print(f"Structured Response: {structured_response.completion_message.content}")
|
||||
print(f"Structured Response: {structured_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
|
@ -186,8 +165,8 @@ def load_image_as_base64(image_path):
|
|||
image_path = {path_to_the_image}
|
||||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.inference.chat_completion(
|
||||
model_id="nvidia/vila",
|
||||
vlm_response = client.chat.completions.create(
|
||||
model="nvidia/vila",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -207,5 +186,5 @@ vlm_response = client.inference.chat_completion(
|
|||
],
|
||||
)
|
||||
|
||||
print(f"VLM Response: {vlm_response.completion_message.content}")
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/vila",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
# NeMo Retriever Text Embedding models -
|
||||
#
|
||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
#
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
# | Model ID | Max | Publisher | Embedding | Dynamic |
|
||||
# | | Tokens | | Dimension | Embeddings |
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
|
||||
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
|
||||
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
|
||||
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 2048,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/nv-embedqa-e5-v5",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 4096,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="snowflake/arctic-embed-l",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -11,8 +11,6 @@ from openai import NOT_GIVEN, APIConnectionError
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -21,8 +19,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -31,15 +27,11 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
|
@ -48,7 +40,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
|
@ -60,7 +51,7 @@ from .utils import _is_nvidia_hosted
|
|||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
|
@ -74,10 +65,15 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
|
@ -155,60 +151,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
# we pass n=1 to get only one completion
|
||||
return convert_openai_completion_choice(response.choices[0])
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
if any(content_has_media(content) for content in contents):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
#
|
||||
# Llama Stack: contents = list[str] | list[InterleavedContentItem]
|
||||
# ->
|
||||
# OpenAI: input = str | list[str]
|
||||
#
|
||||
# we can ignore str and always pass list[str] to OpenAI
|
||||
#
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
|
||||
extra_body = {}
|
||||
|
||||
if text_truncation is not None:
|
||||
text_truncation_options = {
|
||||
TextTruncation.none: "NONE",
|
||||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
extra_body["dimensions"] = output_dimension
|
||||
|
||||
if task_type is not None:
|
||||
task_type_options = {
|
||||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||
# ->
|
||||
# Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
|
||||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -1,106 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:70b-instruct-fp16",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:11b-instruct-fp16",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:latest",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:l6-v2",
|
||||
aliases=["all-minilm"],
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-embed-text",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -6,18 +6,14 @@
|
|||
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||
from openai import AsyncOpenAI
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
|
@ -28,30 +24,21 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
@ -60,61 +47,93 @@ from llama_stack.providers.datatypes import (
|
|||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
b64_encode_openai_embeddings_response,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
prepare_openai_embeddings_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_content_to_url,
|
||||
interleaved_content_as_str,
|
||||
localize_image_content,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
"nomic-embed-text:latest": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:v1.5": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:137m-v1.5-fp16": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
self.download_images = True
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
if self._openai_client is None:
|
||||
url = self.config.url.rstrip("/")
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
def get_api_key(self):
|
||||
return "NO_KEY"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
|
@ -127,59 +146,6 @@ class OllamaInferenceAdapter(
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text:latest",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
|
@ -189,7 +155,7 @@ class OllamaInferenceAdapter(
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.client.ps()
|
||||
await self.ollama_client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
@ -197,9 +163,6 @@ class OllamaInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
|
@ -238,7 +201,7 @@ class OllamaInferenceAdapter(
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
|
@ -254,7 +217,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
|
@ -308,7 +271,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
|
@ -346,9 +309,9 @@ class OllamaInferenceAdapter(
|
|||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.client.chat(**params)
|
||||
r = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
r = await self.client.generate(**params)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
@ -372,9 +335,9 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
s = await self.client.chat(**params)
|
||||
s = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
s = await self.client.generate(**params)
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
if "message" in chunk:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
@ -394,230 +357,17 @@ class OllamaInferenceAdapter(
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
)
|
||||
response = await self.client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
||||
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
||||
)
|
||||
return model
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
||||
# we use list() here instead of ps() -
|
||||
# - ps() only lists running models, not available models
|
||||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self._get_model(model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||
|
||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||
params = prepare_openai_embeddings_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
response = await self.openai_client.embeddings.create(**params)
|
||||
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
# TODO: Investigate why model_obj.identifier is used instead of response.model
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.identifier,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||
if isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if c.type == "image_url" and c.image_url and c.image_url.url:
|
||||
localize_result = await localize_image_content(c.image_url.url)
|
||||
if localize_result is None:
|
||||
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
||||
|
||||
content, format = localize_result
|
||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||
return m
|
||||
|
||||
messages = [await _convert_message(m) for m in messages]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
response = await self.openai_client.chat.completions.create(**params)
|
||||
return await self._adjust_ollama_chat_completion_response_ids(response)
|
||||
|
||||
async def _adjust_ollama_chat_completion_response_ids(
|
||||
self,
|
||||
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if isinstance(response, AsyncIterator):
|
||||
|
||||
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
async for chunk in response:
|
||||
chunk.id = id
|
||||
yield chunk
|
||||
|
||||
return stream_with_chunk_ids()
|
||||
else:
|
||||
response.id = id
|
||||
return response
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o",
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o-audio-preview",
|
||||
"chatgpt-4o-latest",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
"o4-mini",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelInfo:
|
||||
"""Structured representation of embedding model information."""
|
||||
|
||||
embedding_dimension: int
|
||||
context_length: int
|
||||
|
||||
|
||||
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
||||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
|
@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
@ -22,8 +21,6 @@ logger = get_logger(name=__name__, category="inference::openai")
|
|||
# | completion | LiteLLMOpenAIMixin |
|
||||
# | chat_completion | LiteLLMOpenAIMixin |
|
||||
# | embedding | LiteLLMOpenAIMixin |
|
||||
# | batch_completion | LiteLLMOpenAIMixin |
|
||||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | OpenAIMixin |
|
||||
# | openai_chat_completion | OpenAIMixin |
|
||||
# | openai_embeddings | OpenAIMixin |
|
||||
|
@ -40,10 +37,14 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
|
|
|
@ -14,8 +14,6 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -27,7 +25,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -43,7 +40,7 @@ from .config import PassthroughImplConfig
|
|||
|
||||
class PassthroughInferenceAdapter(Inference):
|
||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, [])
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -190,25 +187,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[InterleavedContent],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return await client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -136,16 +136,6 @@ class RunpodInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -4,12 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
|
||||
from .sambanova import SambaNovaInferenceAdapter
|
||||
|
||||
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
|
@ -26,10 +25,9 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
self.environment_available_models: list[str] = []
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
|
|
|
@ -8,17 +8,15 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -26,13 +24,13 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
@ -41,16 +39,15 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
completion_request_to_prompt_model_input_info,
|
||||
|
@ -73,26 +70,49 @@ def build_hf_repo_model_entries():
|
|||
|
||||
|
||||
class _HfAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
client: AsyncInferenceClient
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
||||
hf_client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
}
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
async for model in self.client.models.list():
|
||||
models.append(
|
||||
Model(
|
||||
identifier=model.id,
|
||||
provider_resource_id=model.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
|
||||
|
@ -176,7 +196,7 @@ class _HfAdapter(
|
|||
params = await self._get_params_for_completion(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
finish_reason = None
|
||||
|
@ -194,7 +214,7 @@ class _HfAdapter(
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_completion(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
|
@ -241,7 +261,7 @@ class _HfAdapter(
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
|
@ -256,7 +276,7 @@ class _HfAdapter(
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
|
@ -282,16 +302,6 @@ class _HfAdapter(
|
|||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -308,18 +318,21 @@ class TGIAdapter(_HfAdapter):
|
|||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
self.url = f"{config.url.rstrip('/')}/v1"
|
||||
self.api_key = SecretStr("NO_KEY")
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
# TODO: how do we set url for this?
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
|
@ -331,6 +344,7 @@ class InferenceEndpointAdapter(_HfAdapter):
|
|||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.client = endpoint.async_client
|
||||
self.hf_client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||
# TODO: how do we set url for this?
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 32768,
|
||||
},
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -4,70 +4,76 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from together import AsyncTogether
|
||||
from together.constants import BASE_URL
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
|
||||
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
self._model_cache: dict[str, Model] = {}
|
||||
|
||||
def get_api_key(self):
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -235,25 +241,37 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
logger.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
for m in await self._get_client().models.list():
|
||||
if m.type == "embedding":
|
||||
if m.id not in self.embedding_model_metadata:
|
||||
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
|
||||
continue
|
||||
metadata = self.embedding_model_metadata[m.id]
|
||||
self._model_cache[m.id] = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
self._model_cache[m.id] = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
return self._model_cache.values()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def check_model_availability(self, model):
|
||||
return model in self._model_cache
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
@ -263,125 +281,36 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Together's OpenAI-compatible embeddings endpoint is not compatible with
|
||||
the standard OpenAI embeddings endpoint.
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
The endpoint -
|
||||
- not all models return usage information
|
||||
- does not support user param, returns 400 Unrecognized request arguments supplied: user
|
||||
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
|
||||
"""
|
||||
# Together support ticket #13332 -> will not fix
|
||||
if user is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support user param.")
|
||||
# Together support ticket #13333 -> escalated
|
||||
if dimensions is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# together.ai sometimes adds usage data to the stream, even if include_usage is False
|
||||
# This causes an unexpected final chunk with empty choices array to be sent
|
||||
# to clients that may not handle it gracefully.
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
# Together support ticket #13330 -> escalated
|
||||
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
|
||||
if not hasattr(response, "usage") or response.usage is None:
|
||||
logger.warning(
|
||||
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
|
||||
)
|
||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
return response
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
# Vertex AI model IDs with vertex_ai/ prefix as required by litellm
|
||||
LLM_MODEL_IDS = [
|
||||
"vertex_ai/gemini-2.0-flash",
|
||||
"vertex_ai/gemini-2.5-flash",
|
||||
"vertex_ai/gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES
|
|
@ -16,14 +16,12 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VertexAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: VertexAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="vertex_ai",
|
||||
api_key_from_config=None, # Vertex AI uses ADC, not API keys
|
||||
provider_data_api_key_field="vertex_project", # Use project for validation
|
||||
|
|
|
@ -4,9 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
class VLLMProviderDataValidator(BaseModel):
|
||||
vllm_api_token: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
@ -15,7 +16,6 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
|
@ -30,24 +30,14 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ModelStore,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -62,6 +52,7 @@ from llama_stack.providers.datatypes import (
|
|||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
|
@ -69,17 +60,16 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
UnparseableToolCall,
|
||||
convert_message_to_openai_dict,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_tool_call,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -288,15 +278,30 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
yield c
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=build_hf_repo_model_entries(),
|
||||
litellm_provider_name="vllm",
|
||||
api_key_from_config=config.api_token,
|
||||
provider_data_api_key_field="vllm_api_token",
|
||||
openai_compat_api_base=config.url,
|
||||
)
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.config = config
|
||||
self.client = None
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
if not self.config.url:
|
||||
raise ValueError("No base URL configured")
|
||||
return self.config.url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
|
@ -305,11 +310,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None # mypy
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
|
@ -335,14 +339,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
This method is used by the Provider API to verify
|
||||
that the service is running correctly.
|
||||
Uses the unauthenticated /health endpoint.
|
||||
Returns:
|
||||
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
base_url = self.get_base_url()
|
||||
health_url = urljoin(base_url, "health")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(health_url)
|
||||
response.raise_for_status()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
|
@ -351,21 +360,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def _lazy_initialize_client(self):
|
||||
if self.client is not None:
|
||||
return
|
||||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
|
@ -374,7 +372,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
@ -406,7 +403,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
@ -429,13 +425,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
return self._stream_chat_completion_with_client(request, self.client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
r = await client.chat.completions.create(**params)
|
||||
choice = r.choices[0]
|
||||
|
@ -449,9 +446,24 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(
|
||||
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
# This method is called from LiteLLMOpenAIMixin.chat_completion
|
||||
# The response parameter contains the litellm response
|
||||
# We need to convert it to our format
|
||||
async def _stream_generator():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
async for chunk in convert_openai_chat_completion_stream(
|
||||
_stream_generator(), enable_incremental_tool_calls=True
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _stream_chat_completion_with_client(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""Helper method for streaming with explicit client parameter."""
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await client.chat.completions.create(**params)
|
||||
|
@ -463,7 +475,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
assert self.client is not None
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client is not initialized")
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
@ -471,7 +484,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
assert self.client is not None
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client is not initialized")
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
@ -479,16 +493,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
try:
|
||||
res = await client.models.list()
|
||||
res = self.client.models.list()
|
||||
except APIConnectionError as e:
|
||||
raise ValueError(
|
||||
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
|
||||
|
@ -534,180 +544,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = await self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model_obj = await self._get_model(model)
|
||||
assert model_obj.model_type == ModelType.embedding
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
|
||||
# Call vLLM embeddings endpoint with encoding_format
|
||||
response = await self.client.embeddings.create(
|
||||
model=model_obj.provider_resource_id,
|
||||
input=input_list,
|
||||
dimensions=dimensions,
|
||||
encoding_format=encoding_format,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
data = [
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
for i, embedding_data in enumerate(response.data)
|
||||
]
|
||||
|
||||
# Not returning actual token usage since vLLM doesn't provide it
|
||||
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: dict[str, Any] = {}
|
||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||
if guided_choice:
|
||||
extra_body["guided_choice"] = guided_choice
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
return await self.client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.client.chat.completions.create(**params) # type: ignore
|
||||
|
|
|
@ -26,11 +26,11 @@ class WatsonXConfig(BaseModel):
|
|||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||
description="The watsonx API key, only needed of using the hosted service",
|
||||
description="The watsonx API key",
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||
description="The Project ID key, only needed of using the hosted service",
|
||||
description="The Project ID key",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
|
|
|
@ -11,13 +11,11 @@ from ibm_watsonx_ai.foundation_models import Model
|
|||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
|
@ -30,7 +28,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -38,6 +35,7 @@ from llama_stack.apis.inference import (
|
|||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -57,14 +55,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::watsonx")
|
||||
|
||||
|
||||
# Note on structured output
|
||||
# WatsonX returns responses with a json embedded into a string.
|
||||
# Examples:
|
||||
|
||||
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
|
||||
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
|
||||
# Not even a valid JSON, but we can still extract the JSON from the content
|
||||
|
||||
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
|
||||
# "year_born": "1963", "year_retired": "2003"\\}}$')
|
||||
# Find the start of the boxed content
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
||||
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
self._config = config
|
||||
self._openai_client: AsyncOpenAI | None = None
|
||||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
|
@ -249,16 +262,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
}
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -140,13 +140,11 @@ client.models.register(
|
|||
#### 2. Inference with the fine-tuned model
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
response = client.completions.create(
|
||||
prompt="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=False,
|
||||
model_id="test-example-model@v1",
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
model="test-example-model@v1",
|
||||
max_tokens=50,
|
||||
)
|
||||
print(response.content)
|
||||
print(response.choices[0].text)
|
||||
```
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
@ -49,10 +50,13 @@ def convert_id(_id: str) -> str:
|
|||
Converts any string into a UUID string based on a seed.
|
||||
|
||||
Qdrant accepts UUID strings and unsigned integers as point ID.
|
||||
We use a seed to convert each string into a UUID string deterministically.
|
||||
We use a SHA-256 hash to convert each string into a UUID string deterministically.
|
||||
This allows us to overwrite the same point with the original ID.
|
||||
"""
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
|
||||
hash_input = f"qdrant_id:{_id}".encode()
|
||||
sha256_hash = hashlib.sha256(hash_input).hexdigest()
|
||||
# Use the first 32 characters to create a valid UUID
|
||||
return str(uuid.UUID(sha256_hash[:32]))
|
||||
|
||||
|
||||
class QdrantIndex(EmbeddingIndex):
|
||||
|
|
5
llama_stack/providers/utils/files/__init__.py
Normal file
5
llama_stack/providers/utils/files/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
69
llama_stack/providers/utils/files/form_data.py
Normal file
69
llama_stack/providers/utils/files/form_data.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
|
||||
|
||||
async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None:
|
||||
"""
|
||||
Generic parser to extract a Pydantic model from multipart form data.
|
||||
Handles both bracket notation (field[attr1], field[attr2]) and JSON string format.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object
|
||||
field_name: The name of the field in the form data (e.g., "expires_after")
|
||||
model_class: The Pydantic model class to parse into
|
||||
|
||||
Returns:
|
||||
An instance of model_class if parsing succeeds, None otherwise
|
||||
|
||||
Example:
|
||||
expires_after = await parse_pydantic_from_form(
|
||||
request, "expires_after", ExpiresAfter
|
||||
)
|
||||
"""
|
||||
form = await request.form()
|
||||
|
||||
# Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds])
|
||||
bracket_data = {}
|
||||
prefix = f"{field_name}["
|
||||
for key in form.keys():
|
||||
if key.startswith(prefix) and key.endswith("]"):
|
||||
# Extract the attribute name from field_name[attr]
|
||||
attr = key[len(prefix) : -1]
|
||||
bracket_data[attr] = form[key]
|
||||
|
||||
if bracket_data:
|
||||
try:
|
||||
return model_class(**bracket_data)
|
||||
except (ValidationError, TypeError):
|
||||
pass
|
||||
|
||||
# Check for JSON string format
|
||||
if field_name in form:
|
||||
value = form[field_name]
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
data = json.loads(value)
|
||||
return model_class(**data)
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def parse_expires_after(request: Request) -> ExpiresAfter | None:
|
||||
"""
|
||||
Dependency to parse expires_after from multipart form data.
|
||||
Handles both bracket notation (expires_after[anchor], expires_after[seconds])
|
||||
and JSON string format.
|
||||
"""
|
||||
return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter)
|
|
@ -15,16 +15,11 @@ if TYPE_CHECKING:
|
|||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
InterleavedContentItem,
|
||||
ModelStore,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
TextTruncation,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
@ -35,23 +30,6 @@ log = get_logger(name=__name__, category="providers::utils")
|
|||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(
|
||||
embedding_model.encode,
|
||||
[interleaved_content_as_str(content) for content in contents],
|
||||
show_progress_bar=False,
|
||||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletion,
|
||||
|
@ -52,7 +54,7 @@ class InferenceStore:
|
|||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
@ -129,16 +131,44 @@ class InferenceStore:
|
|||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
data = chat_completion.model_dump()
|
||||
record_data = {
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="chat_completions",
|
||||
data={
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
},
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="chat_completions",
|
||||
data=record_data,
|
||||
)
|
||||
except IntegrityError as e:
|
||||
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
|
||||
# recorded responses across different tests. No need to warn or error under those circumstances.
|
||||
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
|
||||
|
||||
# Check if it's a unique constraint violation
|
||||
error_message = str(e.orig) if e.orig else str(e)
|
||||
if self._is_unique_constraint_error(error_message):
|
||||
# Update the existing record instead
|
||||
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
|
||||
else:
|
||||
# Re-raise if it's not a unique constraint error
|
||||
raise
|
||||
|
||||
def _is_unique_constraint_error(self, error_message: str) -> bool:
|
||||
"""Check if the error is specifically a unique constraint violation."""
|
||||
error_lower = error_message.lower()
|
||||
return any(
|
||||
indicator in error_lower
|
||||
for indicator in [
|
||||
"unique constraint failed", # SQLite
|
||||
"duplicate key", # PostgreSQL
|
||||
"unique violation", # PostgreSQL alternative
|
||||
"duplicate entry", # MySQL
|
||||
]
|
||||
)
|
||||
|
||||
async def list_chat_completions(
|
||||
|
@ -172,7 +202,6 @@ class InferenceStore:
|
|||
order_by=[("created", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [
|
||||
|
@ -199,7 +228,6 @@ class InferenceStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
table="chat_completions",
|
||||
where={"id": completion_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
|
|
@ -11,14 +11,11 @@ import litellm
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
|
@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -40,7 +36,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
b64_encode_openai_embeddings_response,
|
||||
convert_message_to_openai_dict_new,
|
||||
|
@ -50,9 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
@ -67,10 +60,10 @@ class LiteLLMOpenAIMixin(
|
|||
# when calling litellm.
|
||||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
openai_compat_api_base: str | None = None,
|
||||
download_images: bool = False,
|
||||
json_schema_strict: bool = True,
|
||||
|
@ -86,7 +79,7 @@ class LiteLLMOpenAIMixin(
|
|||
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
|
||||
"""
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
||||
self.litellm_provider_name = litellm_provider_name
|
||||
self.api_key_from_config = api_key_from_config
|
||||
|
@ -269,24 +262,6 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
return api_key
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model.provider_resource_id),
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
||||
embeddings = [data["embedding"] for data in response["data"]]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -399,6 +374,14 @@ class LiteLLMOpenAIMixin(
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
if stream and get_current_span() is not None:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
|
|
|
@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
|
@ -21,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field(
|
||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
|
@ -37,13 +36,6 @@ class ProviderModelEntry(BaseModel):
|
|||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
||||
for model in all_registered_models():
|
||||
if model.descriptor() == model_descriptor:
|
||||
return model.huggingface_repo
|
||||
return None
|
||||
|
||||
|
||||
def build_hf_repo_model_entry(
|
||||
provider_model_id: str,
|
||||
model_descriptor: str,
|
||||
|
@ -63,25 +55,20 @@ def build_hf_repo_model_entry(
|
|||
)
|
||||
|
||||
|
||||
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[],
|
||||
llama_model=model_descriptor,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||
self.model_entries = model_entries
|
||||
def __init__(
|
||||
self,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
allowed_models: list[str] | None = None,
|
||||
):
|
||||
self.allowed_models = allowed_models
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
self.model_entries = model_entries or []
|
||||
for entry in self.model_entries:
|
||||
for alias in entry.aliases:
|
||||
self.alias_to_provider_id_map[alias] = entry.provider_model_id
|
||||
|
||||
|
@ -103,7 +90,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
Model(
|
||||
identifier=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
model_type=entry.model_type,
|
||||
metadata=entry.metadata,
|
||||
provider_id=self.__provider_id__,
|
||||
)
|
||||
|
|
|
@ -805,6 +805,10 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
properties[param_name].update(default=param.default)
|
||||
if param.items:
|
||||
properties[param_name].update(items=param.items)
|
||||
if param.title:
|
||||
properties[param_name].update(title=param.title)
|
||||
if param.required:
|
||||
required.append(param_name)
|
||||
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -22,13 +23,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ABC):
|
||||
class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -43,6 +47,28 @@ class OpenAIMixin(ABC):
|
|||
The model_store is set in routing_tables/common.py during provider initialization.
|
||||
"""
|
||||
|
||||
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
|
||||
# is overwritten with a client-side generated id.
|
||||
#
|
||||
# This is useful for providers that do not return a unique id in the response.
|
||||
overwrite_completion_id: bool = False
|
||||
|
||||
# Allow subclasses to control whether to download images and convert to base64
|
||||
# for providers that require base64 encoded images instead of URLs.
|
||||
download_images: bool = False
|
||||
|
||||
# Embedding model metadata for this provider
|
||||
# Can be set by subclasses or instances to provide embedding models
|
||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Cache of available models keyed by model ID
|
||||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -67,6 +93,17 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_extra_client_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get any extra parameters to pass to the AsyncOpenAI client.
|
||||
|
||||
Child classes can override this method to provide additional parameters
|
||||
such as timeout settings, proxies, etc.
|
||||
|
||||
:return: A dictionary of extra parameters
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
|
@ -78,6 +115,7 @@ class OpenAIMixin(ABC):
|
|||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
|
@ -98,6 +136,23 @@ class OpenAIMixin(ABC):
|
|||
raise ValueError(f"Model {model} has no provider_resource_id")
|
||||
return model_obj.provider_resource_id
|
||||
|
||||
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
||||
if not self.overwrite_completion_id:
|
||||
return resp
|
||||
|
||||
new_id = f"cltsd-{uuid.uuid4()}"
|
||||
if stream:
|
||||
|
||||
async def _gen():
|
||||
async for chunk in resp:
|
||||
chunk.id = new_id
|
||||
yield chunk
|
||||
|
||||
return _gen()
|
||||
else:
|
||||
resp.id = new_id
|
||||
return resp
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -124,13 +179,18 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
Direct OpenAI completion API call.
|
||||
"""
|
||||
if guided_choice is not None:
|
||||
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
if prompt_logprobs is not None:
|
||||
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
# Handle parameters that are not supported by OpenAI API, but may be by the provider
|
||||
# prompt_logprobs is supported by vLLM
|
||||
# guided_choice is supported by vLLM
|
||||
# TODO: test coverage
|
||||
extra_body: dict[str, Any] = {}
|
||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||
if guided_choice:
|
||||
extra_body["guided_choice"] = guided_choice
|
||||
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||
resp = await self.client.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
prompt=prompt,
|
||||
|
@ -150,9 +210,12 @@ class OpenAIMixin(ABC):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -182,8 +245,25 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
# Type ignore because return types are compatible
|
||||
return await self.client.chat.completions.create( # type: ignore[no-any-return]
|
||||
if self.download_images:
|
||||
|
||||
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||
if isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
|
||||
localize_result = await localize_image_content(c.image_url.url)
|
||||
if localize_result is None:
|
||||
raise ValueError(
|
||||
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
|
||||
)
|
||||
content, format = localize_result
|
||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||
# else it's a string and we don't need to modify it
|
||||
return m
|
||||
|
||||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
resp = await self.client.chat.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
messages=messages,
|
||||
|
@ -211,6 +291,8 @@ class OpenAIMixin(ABC):
|
|||
)
|
||||
)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -247,26 +329,53 @@ class OpenAIMixin(ABC):
|
|||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
||||
Also, caches the models in self._model_cache for use in check_model_availability().
|
||||
|
||||
:return: A list of Model instances representing available models.
|
||||
"""
|
||||
self._model_cache = {}
|
||||
|
||||
async for m in self.client.models.list():
|
||||
if self.allowed_models and m.id not in self.allowed_models:
|
||||
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
|
||||
continue
|
||||
if metadata := self.embedding_model_metadata.get(m.id):
|
||||
# This is an embedding model - augment with metadata
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
# This is an LLM
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
self._model_cache[m.id] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from OpenAI.
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Direct model lookup - returns model or raises NotFoundError
|
||||
await self.client.models.retrieve(model)
|
||||
return True
|
||||
except openai.NotFoundError:
|
||||
# Model doesn't exist - this is expected for unavailable models
|
||||
pass
|
||||
except Exception as e:
|
||||
# All other errors (auth, rate limit, network, etc.)
|
||||
logger.warning(f"Failed to check model availability for {model}: {e}")
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
|
||||
return False
|
||||
return model in self._model_cache
|
||||
|
|
|
@ -192,6 +192,14 @@ async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
|||
format = "png"
|
||||
|
||||
return content, format
|
||||
elif uri.startswith("data"):
|
||||
# data:image/{format};base64,{data}
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URL format, {uri[:40]}...")
|
||||
fmt, image_data = match.groups()
|
||||
content = base64.b64decode(image_data)
|
||||
return content, fmt
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class CommonConfig(BaseModel):
|
|||
|
||||
|
||||
class RedisKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
|
||||
type: Literal["redis"] = KVStoreType.redis.value
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
|
@ -50,7 +50,7 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||
type: Literal["sqlite"] = KVStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
||||
description="File path for the sqlite database",
|
||||
|
@ -69,7 +69,7 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
|
||||
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
type: Literal["postgres"] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
|
@ -113,11 +113,11 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
||||
type: Literal["mongodb"] = KVStoreType.mongodb.value
|
||||
host: str = "localhost"
|
||||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str = None
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
@ -19,8 +20,13 @@ log = get_logger(name=__name__, category="providers::utils")
|
|||
class MongoDBKVStoreImpl(KVStore):
|
||||
def __init__(self, config: MongoDBKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
self.collection = None
|
||||
self.conn: AsyncMongoClient | None = None
|
||||
|
||||
@property
|
||||
def collection(self) -> AsyncCollection:
|
||||
if self.conn is None:
|
||||
raise RuntimeError("MongoDB connection is not initialized")
|
||||
return self.conn[self.config.db][self.config.collection_name]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
|
@ -32,7 +38,6 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
}
|
||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||
self.conn = AsyncMongoClient(**conn_creds)
|
||||
self.collection = self.conn[self.config.db][self.config.collection_name]
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to MongoDB database server")
|
||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||
|
|
|
@ -9,9 +9,13 @@ from datetime import datetime
|
|||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class SqliteKVStoreImpl(KVStore):
|
||||
def __init__(self, config: SqliteKVStoreConfig):
|
||||
|
@ -50,6 +54,9 @@ class SqliteKVStoreImpl(KVStore):
|
|||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
|
|
|
@ -24,11 +24,13 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileLastError,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListFilesResponse,
|
||||
VectorStoreListResponse,
|
||||
|
@ -107,7 +109,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self.openai_vector_stores.pop(store_id, None)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
self,
|
||||
store_id: str,
|
||||
file_id: str,
|
||||
file_info: dict[str, Any],
|
||||
file_contents: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Save vector store file metadata to persistent storage."""
|
||||
assert self.kvstore
|
||||
|
@ -301,7 +307,10 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
all_stores = all_stores[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
||||
before_index = next(
|
||||
(i for i, store in enumerate(all_stores) if store["id"] == before),
|
||||
len(all_stores),
|
||||
)
|
||||
all_stores = all_stores[:before_index]
|
||||
|
||||
# Apply limit
|
||||
|
@ -397,7 +406,9 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
search_mode: (
|
||||
str | None
|
||||
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
"""Search for chunks in a vector store."""
|
||||
max_num_results = max_num_results or 10
|
||||
|
@ -685,7 +696,10 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
file_objects = file_objects[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next((i for i, file in enumerate(file_objects) if file.id == before), len(file_objects))
|
||||
before_index = next(
|
||||
(i for i, file in enumerate(file_objects) if file.id == before),
|
||||
len(file_objects),
|
||||
)
|
||||
file_objects = file_objects[:before_index]
|
||||
|
||||
# Apply limit
|
||||
|
@ -805,3 +819,42 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
id=file_id,
|
||||
deleted=True,
|
||||
)
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Create a vector store file batch."""
|
||||
raise NotImplementedError("openai_create_vector_store_file_batch is not implemented yet")
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
"""Returns a list of vector store files in a batch."""
|
||||
raise NotImplementedError("openai_list_files_in_vector_store_file_batch is not implemented yet")
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Retrieve a vector store file batch."""
|
||||
raise NotImplementedError("openai_retrieve_vector_store_file_batch is not implemented yet")
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Cancel a vector store file batch."""
|
||||
raise NotImplementedError("openai_cancel_vector_store_file_batch is not implemented yet")
|
||||
|
|
|
@ -3,6 +3,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Order,
|
||||
)
|
||||
|
@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="responses_store")
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
def __init__(
|
||||
self,
|
||||
config: ResponsesStoreConfig | SqlStoreConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
# Handle backward compatibility
|
||||
if not isinstance(config, ResponsesStoreConfig):
|
||||
# Legacy: SqlStoreConfig passed directly as config
|
||||
config = ResponsesStoreConfig(
|
||||
sql_store_config=config,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.sql_store_config = config.sql_store_config
|
||||
if not self.sql_store_config:
|
||||
self.sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
@ -43,9 +72,68 @@ class ResponsesStore:
|
|||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input))
|
||||
else:
|
||||
await self._write_response_object(response_object, input)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
|
||||
|
@ -74,6 +162,9 @@ class ResponsesStore:
|
|||
:param model: The model to filter by.
|
||||
:param order: The order to sort the responses by.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
|
@ -87,7 +178,6 @@ class ResponsesStore:
|
|||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||
|
@ -102,10 +192,12 @@ class ResponsesStore:
|
|||
"""
|
||||
Get a response object with automatic access control checking.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one(
|
||||
"openai_responses",
|
||||
where={"id": response_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
@ -116,7 +208,10 @@ class ResponsesStore:
|
|||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found")
|
||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
|||
access control policies, user attribute capture, and SQL filtering optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_store: SqlStore):
|
||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||
"""
|
||||
Initialize the authorization layer.
|
||||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
:param policy: Access control policy to use for authorization
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.policy = policy
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
|||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Fetch all rows with automatic access control filtering."""
|
||||
access_where = self._build_access_control_where_clause(policy)
|
||||
access_where = self._build_access_control_where_clause(self.policy)
|
||||
rows = await self.sql_store.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
|||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
||||
return PaginatedResponse(
|
||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
|||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch one row with automatic access control checking."""
|
||||
results = await self.fetch_all(
|
||||
table=table,
|
||||
policy=policy,
|
||||
where=where,
|
||||
limit=1,
|
||||
order_by=order_by,
|
||||
|
@ -172,6 +171,20 @@ class AuthorizedSqlStore:
|
|||
|
||||
return results.data[0] if results.data else None
|
||||
|
||||
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
|
||||
"""Update rows with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
await self.sql_store.update(table, enhanced_data, where)
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
"""Delete rows with automatic access control filtering."""
|
||||
await self.sql_store.delete(table, where)
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import contextvars
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
@ -76,16 +76,16 @@ def span_id_to_str(span_id: int) -> str:
|
|||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = random.getrandbits(64)
|
||||
span_id = secrets.randbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
span_id = secrets.randbits(64)
|
||||
return span_id_to_str(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = random.getrandbits(128)
|
||||
trace_id = secrets.randbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = random.getrandbits(128)
|
||||
trace_id = secrets.randbits(128)
|
||||
return trace_id_to_str(trace_id)
|
||||
|
||||
|
||||
|
|
|
@ -120,6 +120,10 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
required="default" not in param_schema,
|
||||
items=param_schema.get("items", None),
|
||||
title=param_schema.get("title", None),
|
||||
default=param_schema.get("default", None),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
|
|
|
@ -12,14 +12,12 @@ import uuid
|
|||
def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | None = None) -> str:
|
||||
"""
|
||||
Generate a unique chunk ID using a hash of the document ID and chunk text.
|
||||
|
||||
Note: MD5 is used only to calculate an identifier, not for security purposes.
|
||||
Adding usedforsecurity=False for compatibility with FIPS environments.
|
||||
Then use the first 32 characters of the hash to create a UUID.
|
||||
"""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
if chunk_window:
|
||||
hash_input += f":{chunk_window}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input, usedforsecurity=False).hexdigest()))
|
||||
return str(uuid.UUID(hashlib.sha256(hash_input).hexdigest()[:32]))
|
||||
|
||||
|
||||
def proper_case(s: str) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue