mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 20:50:52 +00:00
Merge branch 'main' into remove-batch-inference
This commit is contained in:
commit
32b87bf88a
748 changed files with 127607 additions and 50032 deletions
|
@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
|
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
|
|||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: AdapterSpec = Field(
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here.
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -234,33 +207,6 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> str | None:
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api,
|
||||
adapter: AdapterSpec,
|
||||
api_dependencies: list[Api] | None = None,
|
||||
optional_api_dependencies: list[Api] | None = None,
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(StrEnum):
|
||||
OK = "OK"
|
||||
|
|
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
|
|||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
|
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
"Completion response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id,
|
||||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
self.benchmarks[task_def.identifier] = task_def
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
if benchmark_id in self.benchmarks:
|
||||
del self.benchmarks[benchmark_id]
|
||||
|
||||
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
|
@ -86,11 +86,16 @@ class LocalfsFilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
||||
|
@ -145,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -22,7 +22,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
|
@ -37,7 +36,6 @@ FIXED_FNS = [
|
|||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "bfcl",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,296 +0,0 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
|
@ -1,989 +0,0 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
|
@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
|
|||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
|
|||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
|||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render({"messages": messages})
|
||||
rendered_content: str = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model_id=model,
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
response = await inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
query = response.completion_message.content
|
||||
query = response.choices[0].message.content
|
||||
|
||||
return query
|
||||
|
|
|
@ -5,10 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import UploadFile
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
|
@ -30,14 +36,16 @@ from llama_stack.apis.tools import (
|
|||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreChunkingStrategyStaticConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
@ -49,16 +57,59 @@ def make_random_string(length: int = 8):
|
|||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
r.raise_for_status()
|
||||
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||
return r.content, mime_type
|
||||
else:
|
||||
if isinstance(doc.content, str):
|
||||
content_str = doc.content
|
||||
else:
|
||||
content_str = interleaved_content_as_str(doc.content)
|
||||
|
||||
if content_str.startswith("data:"):
|
||||
parts = parse_data_url(content_str)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
):
|
||||
self.config = config
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -78,27 +129,56 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
doc.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
if not documents:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
)
|
||||
for doc in documents:
|
||||
try:
|
||||
try:
|
||||
file_data, mime_type = await raw_data_from_doc(doc)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
try:
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
async def query(
|
||||
self,
|
||||
|
@ -131,8 +211,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
||||
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
|
||||
for chunk, score in zip(result.chunks, result.scores, strict=False):
|
||||
if not hasattr(chunk, "metadata") or chunk.metadata is None:
|
||||
chunk.metadata = {}
|
||||
chunk.metadata["vector_db_id"] = vector_db_id
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
if not chunks:
|
||||
return RAGQueryResult(content=None)
|
||||
|
@ -167,6 +257,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
metadata_keys_to_exclude_from_context = [
|
||||
"token_count",
|
||||
"metadata_token_count",
|
||||
"vector_db_id",
|
||||
]
|
||||
metadata_for_context = {}
|
||||
for k in chunk_metadata_keys_to_include_from_context:
|
||||
|
@ -191,6 +282,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -226,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
|
@ -237,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content,
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
|
|
@ -30,11 +30,11 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
@ -66,59 +66,6 @@ def _create_sqlite_connection(db_path):
|
|||
return connection
|
||||
|
||||
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score = min(scores.values())
|
||||
max_score = max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
|
||||
def _weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses weighted average of scores."""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||
|
||||
return {
|
||||
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
|
||||
def _rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
|
||||
def _make_sql_identifier(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
@ -398,14 +345,10 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the specified reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Sort by combined score and get top k results
|
||||
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[],
|
||||
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="huggingface",
|
||||
pip_packages=[
|
||||
"datasets",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
adapter_type="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"datasets",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.eval,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
|
|
@ -4,13 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
||||
|
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
provider_type="remote::s3",
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
|
@ -49,180 +48,167 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
adapter_type="ollama",
|
||||
provider_type="remote::ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vllm",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
adapter_type="vllm",
|
||||
provider_type="remote::vllm",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
adapter_type="tgi",
|
||||
provider_type="remote::tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
adapter_type="hf::serverless",
|
||||
provider_type="remote::hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
provider_type="remote::hf::endpoint",
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
adapter_type="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
adapter_type="together",
|
||||
provider_type="remote::together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
adapter_type="databricks",
|
||||
provider_type="remote::databricks",
|
||||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="runpod",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
adapter_type="runpod",
|
||||
provider_type="remote::runpod",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="gemini",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
adapter_type="gemini",
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vertexai",
|
||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
adapter_type="vertexai",
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
|
||||
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
|
||||
• Better integration: Seamless integration with other Google Cloud services
|
||||
|
@ -242,61 +228,73 @@ Available Models:
|
|||
- vertex_ai/gemini-2.0-flash
|
||||
- vertex_ai/gemini-2.5-flash
|
||||
- vertex_ai/gemini-2.5-pro""",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
adapter_type="groq",
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
adapter_type="passthrough",
|
||||
provider_type="remote::passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
pip_packages=["ibm_watson_machine_learning"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||
description="""
|
||||
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
||||
Provider documentation
|
||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
|
@ -48,7 +48,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::huggingface-gpu",
|
||||
pip_packages=["trl", "transformers", "peft", "datasets", "torch"],
|
||||
pip_packages=["trl", "transformers", "peft", "datasets>=4.0.0", "torch"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
|
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.post_training,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -38,7 +38,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.scoring,
|
||||
provider_type="inline::braintrust",
|
||||
pip_packages=["autoevals", "openai"],
|
||||
pip_packages=["autoevals"],
|
||||
module="llama_stack.providers.inline.scoring.braintrust",
|
||||
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -32,62 +31,57 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
adapter_type="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
adapter_type="bing-search",
|
||||
provider_type="remote::bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
adapter_type="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
adapter_type="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
adapter_type="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
Please refer to the sqlite-vec provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -340,9 +341,6 @@ pip install chromadb
|
|||
## Documentation
|
||||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval.
|
|||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
There are three implementations of search for PGVectoIndex available:
|
||||
|
||||
1. Vector Search:
|
||||
- How it works:
|
||||
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
|
||||
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
|
||||
-Characteristics:
|
||||
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
|
||||
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
|
||||
- Best for: Finding conceptually related content, handling synonyms, cross-language search
|
||||
|
||||
2. Keyword Search
|
||||
- How it works:
|
||||
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
|
||||
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
|
||||
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
|
||||
- Characteristics:
|
||||
- Lexical matching - finds exact keyword matches and variations
|
||||
- Uses GIN (Generalized Inverted Index) for fast text search performance
|
||||
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
|
||||
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
|
||||
|
||||
3. Hybrid Search
|
||||
- How it works:
|
||||
- Combines both vector and keyword search results
|
||||
- Runs both searches independently, then merges results using configurable reranking
|
||||
|
||||
- Two reranking strategies available:
|
||||
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
|
||||
- Weighted Average - (default: 0.5)
|
||||
|
||||
- Characteristics:
|
||||
- Best of both worlds: semantic understanding + exact matching
|
||||
- Documents appearing in both searches get boosted scores
|
||||
- Configurable balance between semantic and lexical matching
|
||||
- Best for: General-purpose search where you want both precision and recall
|
||||
|
||||
4. Database Schema
|
||||
The PGVector implementation stores data optimized for all three search types:
|
||||
CREATE TABLE vector_store_xxx (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB, -- Original document
|
||||
embedding vector(dimension), -- For vector search
|
||||
content_text TEXT, -- Raw text content
|
||||
tokenized_content TSVECTOR -- For keyword search
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
|
||||
-- Vector index created automatically by pgvector
|
||||
|
||||
## Usage
|
||||
|
||||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
@ -412,6 +466,25 @@ To use PGVector in your Llama Stack project, follow these steps:
|
|||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## This is an example how you can set up your environment for using PGVector
|
||||
|
||||
1. Export env vars:
|
||||
```bash
|
||||
export ENABLE_PGVECTOR=true
|
||||
export PGVECTOR_HOST=localhost
|
||||
export PGVECTOR_PORT=5432
|
||||
export PGVECTOR_DB=llamastack
|
||||
export PGVECTOR_USER=llamastack
|
||||
export PGVECTOR_PASSWORD=llamastack
|
||||
```
|
||||
|
||||
2. Create DB:
|
||||
```bash
|
||||
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
|
||||
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install PGVector using docker:
|
||||
|
@ -422,19 +495,18 @@ docker pull pgvector/pgvector:pg17
|
|||
## Documentation
|
||||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||
""",
|
||||
),
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
description="""
|
||||
description="""
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -449,6 +521,7 @@ Weaviate supports:
|
|||
- Metadata filtering
|
||||
- Multi-modal retrieval
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
To use Weaviate in your Llama Stack project, follow these steps:
|
||||
|
@ -464,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
|
|||
## Documentation
|
||||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -520,28 +590,29 @@ docker pull qdrant/qdrant
|
|||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -562,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
|||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
If you want to use inline Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus[milvus-lite]
|
||||
```
|
||||
|
||||
If you want to use remote Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
|
@ -732,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
|
|||
|
||||
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_stack.apis.datasets import Datasets
|
|||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
|
@ -45,24 +44,29 @@ class NVIDIAEvalImpl(
|
|||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path):
|
||||
async def _evaluator_get(self, path: str):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path, data):
|
||||
async def _evaluator_post(self, path: str, data: dict[str, Any]):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_delete(self, path: str) -> None:
|
||||
"""Helper for making DELETE requests to the evaluator service."""
|
||||
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
|
@ -75,6 +79,10 @@ class NVIDIAEvalImpl(
|
|||
},
|
||||
)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
|
||||
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
|
@ -6,15 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
# TODO: authorization policies and user separation
|
||||
impl = S3FilesImpl(config)
|
||||
impl = S3FilesImpl(config, policy or [])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
|
@ -15,14 +15,17 @@ from fastapi import File, Form, Response, UploadFile
|
|||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
@ -83,22 +86,85 @@ async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImpl
|
|||
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||
|
||||
|
||||
def _make_file_object(
|
||||
*,
|
||||
id: str,
|
||||
filename: str,
|
||||
purpose: str,
|
||||
bytes: int,
|
||||
created_at: int,
|
||||
expires_at: int,
|
||||
**kwargs: Any, # here to ignore any additional fields, e.g. extra fields from AuthorizedSqlStore
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Construct an OpenAIFileObject and normalize expires_at.
|
||||
|
||||
If expires_at is greater than the max we treat it as no-expiration and
|
||||
return None for expires_at.
|
||||
|
||||
The OpenAI spec says expires_at type is Integer, but the implementation
|
||||
will return None for no expiration.
|
||||
"""
|
||||
obj = OpenAIFileObject(
|
||||
id=id,
|
||||
filename=filename,
|
||||
purpose=OpenAIFilePurpose(purpose),
|
||||
bytes=bytes,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||
obj.expires_at = None # type: ignore
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class S3FilesImpl(Files):
|
||||
"""S3-based implementation of the Files API."""
|
||||
|
||||
# TODO: implement expiration, for now a silly offset
|
||||
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: SqlStore | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
def _now(self) -> int:
|
||||
"""Return current UTC timestamp as int seconds."""
|
||||
return int(datetime.now(UTC).timestamp())
|
||||
|
||||
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
async def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from S3 and the database."""
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
async def _delete_if_expired(self, file_id: str) -> None:
|
||||
"""If the file exists and is expired, delete it."""
|
||||
if row := await self._get_file(file_id, return_expired=True):
|
||||
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||
await self._delete_file(file_id)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -121,7 +187,7 @@ class S3FilesImpl(Files):
|
|||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> SqlStore:
|
||||
def sql_store(self) -> AuthorizedSqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
|
@ -129,27 +195,47 @@ class S3FilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||
created_at = self._now()
|
||||
|
||||
expires_after = None
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
# we use ExpiresAfter to validate input
|
||||
expires_after = ExpiresAfter(
|
||||
anchor=expires_after_anchor, # type: ignore[arg-type]
|
||||
seconds=expires_after_seconds, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# the default is no expiration.
|
||||
# to implement no expiration we set an expiration beyond the max.
|
||||
# we'll hide this fact from users when returning the file object.
|
||||
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||
# the default for BATCH files is 30 days, which happens to be the expiration max.
|
||||
if purpose == OpenAIFilePurpose.BATCH:
|
||||
expires_at = created_at + ExpiresAfter.MAX
|
||||
|
||||
if expires_after is not None:
|
||||
expires_at = created_at + expires_after.seconds
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
entry: dict[str, Any] = {
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
await self.sql_store.insert("openai_files", entry)
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
|
@ -163,14 +249,7 @@ class S3FilesImpl(Files):
|
|||
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=filename,
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return _make_file_object(**entry)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
|
@ -183,29 +262,19 @@ class S3FilesImpl(Files):
|
|||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
|
@ -216,41 +285,20 @@ class S3FilesImpl(Files):
|
|||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
await self._delete_if_expired(file_id)
|
||||
row = await self._get_file(file_id)
|
||||
return _make_file_object(**row)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
await self._delete_if_expired(file_id)
|
||||
_ = await self._get_file(file_id) # raises if not found
|
||||
await self._delete_file(file_id)
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
await self._delete_if_expired(file_id)
|
||||
|
||||
row = await self._get_file(file_id)
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
|
@ -261,7 +309,7 @@ class S3FilesImpl(Files):
|
|||
content = response["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
await self._delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
|
|
|
@ -5,16 +5,27 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
|
||||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
|
@ -26,3 +37,8 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
15
llama_stack/providers/remote/inference/azure/__init__.py
Normal file
15
llama_stack/providers/remote/inference/azure/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
62
llama_stack/providers/remote/inference/azure/azure.py
Normal file
62
llama_stack/providers/remote/inference/azure/azure.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AzureConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
||||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Azure specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "azure_api_key", None):
|
||||
params["api_key"] = provider_data.azure_api_key
|
||||
if getattr(provider_data, "azure_api_base", None):
|
||||
params["api_base"] = provider_data.azure_api_base
|
||||
if getattr(provider_data, "azure_api_version", None):
|
||||
params["api_version"] = provider_data.azure_api_version
|
||||
if getattr(provider_data, "azure_api_type", None):
|
||||
params["api_type"] = provider_data.azure_api_type
|
||||
else:
|
||||
params["api_key"] = self.config.api_key.get_secret_value()
|
||||
params["api_base"] = str(self.config.api_base)
|
||||
params["api_version"] = self.config.api_version
|
||||
params["api_type"] = self.config.api_type
|
||||
|
||||
return params
|
63
llama_stack/providers/remote/inference/azure/config.py
Normal file
63
llama_stack/providers/remote/inference/azure/config.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AzureProviderDataValidator(BaseModel):
|
||||
azure_api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
azure_api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
azure_api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version for Azure (e.g., 2024-06-01)",
|
||||
)
|
||||
azure_api_type: str | None = Field(
|
||||
default="azure",
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(BaseModel):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"api_type": api_type,
|
||||
}
|
|
@ -53,6 +53,43 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
REGION_PREFIX_MAP = {
|
||||
"us": "us.",
|
||||
"eu": "eu.",
|
||||
"ap": "ap.",
|
||||
}
|
||||
|
||||
|
||||
def _get_region_prefix(region: str | None) -> str:
|
||||
# AWS requires region prefixes for inference profiles
|
||||
if region is None:
|
||||
return "us." # default to US when we don't know
|
||||
|
||||
# Handle case insensitive region matching
|
||||
region_lower = region.lower()
|
||||
for prefix in REGION_PREFIX_MAP:
|
||||
if region_lower.startswith(f"{prefix}-"):
|
||||
return REGION_PREFIX_MAP[prefix]
|
||||
|
||||
# Fallback to US for anything we don't recognize
|
||||
return "us."
|
||||
|
||||
|
||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||
# Return ARNs unchanged
|
||||
if model_id.startswith("arn:"):
|
||||
return model_id
|
||||
|
||||
# Return inference profile IDs that already have regional prefixes
|
||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||
return model_id
|
||||
|
||||
# Default to US East when no region is provided
|
||||
if region is None:
|
||||
region = "us-east-1"
|
||||
|
||||
return _get_region_prefix(region) + model_id
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
|
@ -61,7 +98,7 @@ class BedrockInferenceAdapter(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self._config = config
|
||||
self._client = None
|
||||
|
||||
|
@ -166,8 +203,13 @@ class BedrockInferenceAdapter(
|
|||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||
|
||||
return {
|
||||
"modelId": bedrock_model,
|
||||
"modelId": inference_profile_id,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
|
@ -185,6 +227,11 @@ class BedrockInferenceAdapter(
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(model.provider_resource_id, region_name)
|
||||
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
||||
|
@ -193,7 +240,7 @@ class BedrockInferenceAdapter(
|
|||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
modelId=inference_profile_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
|
@ -35,42 +36,41 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
self._cerebras_client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
|
@ -107,14 +107,14 @@ class CerebrasInferenceAdapter(
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
@ -156,14 +156,14 @@ class CerebrasInferenceAdapter(
|
|||
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
|
|
@ -20,8 +20,8 @@ class CerebrasImplConfig(BaseModel):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://inference-docs.cerebras.ai/models
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -17,16 +17,16 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None),
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
url: str = "${env.DATABRICKS_HOST:=}",
|
||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -4,23 +4,28 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -29,49 +34,33 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
|
@ -80,72 +69,54 @@ class DatabricksInferenceAdapter(
|
|||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -157,12 +128,31 @@ class DatabricksInferenceAdapter(
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {} # from OpenAIMixin
|
||||
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
|
||||
endpoints = ws_client.serving_endpoints.list()
|
||||
for endpoint in endpoints:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=endpoint.name,
|
||||
identifier=endpoint.name,
|
||||
)
|
||||
if endpoint.task == "llm/v1/chat":
|
||||
model.model_type = ModelType.llm # this is redundant, but informative
|
||||
elif endpoint.task == "llm/v1/embeddings":
|
||||
if endpoint.name not in self.embedding_model_metadata:
|
||||
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
|
||||
continue
|
||||
model.model_type = ModelType.embedding
|
||||
model.metadata = self.embedding_model_metadata[endpoint.name]
|
||||
else:
|
||||
logger.warning(f"Unknown model type, skipping: {endpoint}")
|
||||
continue
|
||||
|
||||
self._model_cache[endpoint.name] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
|
@ -4,11 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
|
@ -63,15 +54,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
}
|
||||
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -79,7 +73,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
|
@ -91,15 +85,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
)
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_base_url(self) -> str:
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -285,153 +282,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
logger.debug(f"fireworks params: {params}")
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
|
|
|
@ -5,22 +5,30 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
embedding_model_metadata = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
|
@ -4,158 +4,32 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Groq does not support json_schema response format, so we need to convert it to json_object
|
||||
if response_format and response_format.type == "json_schema":
|
||||
response_format.type = "json_object"
|
||||
schema = response_format.json_schema.get("schema", {})
|
||||
response_format.json_schema = None
|
||||
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
||||
if messages and messages[0].role == "system":
|
||||
messages[0].content = messages[0].content + json_instructions
|
||||
else:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
||||
|
||||
# Groq returns a 400 error if tools are provided but none are called
|
||||
# So, set tool_choice to "required" to attempt to force a call
|
||||
if tools and (not tool_choice or tool_choice == "auto"):
|
||||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Groq does not support streaming requests that set response_format
|
||||
fake_stream = False
|
||||
if stream and response_format:
|
||||
params["stream"] = False
|
||||
fake_stream = True
|
||||
|
||||
response = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
if fake_stream:
|
||||
chunk_choices = []
|
||||
for choice in response.choices:
|
||||
delta = OpenAIChoiceDelta(
|
||||
content=choice.message.content,
|
||||
role=choice.message.role,
|
||||
tool_calls=choice.message.tool_calls,
|
||||
)
|
||||
chunk_choice = OpenAIChunkChoice(
|
||||
delta=delta,
|
||||
finish_reason=choice.finish_reason,
|
||||
index=choice.index,
|
||||
logprobs=None,
|
||||
)
|
||||
chunk_choices.append(chunk_choice)
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response.id,
|
||||
choices=chunk_choices,
|
||||
object="chat.completion.chunk",
|
||||
created=response.created,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _fake_stream_generator():
|
||||
yield chunk
|
||||
|
||||
return _fake_stream_generator()
|
||||
else:
|
||||
return response
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -8,8 +8,6 @@ from llama_stack.providers.remote.inference.llama_openai_compat.config import Ll
|
|||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
|
@ -30,7 +28,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
|
@ -41,10 +41,10 @@ client.initialize()
|
|||
|
||||
### Create Completion
|
||||
|
||||
> Note on Completion API
|
||||
>
|
||||
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||
The following example shows how to create a completion for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> The hosted NVIDIA Llama NIMs (for example ```meta-llama/Llama-3.1-8B-Instruct```) that have ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` do not support the ```completion``` method, while locally deployed NIMs do.
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
|
@ -60,6 +60,8 @@ print(f"Response: {response.content}")
|
|||
|
||||
### Create Chat Completion
|
||||
|
||||
The following example shows how to create a chat completion for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
@ -82,6 +84,9 @@ print(f"Response: {response.completion_message.content}")
|
|||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
|
||||
The following example shows how to do tool calling for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
|
@ -117,6 +122,9 @@ if tool_response.completion_message.tool_calls:
|
|||
```
|
||||
|
||||
### Structured Output Example
|
||||
|
||||
The following example shows how to do structured output for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
|
@ -149,8 +157,10 @@ print(f"Structured Response: {structured_response.completion_message.content}")
|
|||
```
|
||||
|
||||
### Create Embeddings
|
||||
> Note on OpenAI embeddings compatibility
|
||||
>
|
||||
|
||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
|
@ -160,4 +170,42 @@ response = client.inference.embeddings(
|
|||
task_type="query",
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
```
|
||||
```
|
||||
|
||||
### Vision Language Models Example
|
||||
|
||||
The following example shows how to run vision inference by using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
def load_image_as_base64(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
img_bytes = image_file.read()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
image_path = {path_to_the_image}
|
||||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.inference.chat_completion(
|
||||
model_id="nvidia/vila",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": demo_image_b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe what you see in this image in detail.",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(f"VLM Response: {vlm_response.completion_message.content}")
|
||||
```
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# NeMo Retriever Text Embedding models -
|
||||
#
|
||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
#
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
# | Model ID | Max | Publisher | Embedding | Dynamic |
|
||||
# | | Tokens | | Dimension | Embeddings |
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
|
||||
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
|
||||
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
|
||||
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 2048,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/nv-embedqa-e5-v5",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 4096,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="snowflake/arctic-embed-l",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -37,9 +37,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
|
@ -48,7 +45,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
|
@ -60,7 +56,7 @@ from .utils import _is_nvidia_hosted
|
|||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
|
@ -74,10 +70,15 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
|
|
|
@ -1,106 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:70b-instruct-fp16",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:11b-instruct-fp16",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:latest",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:l6-v2",
|
||||
aliases=["all-minilm"],
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-embed-text",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -7,12 +7,10 @@
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||
from openai import AsyncOpenAI
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
|
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
|
@ -50,8 +45,9 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
@ -60,19 +56,19 @@ from llama_stack.providers.datatypes import (
|
|||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
b64_encode_openai_embeddings_response,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
prepare_openai_embeddings_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
|
@ -83,103 +79,83 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
request_has_media,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
"nomic-embed-text:latest": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:v1.5": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:137m-v1.5-fp16": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
if self._openai_client is None:
|
||||
url = self.config.url.rstrip("/")
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
def get_api_key(self):
|
||||
return "NO_KEY"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
r = await self.health()
|
||||
if r["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
|
@ -189,7 +165,7 @@ class OllamaInferenceAdapter(
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.client.ps()
|
||||
await self.ollama_client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
@ -238,7 +214,7 @@ class OllamaInferenceAdapter(
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
|
@ -254,7 +230,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
|
@ -308,7 +284,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
|
@ -346,9 +322,9 @@ class OllamaInferenceAdapter(
|
|||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.client.chat(**params)
|
||||
r = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
r = await self.client.generate(**params)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
@ -372,9 +348,9 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
s = await self.client.chat(**params)
|
||||
s = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
s = await self.client.generate(**params)
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
if "message" in chunk:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
@ -407,7 +383,7 @@ class OllamaInferenceAdapter(
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
)
|
||||
response = await self.client.embed(
|
||||
response = await self.ollama_client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
@ -416,121 +392,16 @@ class OllamaInferenceAdapter(
|
|||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
||||
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
||||
)
|
||||
return model
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
||||
# we use list() here instead of ps() -
|
||||
# - ps() only lists running models, not available models
|
||||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self._get_model(model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||
|
||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||
params = prepare_openai_embeddings_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
response = await self.openai_client.embeddings.create(**params)
|
||||
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
# TODO: Investigate why model_obj.identifier is used instead of response.model
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.identifier,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
@ -599,25 +470,7 @@ class OllamaInferenceAdapter(
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
response = await self.openai_client.chat.completions.create(**params)
|
||||
return await self._adjust_ollama_chat_completion_response_ids(response)
|
||||
|
||||
async def _adjust_ollama_chat_completion_response_ids(
|
||||
self,
|
||||
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if isinstance(response, AsyncIterator):
|
||||
|
||||
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
async for chunk in response:
|
||||
chunk.id = id
|
||||
yield chunk
|
||||
|
||||
return stream_with_chunk_ids()
|
||||
else:
|
||||
response.id = id
|
||||
return response
|
||||
return await OpenAIMixin.openai_chat_completion(self, **params)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o",
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o-audio-preview",
|
||||
"chatgpt-4o-latest",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
"o4-mini",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelInfo:
|
||||
"""Structured representation of embedding model information."""
|
||||
|
||||
embedding_dimension: int
|
||||
context_length: int
|
||||
|
||||
|
||||
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
||||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
|
@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
@ -38,10 +37,14 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
|
|
|
@ -43,7 +43,7 @@ from .config import PassthroughImplConfig
|
|||
|
||||
class PassthroughInferenceAdapter(Inference):
|
||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, [])
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -4,19 +4,30 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
SambaNova Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of LiteLLMOpenAIMixin.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
|
||||
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
|
||||
"""
|
||||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
|
@ -24,3 +35,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
download_images=True, # SambaNova requires base64 image encoding
|
||||
json_schema_strict=False, # SambaNova doesn't support strict=True yet
|
||||
)
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The SambaNova base URL
|
||||
"""
|
||||
return self.config.url
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -33,6 +34,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
@ -41,16 +43,15 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
completion_request_to_prompt_model_input_info,
|
||||
|
@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
|
|||
|
||||
|
||||
class _HfAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
client: AsyncInferenceClient
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
||||
hf_client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
}
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
async for model in self.client.models.list():
|
||||
models.append(
|
||||
Model(
|
||||
identifier=model.id,
|
||||
provider_resource_id=model.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
|
||||
|
@ -176,7 +200,7 @@ class _HfAdapter(
|
|||
params = await self._get_params_for_completion(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
finish_reason = None
|
||||
|
@ -194,7 +218,7 @@ class _HfAdapter(
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_completion(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
|
@ -241,7 +265,7 @@ class _HfAdapter(
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
|
@ -256,7 +280,7 @@ class _HfAdapter(
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
|
@ -308,18 +332,21 @@ class TGIAdapter(_HfAdapter):
|
|||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
self.url = f"{config.url.rstrip('/')}/v1"
|
||||
self.api_key = SecretStr("NO_KEY")
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
# TODO: how do we set url for this?
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
|
@ -331,6 +358,7 @@ class InferenceEndpointAdapter(_HfAdapter):
|
|||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.client = endpoint.async_client
|
||||
self.hf_client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||
# TODO: how do we set url for this?
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 32768,
|
||||
},
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from together import AsyncTogether
|
||||
from together.constants import BASE_URL
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -23,12 +23,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -38,18 +33,20 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
|
@ -59,15 +56,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
|
||||
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
self._model_cache: dict[str, Model] = {}
|
||||
|
||||
def get_api_key(self):
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -255,6 +266,38 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
for m in await self._get_client().models.list():
|
||||
if m.type == "embedding":
|
||||
if m.id not in self.embedding_model_metadata:
|
||||
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
|
||||
continue
|
||||
metadata = self.embedding_model_metadata[m.id]
|
||||
self._model_cache[m.id] = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
self._model_cache[m.id] = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
return self._model_cache.values()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def check_model_availability(self, model):
|
||||
return model in self._model_cache
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -263,125 +306,36 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Together's OpenAI-compatible embeddings endpoint is not compatible with
|
||||
the standard OpenAI embeddings endpoint.
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
The endpoint -
|
||||
- not all models return usage information
|
||||
- does not support user param, returns 400 Unrecognized request arguments supplied: user
|
||||
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
|
||||
"""
|
||||
# Together support ticket #13332 -> will not fix
|
||||
if user is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support user param.")
|
||||
# Together support ticket #13333 -> escalated
|
||||
if dimensions is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# together.ai sometimes adds usage data to the stream, even if include_usage is False
|
||||
# This causes an unexpected final chunk with empty choices array to be sent
|
||||
# to clients that may not handle it gracefully.
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
# Together support ticket #13330 -> escalated
|
||||
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
|
||||
if not hasattr(response, "usage") or response.usage is None:
|
||||
logger.warning(
|
||||
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
|
||||
)
|
||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
return response
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
# Vertex AI model IDs with vertex_ai/ prefix as required by litellm
|
||||
LLM_MODEL_IDS = [
|
||||
"vertex_ai/gemini-2.0-flash",
|
||||
"vertex_ai/gemini-2.5-flash",
|
||||
"vertex_ai/gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES
|
|
@ -6,20 +6,22 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
import google.auth.transport.requests
|
||||
from google.auth import default
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VertexAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: VertexAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="vertex_ai",
|
||||
api_key_from_config=None, # Vertex AI uses ADC, not API keys
|
||||
provider_data_api_key_field="vertex_project", # Use project for validation
|
||||
|
@ -27,9 +29,30 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
# Vertex AI doesn't use API keys, it uses Application Default Credentials
|
||||
# Return empty string to let litellm handle authentication via ADC
|
||||
return ""
|
||||
"""
|
||||
Get an access token for Vertex AI using Application Default Credentials.
|
||||
|
||||
Vertex AI uses ADC instead of API keys. This method obtains an access token
|
||||
from the default credentials and returns it for use with the OpenAI-compatible client.
|
||||
"""
|
||||
try:
|
||||
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
|
||||
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
credentials.refresh(google.auth.transport.requests.Request())
|
||||
return str(credentials.token)
|
||||
except Exception:
|
||||
# If we can't get credentials, return empty string to let LiteLLM handle it
|
||||
# This allows the LiteLLM mixin to work with ADC directly
|
||||
return ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Vertex AI OpenAI-compatible API base URL.
|
||||
|
||||
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||
"""
|
||||
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
|
|
|
@ -4,9 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
class VLLMProviderDataValidator(BaseModel):
|
||||
vllm_api_token: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
@ -38,13 +39,6 @@ from llama_stack.apis.inference import (
|
|||
LogProbConfig,
|
||||
Message,
|
||||
ModelStore,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -62,6 +56,7 @@ from llama_stack.providers.datatypes import (
|
|||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
|
@ -69,13 +64,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
UnparseableToolCall,
|
||||
convert_message_to_openai_dict,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_tool_call,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
|
@ -288,15 +284,30 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
yield c
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=build_hf_repo_model_entries(),
|
||||
litellm_provider_name="vllm",
|
||||
api_key_from_config=config.api_token,
|
||||
provider_data_api_key_field="vllm_api_token",
|
||||
openai_compat_api_base=config.url,
|
||||
)
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.config = config
|
||||
self.client = None
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
if not self.config.url:
|
||||
raise ValueError("No base URL configured")
|
||||
return self.config.url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
|
@ -305,11 +316,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None # mypy
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
|
@ -335,14 +345,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
This method is used by the Provider API to verify
|
||||
that the service is running correctly.
|
||||
Uses the unauthenticated /health endpoint.
|
||||
Returns:
|
||||
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
base_url = self.get_base_url()
|
||||
health_url = urljoin(base_url, "health")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(health_url)
|
||||
response.raise_for_status()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
|
@ -351,21 +366,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def _lazy_initialize_client(self):
|
||||
if self.client is not None:
|
||||
return
|
||||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
|
@ -374,7 +378,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
@ -406,7 +409,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
@ -429,13 +431,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
return self._stream_chat_completion_with_client(request, self.client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
r = await client.chat.completions.create(**params)
|
||||
choice = r.choices[0]
|
||||
|
@ -449,9 +452,24 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(
|
||||
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
# This method is called from LiteLLMOpenAIMixin.chat_completion
|
||||
# The response parameter contains the litellm response
|
||||
# We need to convert it to our format
|
||||
async def _stream_generator():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
async for chunk in convert_openai_chat_completion_stream(
|
||||
_stream_generator(), enable_incremental_tool_calls=True
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _stream_chat_completion_with_client(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""Helper method for streaming with explicit client parameter."""
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await client.chat.completions.create(**params)
|
||||
|
@ -463,7 +481,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
assert self.client is not None
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client is not initialized")
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
@ -471,7 +490,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
assert self.client is not None
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client is not initialized")
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
@ -479,16 +499,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
try:
|
||||
res = await client.models.list()
|
||||
res = self.client.models.list()
|
||||
except APIConnectionError as e:
|
||||
raise ValueError(
|
||||
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
|
||||
|
@ -543,8 +559,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
|
@ -560,154 +574,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model_obj = await self._get_model(model)
|
||||
assert model_obj.model_type == ModelType.embedding
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
|
||||
# Call vLLM embeddings endpoint with encoding_format
|
||||
response = await self.client.embeddings.create(
|
||||
model=model_obj.provider_resource_id,
|
||||
input=input_list,
|
||||
dimensions=dimensions,
|
||||
encoding_format=encoding_format,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
data = [
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
for i, embedding_data in enumerate(response.data)
|
||||
]
|
||||
|
||||
# Not returning actual token usage since vLLM doesn't provide it
|
||||
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: dict[str, Any] = {}
|
||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||
if guided_choice:
|
||||
extra_body["guided_choice"] = guided_choice
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
return await self.client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.client.chat.completions.create(**params) # type: ignore
|
||||
|
|
|
@ -26,11 +26,11 @@ class WatsonXConfig(BaseModel):
|
|||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||
description="The watsonx API key, only needed of using the hosted service",
|
||||
description="The watsonx API key",
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||
description="The Project ID key, only needed of using the hosted service",
|
||||
description="The Project ID key",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
from ibm_watsonx_ai.foundation_models import Model
|
||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
|||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -57,14 +58,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::watsonx")
|
||||
|
||||
|
||||
# Note on structured output
|
||||
# WatsonX returns responses with a json embedded into a string.
|
||||
# Examples:
|
||||
|
||||
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
|
||||
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
|
||||
# Not even a valid JSON, but we can still extract the JSON from the content
|
||||
|
||||
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
|
||||
# "year_born": "1963", "year_retired": "2003"\\}}$')
|
||||
# Find the start of the boxed content
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
||||
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
self._config = config
|
||||
self._openai_client: AsyncOpenAI | None = None
|
||||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import heapq
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
|
@ -72,25 +77,63 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
self.kvstore = kvstore
|
||||
# reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
|
||||
PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION: dict[str, str] = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
dimension: int,
|
||||
conn: psycopg2.extensions.connection,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.dimension = dimension
|
||||
self.conn = conn
|
||||
self.kvstore = kvstore
|
||||
self.check_distance_metric_availability(distance_metric)
|
||||
self.distance_metric = distance_metric
|
||||
self.table_name = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
|
||||
self.table_name = f"vs_{sanitized_identifier}"
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({self.dimension}),
|
||||
content_text TEXT,
|
||||
tokenized_content TSVECTOR
|
||||
)
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create GIN index for full-text search performance
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx
|
||||
ON {self.table_name} USING GIN(tokenized_content)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
content_text = interleaved_content_as_str(chunk.content)
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.chunk_id}",
|
||||
Json(chunk.model_dump()),
|
||||
embeddings[i].tolist(),
|
||||
content_text,
|
||||
content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function. Eg. to_tsvector(content_text) = tokenized_content
|
||||
)
|
||||
)
|
||||
|
||||
query = sql.SQL(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (id, document, embedding)
|
||||
INSERT INTO {self.table_name} (id, document, embedding, content_text, tokenized_content)
|
||||
VALUES %s
|
||||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
embedding = EXCLUDED.embedding,
|
||||
document = EXCLUDED.document,
|
||||
content_text = EXCLUDED.content_text,
|
||||
tokenized_content = EXCLUDED.tokenized_content
|
||||
"""
|
||||
)
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))")
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector similarity search using PostgreSQL's search function. Default distance metric is COSINE.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
pgvector_search_function = self.get_pgvector_search_function()
|
||||
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
SELECT document, embedding {pgvector_search_function} %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
|
@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
Args:
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE tokenized_content @@ plainto_tsquery('english', %s)
|
||||
ORDER BY score DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(query_string, query_string, k),
|
||||
)
|
||||
results = cur.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, score in results:
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(float(score))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -158,7 +254,59 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||
"""
|
||||
Hybrid search combining vector similarity and keyword search using configurable reranking.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Convert responses to score dictionaries using chunk_id
|
||||
vector_scores = {
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
|
||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
|
@ -170,6 +318,25 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
||||
def check_distance_metric_availability(self, distance_metric: str) -> None:
|
||||
"""Check if the distance metric is supported by PGVector.
|
||||
|
||||
Args:
|
||||
distance_metric: The distance metric to check
|
||||
|
||||
Raises:
|
||||
ValueError: If the distance metric is not supported
|
||||
"""
|
||||
if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION:
|
||||
supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION.keys())
|
||||
raise ValueError(
|
||||
f"Distance metric '{distance_metric}' is not supported by PGVector. "
|
||||
f"Supported metrics are: {', '.join(supported_metrics)}"
|
||||
)
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
@ -185,8 +352,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
|
@ -233,9 +400,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
pgvector_index = PGVectorIndex(
|
||||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
@ -272,8 +443,15 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
@ -49,10 +50,13 @@ def convert_id(_id: str) -> str:
|
|||
Converts any string into a UUID string based on a seed.
|
||||
|
||||
Qdrant accepts UUID strings and unsigned integers as point ID.
|
||||
We use a seed to convert each string into a UUID string deterministically.
|
||||
We use a SHA-256 hash to convert each string into a UUID string deterministically.
|
||||
This allows us to overwrite the same point with the original ID.
|
||||
"""
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
|
||||
hash_input = f"qdrant_id:{_id}".encode()
|
||||
sha256_hash = hashlib.sha256(hash_input).hexdigest()
|
||||
# Use the first 32 characters to create a valid UUID
|
||||
return str(uuid.UUID(sha256_hash[:32]))
|
||||
|
||||
|
||||
class QdrantIndex(EmbeddingIndex):
|
||||
|
|
|
@ -4,53 +4,55 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
aws_access_key_id: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_PROFILE"),
|
||||
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: int | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: float | None = Field(
|
||||
default=60,
|
||||
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: float | None = Field(
|
||||
default=60,
|
||||
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
session_ttl: int | None = Field(
|
||||
default=3600,
|
||||
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(
|
||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(
|
||||
embedding_model.encode,
|
||||
[interleaved_content_as_str(content) for content in contents],
|
||||
show_progress_bar=False,
|
||||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
|
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
|
|||
|
||||
# Get the model and generate embeddings
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
||||
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||
|
||||
# Convert embeddings to the requested format
|
||||
data = []
|
||||
|
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
usage=usage,
|
||||
)
|
||||
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
|
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
|
|||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
def _load_model():
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
return SentenceTransformer(model)
|
||||
|
||||
loaded_model = await asyncio.to_thread(_load_model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
|
|
@ -3,6 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletion,
|
||||
|
@ -10,27 +15,46 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="inference_store")
|
||||
|
||||
|
||||
class InferenceStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
def __init__(
|
||||
self,
|
||||
config: InferenceStoreConfig | SqlStoreConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
# Handle backward compatibility
|
||||
if not isinstance(config, InferenceStoreConfig):
|
||||
# Legacy: SqlStoreConfig passed directly as config
|
||||
config = InferenceStoreConfig(
|
||||
sql_store_config=config,
|
||||
)
|
||||
self.sql_store_config = sql_store_config
|
||||
|
||||
self.config = config
|
||||
self.sql_store_config = config.sql_store_config
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
@ -42,23 +66,109 @@ class InferenceStore:
|
|||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def store_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if not self.sql_store:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((chat_completion, input_messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
|
||||
)
|
||||
await self._queue.put((chat_completion, input_messages))
|
||||
else:
|
||||
await self._write_chat_completion(chat_completion, input_messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
chat_completion, input_messages = item
|
||||
try:
|
||||
await self._write_chat_completion(chat_completion, input_messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing chat completion: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
data = chat_completion.model_dump()
|
||||
record_data = {
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="chat_completions",
|
||||
data={
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
},
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="chat_completions",
|
||||
data=record_data,
|
||||
)
|
||||
except IntegrityError as e:
|
||||
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
|
||||
# recorded responses across different tests. No need to warn or error under those circumstances.
|
||||
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
|
||||
|
||||
# Check if it's a unique constraint violation
|
||||
error_message = str(e.orig) if e.orig else str(e)
|
||||
if self._is_unique_constraint_error(error_message):
|
||||
# Update the existing record instead
|
||||
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
|
||||
else:
|
||||
# Re-raise if it's not a unique constraint error
|
||||
raise
|
||||
|
||||
def _is_unique_constraint_error(self, error_message: str) -> bool:
|
||||
"""Check if the error is specifically a unique constraint violation."""
|
||||
error_lower = error_message.lower()
|
||||
return any(
|
||||
indicator in error_lower
|
||||
for indicator in [
|
||||
"unique constraint failed", # SQLite
|
||||
"duplicate key", # PostgreSQL
|
||||
"unique violation", # PostgreSQL alternative
|
||||
"duplicate entry", # MySQL
|
||||
]
|
||||
)
|
||||
|
||||
async def list_chat_completions(
|
||||
|
@ -92,7 +202,6 @@ class InferenceStore:
|
|||
order_by=[("created", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [
|
||||
|
@ -119,7 +228,6 @@ class InferenceStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
table="chat_completions",
|
||||
where={"id": completion_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
b64_encode_openai_embeddings_response,
|
||||
convert_message_to_openai_dict_new,
|
||||
|
@ -67,10 +67,10 @@ class LiteLLMOpenAIMixin(
|
|||
# when calling litellm.
|
||||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
openai_compat_api_base: str | None = None,
|
||||
download_images: bool = False,
|
||||
json_schema_strict: bool = True,
|
||||
|
@ -86,7 +86,7 @@ class LiteLLMOpenAIMixin(
|
|||
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
|
||||
"""
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
||||
self.litellm_provider_name = litellm_provider_name
|
||||
self.api_key_from_config = api_key_from_config
|
||||
|
|
|
@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
|
@ -37,13 +36,6 @@ class ProviderModelEntry(BaseModel):
|
|||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
||||
for model in all_registered_models():
|
||||
if model.descriptor() == model_descriptor:
|
||||
return model.huggingface_repo
|
||||
return None
|
||||
|
||||
|
||||
def build_hf_repo_model_entry(
|
||||
provider_model_id: str,
|
||||
model_descriptor: str,
|
||||
|
@ -63,25 +55,20 @@ def build_hf_repo_model_entry(
|
|||
)
|
||||
|
||||
|
||||
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[],
|
||||
llama_model=model_descriptor,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||
self.model_entries = model_entries
|
||||
def __init__(
|
||||
self,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
allowed_models: list[str] | None = None,
|
||||
):
|
||||
self.allowed_models = allowed_models
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
self.model_entries = model_entries or []
|
||||
for entry in self.model_entries:
|
||||
for alias in entry.aliases:
|
||||
self.alias_to_provider_id_map[alias] = entry.provider_model_id
|
||||
|
||||
|
@ -103,7 +90,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
Model(
|
||||
identifier=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
model_type=entry.model_type,
|
||||
metadata=entry.metadata,
|
||||
provider_id=self.__provider_id__,
|
||||
)
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -22,13 +22,15 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ABC):
|
||||
class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -43,6 +45,24 @@ class OpenAIMixin(ABC):
|
|||
The model_store is set in routing_tables/common.py during provider initialization.
|
||||
"""
|
||||
|
||||
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
|
||||
# is overwritten with a client-side generated id.
|
||||
#
|
||||
# This is useful for providers that do not return a unique id in the response.
|
||||
overwrite_completion_id: bool = False
|
||||
|
||||
# Embedding model metadata for this provider
|
||||
# Can be set by subclasses or instances to provide embedding models
|
||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Cache of available models keyed by model ID
|
||||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -67,6 +87,17 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_extra_client_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get any extra parameters to pass to the AsyncOpenAI client.
|
||||
|
||||
Child classes can override this method to provide additional parameters
|
||||
such as timeout settings, proxies, etc.
|
||||
|
||||
:return: A dictionary of extra parameters
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
|
@ -78,6 +109,7 @@ class OpenAIMixin(ABC):
|
|||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
|
@ -98,6 +130,23 @@ class OpenAIMixin(ABC):
|
|||
raise ValueError(f"Model {model} has no provider_resource_id")
|
||||
return model_obj.provider_resource_id
|
||||
|
||||
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
||||
if not self.overwrite_completion_id:
|
||||
return resp
|
||||
|
||||
new_id = f"cltsd-{uuid.uuid4()}"
|
||||
if stream:
|
||||
|
||||
async def _gen():
|
||||
async for chunk in resp:
|
||||
chunk.id = new_id
|
||||
yield chunk
|
||||
|
||||
return _gen()
|
||||
else:
|
||||
resp.id = new_id
|
||||
return resp
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -124,13 +173,18 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
Direct OpenAI completion API call.
|
||||
"""
|
||||
if guided_choice is not None:
|
||||
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
if prompt_logprobs is not None:
|
||||
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
# Handle parameters that are not supported by OpenAI API, but may be by the provider
|
||||
# prompt_logprobs is supported by vLLM
|
||||
# guided_choice is supported by vLLM
|
||||
# TODO: test coverage
|
||||
extra_body: dict[str, Any] = {}
|
||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||
if guided_choice:
|
||||
extra_body["guided_choice"] = guided_choice
|
||||
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||
resp = await self.client.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
prompt=prompt,
|
||||
|
@ -150,9 +204,12 @@ class OpenAIMixin(ABC):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -182,8 +239,7 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
# Type ignore because return types are compatible
|
||||
return await self.client.chat.completions.create( # type: ignore[no-any-return]
|
||||
resp = await self.client.chat.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
messages=messages,
|
||||
|
@ -211,6 +267,8 @@ class OpenAIMixin(ABC):
|
|||
)
|
||||
)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -247,26 +305,53 @@ class OpenAIMixin(ABC):
|
|||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
||||
Also, caches the models in self._model_cache for use in check_model_availability().
|
||||
|
||||
:return: A list of Model instances representing available models.
|
||||
"""
|
||||
self._model_cache = {}
|
||||
|
||||
async for m in self.client.models.list():
|
||||
if self.allowed_models and m.id not in self.allowed_models:
|
||||
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
|
||||
continue
|
||||
if metadata := self.embedding_model_metadata.get(m.id):
|
||||
# This is an embedding model - augment with metadata
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
# This is an LLM
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
self._model_cache[m.id] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from OpenAI.
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Direct model lookup - returns model or raises NotFoundError
|
||||
await self.client.models.retrieve(model)
|
||||
return True
|
||||
except openai.NotFoundError:
|
||||
# Model doesn't exist - this is expected for unavailable models
|
||||
pass
|
||||
except Exception as e:
|
||||
# All other errors (auth, rate limit, network, etc.)
|
||||
logger.warning(f"Failed to check model availability for {model}: {e}")
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
|
||||
return False
|
||||
return model in self._model_cache
|
||||
|
|
|
@ -294,12 +294,12 @@ class VectorDBWithIndex:
|
|||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
resp = await self.inference_api.embeddings(
|
||||
resp = await self.inference_api.openai_embeddings(
|
||||
self.vector_db.embedding_model,
|
||||
[c.content for c in chunks_to_embed],
|
||||
)
|
||||
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
|
||||
c.embedding = embedding
|
||||
for c, data in zip(chunks_to_embed, resp.data, strict=False):
|
||||
c.embedding = data.embedding
|
||||
|
||||
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
@ -334,8 +334,8 @@ class VectorDBWithIndex:
|
|||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
embeddings_response = await self.inference_api.openai_embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
|||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
||||
self.policy = policy
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
|||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
"openai_responses",
|
||||
where={"id": response_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
|||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found")
|
||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
|||
access control policies, user attribute capture, and SQL filtering optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_store: SqlStore):
|
||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||
"""
|
||||
Initialize the authorization layer.
|
||||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
:param policy: Access control policy to use for authorization
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.policy = policy
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
|||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Fetch all rows with automatic access control filtering."""
|
||||
access_where = self._build_access_control_where_clause(policy)
|
||||
access_where = self._build_access_control_where_clause(self.policy)
|
||||
rows = await self.sql_store.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
|||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
||||
return PaginatedResponse(
|
||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
|||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch one row with automatic access control checking."""
|
||||
results = await self.fetch_all(
|
||||
table=table,
|
||||
policy=policy,
|
||||
where=where,
|
||||
limit=1,
|
||||
order_by=order_by,
|
||||
|
@ -172,6 +171,20 @@ class AuthorizedSqlStore:
|
|||
|
||||
return results.data[0] if results.data else None
|
||||
|
||||
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
|
||||
"""Update rows with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
await self.sql_store.update(table, enhanced_data, where)
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
"""Delete rows with automatic access control filtering."""
|
||||
await self.sql_store.delete(table, where)
|
||||
|
|
|
@ -23,6 +23,7 @@ from sqlalchemy import (
|
|||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -43,6 +44,30 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
|||
}
|
||||
|
||||
|
||||
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||
"""Return a SQLAlchemy expression for a where condition.
|
||||
|
||||
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||
op, operand = next(iter(value.items()))
|
||||
if op == "==" or op == "=":
|
||||
return column == operand
|
||||
if op == ">":
|
||||
return column > operand
|
||||
if op == "<":
|
||||
return column < operand
|
||||
if op == ">=":
|
||||
return column >= operand
|
||||
if op == "<=":
|
||||
return column <= operand
|
||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||
return column == value
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
|
@ -111,7 +136,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(table_obj.c[key] == value)
|
||||
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||
|
||||
if where_sql:
|
||||
query = query.where(text(where_sql))
|
||||
|
@ -222,7 +247,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
|
@ -233,7 +258,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import contextvars
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
@ -18,6 +18,7 @@ from functools import wraps
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
LogSeverity,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
|
@ -75,16 +76,16 @@ def span_id_to_str(span_id: int) -> str:
|
|||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = random.getrandbits(64)
|
||||
span_id = secrets.randbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
span_id = secrets.randbits(64)
|
||||
return span_id_to_str(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = random.getrandbits(128)
|
||||
trace_id = secrets.randbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = random.getrandbits(128)
|
||||
trace_id = secrets.randbits(128)
|
||||
return trace_id_to_str(trace_id)
|
||||
|
||||
|
||||
|
@ -98,7 +99,7 @@ class BackgroundLogger:
|
|||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
self._last_queue_full_log_time: float = 0.0
|
||||
self._dropped_since_last_notice: int = 0
|
||||
|
@ -118,12 +119,16 @@ class BackgroundLogger:
|
|||
self._last_queue_full_log_time = current_time
|
||||
self._dropped_since_last_notice = 0
|
||||
|
||||
def _process_logs(self):
|
||||
def _worker(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._process_logs())
|
||||
|
||||
async def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
# figure out how to use a thread's native loop
|
||||
asyncio.run(self.api.log_event(event))
|
||||
await self.api.log_event(event)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
|
@ -136,6 +141,19 @@ class BackgroundLogger:
|
|||
self.log_queue.join()
|
||||
|
||||
|
||||
def enqueue_event(event: Event) -> None:
|
||||
"""Enqueue a telemetry event to the background logger if available.
|
||||
|
||||
This provides a non-blocking path for routers and other hot paths to
|
||||
submit telemetry without awaiting the Telemetry API, reducing contention
|
||||
with the main event loop.
|
||||
"""
|
||||
global BACKGROUND_LOGGER
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
BACKGROUND_LOGGER.log_event(event)
|
||||
|
||||
|
||||
class TraceContext:
|
||||
spans: list[Span] = []
|
||||
|
||||
|
@ -256,11 +274,7 @@ class TelemetryHandler(logging.Handler):
|
|||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
@ -269,7 +283,7 @@ class TelemetryHandler(logging.Handler):
|
|||
if span is None:
|
||||
return
|
||||
|
||||
BACKGROUND_LOGGER.log_event(
|
||||
enqueue_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
|
|
|
@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
|||
raise AuthenticationRequiredError(exc) from exc
|
||||
if i == len(connection_strategies) - 1:
|
||||
raise
|
||||
except* httpx.ConnectError as eg:
|
||||
# Connection refused, server down, network unreachable
|
||||
if i == len(connection_strategies) - 1:
|
||||
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
||||
logger.error(f"MCP connection error: {error_msg}")
|
||||
raise ConnectionError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* httpx.TimeoutException as eg:
|
||||
# Request timeout, server too slow
|
||||
if i == len(connection_strategies) - 1:
|
||||
error_msg = f"MCP server at {endpoint} timed out"
|
||||
logger.error(f"MCP timeout error: {error_msg}")
|
||||
raise TimeoutError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* httpx.RequestError as eg:
|
||||
# DNS resolution failures, network errors, invalid URLs
|
||||
if i == len(connection_strategies) - 1:
|
||||
# Get the first exception's message for the error string
|
||||
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
||||
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
||||
logger.error(f"MCP network error: {error_msg}")
|
||||
raise ConnectionError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* McpError:
|
||||
if i < len(connection_strategies) - 1:
|
||||
logger.warning(
|
||||
|
|
|
@ -12,14 +12,12 @@ import uuid
|
|||
def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | None = None) -> str:
|
||||
"""
|
||||
Generate a unique chunk ID using a hash of the document ID and chunk text.
|
||||
|
||||
Note: MD5 is used only to calculate an identifier, not for security purposes.
|
||||
Adding usedforsecurity=False for compatibility with FIPS environments.
|
||||
Then use the first 32 characters of the hash to create a UUID.
|
||||
"""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
if chunk_window:
|
||||
hash_input += f":{chunk_window}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input, usedforsecurity=False).hexdigest()))
|
||||
return str(uuid.UUID(hashlib.sha256(hash_input).hexdigest()[:32]))
|
||||
|
||||
|
||||
def proper_case(s: str) -> str:
|
||||
|
@ -37,3 +35,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
|
|||
else:
|
||||
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
|
||||
return s
|
||||
|
||||
|
||||
class WeightedInMemoryAggregator:
|
||||
@staticmethod
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""
|
||||
Normalize scores to 0-1 range using min-max normalization.
|
||||
|
||||
Args:
|
||||
scores: dictionary of scores with document IDs as keys and scores as values
|
||||
|
||||
Returns:
|
||||
Normalized scores with document IDs as keys and normalized scores as values
|
||||
"""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score, max_score = min(scores.values()), max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via weighted average of scores.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
alpha: weight factor between 0 and 1 (default: 0.5)
|
||||
0 = keyword only, 1 = vector only, 0.5 = equal weight
|
||||
|
||||
Returns:
|
||||
All unique document IDs with weighted combined scores
|
||||
"""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores)
|
||||
|
||||
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
|
||||
# alpha=0 means keyword only, alpha=1 means vector only
|
||||
return {
|
||||
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
impact_factor: impact factor for RRF (default: 60.0)
|
||||
|
||||
Returns:
|
||||
All unique document IDs with RRF combined scores
|
||||
"""
|
||||
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
@staticmethod
|
||||
def combine_search_results(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
reranker_type: str = "rrf",
|
||||
reranker_params: dict[str, float] | None = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Combine vector and keyword search results using specified reranking strategy.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
|
||||
reranker_params: parameters for the reranker
|
||||
|
||||
Returns:
|
||||
All unique document IDs with combined scores
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
if reranker_type == "weighted":
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue