mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 00:52:26 +00:00
Merge branch 'main' into update-api-docs
This commit is contained in:
commit
c68f53299e
99 changed files with 2783 additions and 5168 deletions
|
|
@ -518,6 +518,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
|||
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
|
|
@ -543,7 +545,7 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
@ -586,7 +588,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
|
@ -602,7 +604,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -615,7 +617,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from llama_stack.distribution.build import (
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
BuildProvider,
|
||||
DistributionSpec,
|
||||
Provider,
|
||||
StackRunConfig,
|
||||
|
|
@ -94,7 +95,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
provider_list: dict[str, list[Provider]] = dict()
|
||||
provider_list: dict[str, list[BuildProvider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
|
|
@ -113,10 +114,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
if provider_type in providers_for_api:
|
||||
provider = Provider(
|
||||
provider = BuildProvider(
|
||||
provider_type=provider_type,
|
||||
provider_id=provider_type.split("::")[1],
|
||||
config={},
|
||||
module=None,
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
|
|
@ -189,7 +188,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
||||
|
||||
providers: dict[str, list[Provider]] = dict()
|
||||
providers: dict[str, list[BuildProvider]] = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
|
||||
if not available_providers:
|
||||
|
|
@ -204,7 +203,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
),
|
||||
)
|
||||
|
||||
providers[api.value] = api_provider
|
||||
string_providers = api_provider.split(" ")
|
||||
|
||||
for provider in string_providers:
|
||||
providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider))
|
||||
|
||||
description = prompt(
|
||||
"\n > (Optional) Enter a short description for your Llama Stack: ",
|
||||
|
|
@ -307,7 +309,7 @@ def _generate_run_config(
|
|||
providers = build_config.distribution_spec.providers[api]
|
||||
|
||||
for provider in providers:
|
||||
pid = provider.provider_id
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
|
||||
p = provider_registry[Api(api)][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
|
|
|
|||
|
|
@ -18,10 +18,6 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
|
||||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Mount command for cache container .cache, can be overridden by the user if needed
|
||||
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
|
|
@ -176,18 +172,13 @@ RUN pip install uv
|
|||
EOF
|
||||
fi
|
||||
|
||||
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
|
||||
add_to_container << EOF
|
||||
ENV UV_LINK_MODE=copy
|
||||
EOF
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$normal_deps" ]; then
|
||||
read -ra pip_args <<< "$normal_deps"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
@ -197,7 +188,7 @@ if [ -n "$optional_deps" ]; then
|
|||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
|
@ -208,10 +199,10 @@ if [ -n "$external_provider_deps" ]; then
|
|||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
add_to_container <<EOF
|
||||
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
|
||||
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
|
|
@ -293,7 +284,7 @@ COPY $dir $mount_point
|
|||
EOF
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install -e $mount_point
|
||||
RUN uv pip install --no-cache -e $mount_point
|
||||
EOF
|
||||
}
|
||||
|
||||
|
|
@ -308,10 +299,10 @@ else
|
|||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install fastapi libcst
|
||||
RUN uv pip install --no-cache fastapi libcst
|
||||
EOF
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
|
|
@ -323,7 +314,7 @@ EOF
|
|||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
|
||||
RUN uv pip install --no-cache $SPEC_VERSION
|
||||
EOF
|
||||
fi
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -100,11 +100,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
|
||||
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
|
||||
provider_type=provider.provider_type,
|
||||
config={},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -154,13 +154,27 @@ class Provider(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class BuildProvider(BaseModel):
|
||||
provider_type: str
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the external provider module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: str | None = None
|
||||
providers: dict[str, list[Provider]] = Field(
|
||||
providers: dict[str, list[BuildProvider]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
|
|
@ -249,9 +249,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
file=sys.stderr,
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
providers: dict[str, list[BuildProvider]] = {}
|
||||
for api, run_providers in self.config.providers.items():
|
||||
for provider in run_providers:
|
||||
providers.setdefault(api, []).append(
|
||||
BuildProvider(provider_type=provider.provider_type, module=provider.module)
|
||||
)
|
||||
providers = dict(providers)
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=self.config.providers,
|
||||
providers=providers,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
if not (refresh or provider_id in self.listed_providers):
|
||||
refresh = refresh or provider_id not in self.listed_providers
|
||||
if not refresh:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
@ -138,6 +139,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
if model.identifier == model.provider_resource_id:
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
ModelWithOwner(
|
||||
|
|
|
|||
|
|
@ -611,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
if (
|
||||
obj.get("provider_id") == "__disabled__"
|
||||
or obj.get("shield_id") == "__disabled__"
|
||||
or obj.get("provider_model_id") == "__disabled__"
|
||||
):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
|||
|
|
@ -105,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
# In complex templates, like our starter template, we may have dynamic model ids
|
||||
# given by environment variables. This allows those environment variables to have
|
||||
# a default value of __disabled__ to skip registration of the model if not set.
|
||||
if (
|
||||
hasattr(obj, "provider_model_id")
|
||||
and obj.provider_model_id is not None
|
||||
and "__disabled__" in obj.provider_model_id
|
||||
):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
||||
continue
|
||||
|
||||
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
|
|
@ -331,8 +318,10 @@ async def construct_stack(
|
|||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
|
@ -368,11 +357,17 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
|||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry(impls: dict[Api, Any]):
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
|
||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||
logger.info("starting registry refresh task")
|
||||
while True:
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol):
|
|||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
# this should be called `on_model_register` or something like that.
|
||||
# the provider should _not_ be able to change the object in this
|
||||
# callback
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Llama Guard shield must have a model id")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"anthropic/claude-3-5-sonnet-latest",
|
||||
"anthropic/claude-3-7-sonnet-latest",
|
||||
"anthropic/claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
|
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
|
|||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
|
|||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash",
|
||||
"gemini/gemini-2.5-flash",
|
||||
"gemini/gemini-2.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
|
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
provider_model_id="text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
|
|
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
|
|||
|
|
@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
|
|||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-70b-8192",
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
|
|
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
|
|||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.2-3b-preview",
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class OllamaInferenceAdapter(
|
|||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if m.details.family in ["bert"]:
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
|
|
@ -420,9 +420,6 @@ class OllamaInferenceAdapter(
|
|||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
|
|
@ -433,9 +430,9 @@ class OllamaInferenceAdapter(
|
|||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id is None:
|
||||
provider_resource_id = model.provider_resource_id
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
|
|
@ -443,7 +440,9 @@ class OllamaInferenceAdapter(
|
|||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(model.provider_resource_id, available_models)
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
|
|||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
|
|
@ -64,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
||||
Returns the standard OpenAI API base URL for direct OpenAI API calls.
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return "https://api.openai.com/v1"
|
||||
return self.config.base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL}",
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -305,6 +305,8 @@ class _HfAdapter(
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
|
|
|
|||
|
|
@ -27,5 +27,5 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
|||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,15 +69,9 @@ MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
],
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -299,7 +299,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
|
@ -337,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
if not self.config.url:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
|
||||
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
|
|
@ -355,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if self.client is not None:
|
||||
return
|
||||
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
|
||||
)
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -68,11 +68,23 @@ class LiteLLMOpenAIMixin(
|
|||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
openai_compat_api_base: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLMOpenAIMixin.
|
||||
|
||||
:param model_entries: The model entries to register.
|
||||
:param api_key_from_config: The API key to use from the config.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
||||
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||
"""
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
|
||||
self.litellm_provider_name = litellm_provider_name
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
self.api_base = openai_compat_api_base
|
||||
|
|
@ -91,7 +103,11 @@ class LiteLLMOpenAIMixin(
|
|||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||
# model_id.startswith("openai/") is for backwards compatibility.
|
||||
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
|
||||
return (
|
||||
f"{self.litellm_provider_name}/{model_id}"
|
||||
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
|
||||
else model_id
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -421,3 +437,17 @@ class LiteLLMOpenAIMixin(
|
|||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available via LiteLLM for the current
|
||||
provider (self.litellm_provider_name).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if self.litellm_provider_name not in litellm.models_by_provider:
|
||||
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||
return False
|
||||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ def build_hf_repo_model_entry(
|
|||
additional_aliases: list[str] | None = None,
|
||||
) -> ProviderModelEntry:
|
||||
aliases = [
|
||||
get_huggingface_repo(model_descriptor),
|
||||
# NOTE: avoid HF aliases because they _cannot_ be unique across providers
|
||||
# get_huggingface_repo(model_descriptor),
|
||||
]
|
||||
if additional_aliases:
|
||||
aliases.extend(additional_aliases)
|
||||
|
|
@ -75,7 +76,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
__provider_id__: str
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||
self.model_entries = model_entries
|
||||
self.allowed_models = allowed_models
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
|
|
@ -98,7 +101,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
continue
|
||||
models.append(
|
||||
Model(
|
||||
model_id=id,
|
||||
identifier=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
metadata=entry.metadata,
|
||||
|
|
@ -185,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
# TODO: should we block unregistering base supported provider model IDs?
|
||||
if model_id not in self.alias_to_provider_id_map:
|
||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
||||
|
||||
del self.alias_to_provider_id_map[model_id]
|
||||
# model_id is the identifier, not the provider_resource_id
|
||||
# unfortunately, this ID can be of the form provider_id/model_id which
|
||||
# we never registered. TODO: fix this by significantly rewriting
|
||||
# registration and registry helper
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -3,96 +3,50 @@ distribution_spec:
|
|||
description: CI tests for Llama Stack
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_type: remote::cerebras
|
||||
- provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
||||
provider_type: remote::ollama
|
||||
- provider_id: ${env.ENABLE_VLLM:=__disabled__}
|
||||
provider_type: remote::vllm
|
||||
- provider_id: ${env.ENABLE_TGI:=__disabled__}
|
||||
provider_type: remote::tgi
|
||||
- provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__}
|
||||
provider_type: remote::hf::serverless
|
||||
- provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__}
|
||||
provider_type: remote::hf::endpoint
|
||||
- provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_type: remote::fireworks
|
||||
- provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_type: remote::together
|
||||
- provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_type: remote::bedrock
|
||||
- provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
|
||||
provider_type: remote::databricks
|
||||
- provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_type: remote::nvidia
|
||||
- provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_type: remote::runpod
|
||||
- provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_type: remote::openai
|
||||
- provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__}
|
||||
provider_type: remote::anthropic
|
||||
- provider_id: ${env.ENABLE_GEMINI:=__disabled__}
|
||||
provider_type: remote::gemini
|
||||
- provider_id: ${env.ENABLE_GROQ:=__disabled__}
|
||||
provider_type: remote::groq
|
||||
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::llama-openai-compat
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_type: remote::sambanova
|
||||
- provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__}
|
||||
provider_type: remote::passthrough
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::cerebras
|
||||
- provider_type: remote::ollama
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: remote::fireworks
|
||||
- provider_type: remote::together
|
||||
- provider_type: remote::bedrock
|
||||
- provider_type: remote::nvidia
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_FAISS:=faiss}
|
||||
provider_type: inline::faiss
|
||||
- provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__}
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: ${env.ENABLE_MILVUS:=__disabled__}
|
||||
provider_type: inline::milvus
|
||||
- provider_id: ${env.ENABLE_CHROMADB:=__disabled__}
|
||||
provider_type: remote::chromadb
|
||||
- provider_id: ${env.ENABLE_PGVECTOR:=__disabled__}
|
||||
provider_type: remote::pgvector
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: inline::sqlite-vec
|
||||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
files:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
- provider_type: inline::huggingface
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: conda
|
||||
image_name: ci-tests
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,48 +4,31 @@ distribution_spec:
|
|||
container
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: tgi
|
||||
provider_type: remote::tgi
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_id: chromadb
|
||||
provider_type: remote::chromadb
|
||||
- provider_id: pgvector
|
||||
provider_type: remote::pgvector
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
image_type: conda
|
||||
image_name: dell
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
|
|
@ -20,31 +21,31 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [
|
||||
Provider(provider_id="tgi", provider_type="remote::tgi"),
|
||||
Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"),
|
||||
BuildProvider(provider_type="remote::tgi"),
|
||||
BuildProvider(provider_type="inline::sentence-transformers"),
|
||||
],
|
||||
"vector_io": [
|
||||
Provider(provider_id="faiss", provider_type="inline::faiss"),
|
||||
Provider(provider_id="chromadb", provider_type="remote::chromadb"),
|
||||
Provider(provider_id="pgvector", provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")],
|
||||
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
Provider(provider_id="huggingface", provider_type="remote::huggingface"),
|
||||
Provider(provider_id="localfs", provider_type="inline::localfs"),
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
Provider(provider_id="basic", provider_type="inline::basic"),
|
||||
Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"),
|
||||
Provider(provider_id="braintrust", provider_type="inline::braintrust"),
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
Provider(provider_id="brave-search", provider_type="remote::brave-search"),
|
||||
Provider(provider_id="tavily-search", provider_type="remote::tavily-search"),
|
||||
Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
],
|
||||
}
|
||||
name = "dell"
|
||||
|
|
|
|||
|
|
@ -3,48 +3,31 @@ distribution_spec:
|
|||
description: Use Meta Reference for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_id: chromadb
|
||||
provider_type: remote::chromadb
|
||||
- provider_id: pgvector
|
||||
provider_type: remote::pgvector
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: conda
|
||||
image_name: meta-reference-gpu
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
|
|
@ -25,91 +26,30 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"inference": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
),
|
||||
Provider(
|
||||
provider_id="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
),
|
||||
Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
),
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
)
|
||||
],
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"telemetry": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"eval": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
),
|
||||
Provider(
|
||||
provider_id="localfs",
|
||||
provider_type="inline::localfs",
|
||||
),
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
Provider(
|
||||
provider_id="basic",
|
||||
provider_type="inline::basic",
|
||||
),
|
||||
Provider(
|
||||
provider_id="llm-as-judge",
|
||||
provider_type="inline::llm-as-judge",
|
||||
),
|
||||
Provider(
|
||||
provider_id="braintrust",
|
||||
provider_type="inline::braintrust",
|
||||
),
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
Provider(
|
||||
provider_id="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
),
|
||||
Provider(
|
||||
provider_id="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
),
|
||||
Provider(
|
||||
provider_id="rag-runtime",
|
||||
provider_type="inline::rag-runtime",
|
||||
),
|
||||
Provider(
|
||||
provider_id="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
),
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
name = "meta-reference-gpu"
|
||||
|
|
|
|||
|
|
@ -3,37 +3,26 @@ distribution_spec:
|
|||
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
- provider_type: remote::nvidia
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_type: inline::faiss
|
||||
safety:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
- provider_type: remote::nvidia
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
- provider_type: remote::nvidia
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
- provider_type: remote::nvidia
|
||||
datasetio:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
- provider_type: inline::localfs
|
||||
- provider_type: remote::nvidia
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_type: inline::basic
|
||||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_type: inline::rag-runtime
|
||||
image_type: conda
|
||||
image_name: nvidia
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
|
|
@ -17,65 +17,19 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
)
|
||||
],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
)
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
)
|
||||
],
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"telemetry": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"eval": [
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
)
|
||||
],
|
||||
"post_training": [Provider(provider_id="nvidia", provider_type="remote::nvidia", config={})],
|
||||
"inference": [BuildProvider(provider_type="remote::nvidia")],
|
||||
"vector_io": [BuildProvider(provider_type="inline::faiss")],
|
||||
"safety": [BuildProvider(provider_type="remote::nvidia")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="remote::nvidia")],
|
||||
"post_training": [BuildProvider(provider_type="remote::nvidia")],
|
||||
"datasetio": [
|
||||
Provider(
|
||||
provider_id="localfs",
|
||||
provider_type="inline::localfs",
|
||||
),
|
||||
Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
),
|
||||
],
|
||||
"scoring": [
|
||||
Provider(
|
||||
provider_id="basic",
|
||||
provider_type="inline::basic",
|
||||
)
|
||||
],
|
||||
"tool_runtime": [
|
||||
Provider(
|
||||
provider_id="rag-runtime",
|
||||
provider_type="inline::rag-runtime",
|
||||
)
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
BuildProvider(provider_type="remote::nvidia"),
|
||||
],
|
||||
"scoring": [BuildProvider(provider_type="inline::basic")],
|
||||
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
|
|
|
|||
|
|
@ -89,101 +89,51 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-405b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-1b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-3b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-11b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-90b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
|
|
|
|||
|
|
@ -3,56 +3,35 @@ distribution_spec:
|
|||
description: Distribution for running open benchmarks
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::together
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: chromadb
|
||||
provider_type: remote::chromadb
|
||||
- provider_id: pgvector
|
||||
provider_type: remote::pgvector
|
||||
- provider_type: inline::sqlite-vec
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: conda
|
||||
image_name: open-benchmark
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
|
|||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BenchmarkInput,
|
||||
BuildProvider,
|
||||
DatasetInput,
|
||||
ModelInput,
|
||||
Provider,
|
||||
|
|
@ -96,33 +97,30 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, available_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": inference_providers,
|
||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in inference_providers],
|
||||
"vector_io": [
|
||||
Provider(provider_id="sqlite-vec", provider_type="inline::sqlite-vec"),
|
||||
Provider(provider_id="chromadb", provider_type="remote::chromadb"),
|
||||
Provider(provider_id="pgvector", provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")],
|
||||
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
Provider(provider_id="huggingface", provider_type="remote::huggingface"),
|
||||
Provider(provider_id="localfs", provider_type="inline::localfs"),
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
Provider(provider_id="basic", provider_type="inline::basic"),
|
||||
Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"),
|
||||
Provider(provider_id="braintrust", provider_type="inline::braintrust"),
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
Provider(provider_id="brave-search", provider_type="remote::brave-search"),
|
||||
Provider(provider_id="tavily-search", provider_type="remote::tavily-search"),
|
||||
Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"),
|
||||
Provider(
|
||||
provider_id="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
),
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
name = "open-benchmark"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ providers:
|
|||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
|
|
@ -33,7 +34,7 @@ providers:
|
|||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
|
|
|
|||
|
|
@ -3,31 +3,21 @@ distribution_spec:
|
|||
description: Quick start template for running Llama Stack with several popular providers
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: chromadb
|
||||
provider_type: remote::chromadb
|
||||
- provider_type: remote::chromadb
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: conda
|
||||
image_name: postgres-demo
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
|
|
@ -34,24 +35,19 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
providers = {
|
||||
"inference": inference_providers
|
||||
+ [
|
||||
Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"),
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::vllm"),
|
||||
BuildProvider(provider_type="inline::sentence-transformers"),
|
||||
],
|
||||
"vector_io": [
|
||||
Provider(provider_id="chromadb", provider_type="remote::chromadb"),
|
||||
],
|
||||
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")],
|
||||
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
|
||||
"vector_io": [BuildProvider(provider_type="remote::chromadb")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"tool_runtime": [
|
||||
Provider(provider_id="brave-search", provider_type="remote::brave-search"),
|
||||
Provider(provider_id="tavily-search", provider_type="remote::tavily-search"),
|
||||
Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"),
|
||||
Provider(
|
||||
provider_id="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
),
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
name = "postgres-demo"
|
||||
|
|
|
|||
|
|
@ -3,96 +3,50 @@ distribution_spec:
|
|||
description: Quick start template for running Llama Stack with several popular providers
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_type: remote::cerebras
|
||||
- provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
||||
provider_type: remote::ollama
|
||||
- provider_id: ${env.ENABLE_VLLM:=__disabled__}
|
||||
provider_type: remote::vllm
|
||||
- provider_id: ${env.ENABLE_TGI:=__disabled__}
|
||||
provider_type: remote::tgi
|
||||
- provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__}
|
||||
provider_type: remote::hf::serverless
|
||||
- provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__}
|
||||
provider_type: remote::hf::endpoint
|
||||
- provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_type: remote::fireworks
|
||||
- provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_type: remote::together
|
||||
- provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_type: remote::bedrock
|
||||
- provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
|
||||
provider_type: remote::databricks
|
||||
- provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_type: remote::nvidia
|
||||
- provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_type: remote::runpod
|
||||
- provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_type: remote::openai
|
||||
- provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__}
|
||||
provider_type: remote::anthropic
|
||||
- provider_id: ${env.ENABLE_GEMINI:=__disabled__}
|
||||
provider_type: remote::gemini
|
||||
- provider_id: ${env.ENABLE_GROQ:=__disabled__}
|
||||
provider_type: remote::groq
|
||||
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::llama-openai-compat
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_type: remote::sambanova
|
||||
- provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__}
|
||||
provider_type: remote::passthrough
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::cerebras
|
||||
- provider_type: remote::ollama
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: remote::fireworks
|
||||
- provider_type: remote::together
|
||||
- provider_type: remote::bedrock
|
||||
- provider_type: remote::nvidia
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_FAISS:=faiss}
|
||||
provider_type: inline::faiss
|
||||
- provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__}
|
||||
provider_type: inline::sqlite-vec
|
||||
- provider_id: ${env.ENABLE_MILVUS:=__disabled__}
|
||||
provider_type: inline::milvus
|
||||
- provider_id: ${env.ENABLE_CHROMADB:=__disabled__}
|
||||
provider_type: remote::chromadb
|
||||
- provider_id: ${env.ENABLE_PGVECTOR:=__disabled__}
|
||||
provider_type: remote::pgvector
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: inline::sqlite-vec
|
||||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
files:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
- provider_type: inline::huggingface
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: conda
|
||||
image_name: starter
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -7,19 +7,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
BuildProvider,
|
||||
Provider,
|
||||
ProviderSpec,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
||||
MilvusVectorIOConfig,
|
||||
|
|
@ -28,117 +28,17 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
|||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.registry.inference import available_providers
|
||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.gemini.models import (
|
||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.models import (
|
||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.openai.models import (
|
||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.together.models import (
|
||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
get_shield_registry,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
model_entries_map = {
|
||||
"openai": OPENAI_MODEL_ENTRIES,
|
||||
"fireworks": FIREWORKS_MODEL_ENTRIES,
|
||||
"together": TOGETHER_MODEL_ENTRIES,
|
||||
"anthropic": ANTHROPIC_MODEL_ENTRIES,
|
||||
"gemini": GEMINI_MODEL_ENTRIES,
|
||||
"groq": GROQ_MODEL_ENTRIES,
|
||||
"sambanova": SAMBANOVA_MODEL_ENTRIES,
|
||||
"cerebras": CEREBRAS_MODEL_ENTRIES,
|
||||
"bedrock": BEDROCK_MODEL_ENTRIES,
|
||||
"databricks": DATABRICKS_MODEL_ENTRIES,
|
||||
"nvidia": NVIDIA_MODEL_ENTRIES,
|
||||
"runpod": RUNPOD_MODEL_ENTRIES,
|
||||
}
|
||||
|
||||
# Special handling for providers with dynamic model entries
|
||||
if provider_type == "ollama":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
||||
},
|
||||
),
|
||||
]
|
||||
elif provider_type == "vllm":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
|
||||
return model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
safety_model_entries_map = {
|
||||
"ollama": [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
return safety_model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||
"""Get configuration for a provider using its adapter's config class."""
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
|
@ -149,40 +49,48 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
|||
return {}
|
||||
|
||||
|
||||
def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
all_providers = available_providers()
|
||||
ENABLED_INFERENCE_PROVIDERS = [
|
||||
"ollama",
|
||||
"vllm",
|
||||
"tgi",
|
||||
"fireworks",
|
||||
"together",
|
||||
"gemini",
|
||||
"groq",
|
||||
"sambanova",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"cerebras",
|
||||
"nvidia",
|
||||
"bedrock",
|
||||
]
|
||||
|
||||
# Filter out inline providers and watsonx - the starter distro only exposes remote providers
|
||||
INFERENCE_PROVIDER_IDS = {
|
||||
"vllm": "${env.VLLM_URL:+vllm}",
|
||||
"tgi": "${env.TGI_URL:+tgi}",
|
||||
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
|
||||
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
|
||||
}
|
||||
|
||||
|
||||
def get_remote_inference_providers() -> list[Provider]:
|
||||
# Filter out inline providers and some others - the starter distro only exposes remote providers
|
||||
remote_providers = [
|
||||
provider
|
||||
for provider in all_providers
|
||||
# TODO: re-add once the Python 3.13 issue is fixed
|
||||
# discussion: https://github.com/meta-llama/llama-stack/pull/2327#discussion_r2156883828
|
||||
if hasattr(provider, "adapter") and provider.adapter.adapter_type != "watsonx"
|
||||
for provider in available_providers()
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
]
|
||||
|
||||
providers = []
|
||||
available_models = {}
|
||||
|
||||
inference_providers = []
|
||||
for provider_spec in remote_providers:
|
||||
provider_type = provider_spec.adapter.adapter_type
|
||||
|
||||
# Build the environment variable name for enabling this provider
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
model_entries = _get_model_entries_for_provider(provider_type)
|
||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||
else:
|
||||
provider_id = provider_type.replace("-", "_").replace("::", "_")
|
||||
config = _get_config_for_provider(provider_spec)
|
||||
providers.append(
|
||||
(
|
||||
f"${{env.{env_var}:=__disabled__}}",
|
||||
provider_type,
|
||||
model_entries,
|
||||
config,
|
||||
)
|
||||
)
|
||||
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
|
||||
|
||||
inference_providers = []
|
||||
for provider_id, provider_type, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
|
|
@ -190,154 +98,43 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
|||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
# build a list of shields for all possible providers
|
||||
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||
available_models = {}
|
||||
for provider in providers:
|
||||
provider_type = provider.provider_type.split("::")[1]
|
||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||
if len(safety_model_entries) == 0:
|
||||
continue
|
||||
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||
|
||||
available_models[provider_id] = safety_model_entries
|
||||
|
||||
return available_models
|
||||
return inference_providers
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
remote_inference_providers, available_models = get_remote_inference_providers()
|
||||
|
||||
remote_inference_providers = get_remote_inference_providers()
|
||||
name = "starter"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_FAISS:=faiss}",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
providers = {
|
||||
"inference": remote_inference_providers
|
||||
+ [
|
||||
Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
)
|
||||
],
|
||||
"vector_io": vector_io_providers,
|
||||
"files": [
|
||||
Provider(
|
||||
provider_id="localfs",
|
||||
provider_type="inline::localfs",
|
||||
)
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
)
|
||||
],
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"telemetry": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"post_training": [
|
||||
Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
)
|
||||
],
|
||||
"eval": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||
+ [BuildProvider(provider_type="inline::sentence-transformers")],
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
BuildProvider(provider_type="inline::milvus"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
),
|
||||
Provider(
|
||||
provider_id="localfs",
|
||||
provider_type="inline::localfs",
|
||||
),
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
Provider(
|
||||
provider_id="basic",
|
||||
provider_type="inline::basic",
|
||||
),
|
||||
Provider(
|
||||
provider_id="llm-as-judge",
|
||||
provider_type="inline::llm-as-judge",
|
||||
),
|
||||
Provider(
|
||||
provider_id="braintrust",
|
||||
provider_type="inline::braintrust",
|
||||
),
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
Provider(
|
||||
provider_id="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
),
|
||||
Provider(
|
||||
provider_id="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
),
|
||||
Provider(
|
||||
provider_id="rag-runtime",
|
||||
provider_type="inline::rag-runtime",
|
||||
),
|
||||
Provider(
|
||||
provider_id="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
),
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
files_provider = Provider(
|
||||
|
|
@ -346,15 +143,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
post_training_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
|
@ -365,19 +157,14 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id=embedding_provider.provider_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||
|
||||
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||
default_shields = [
|
||||
# if the
|
||||
ShieldInput(
|
||||
shield_id="llama-guard",
|
||||
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
@ -386,20 +173,51 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": vector_io_providers,
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.MILVUS_URL:+milvus}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
"post_training": [post_training_provider],
|
||||
},
|
||||
default_models=[embedding_model] + default_models,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
# TODO: add a way to enable/disable shields on the fly
|
||||
default_shields=shields,
|
||||
default_shields=default_shields,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
@ -443,17 +261,5 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"http://localhost:11434",
|
||||
"Ollama URL",
|
||||
),
|
||||
"OLLAMA_INFERENCE_MODEL": (
|
||||
"",
|
||||
"Optional Ollama Inference Model to register on startup",
|
||||
),
|
||||
"OLLAMA_EMBEDDING_MODEL": (
|
||||
"",
|
||||
"Optional Ollama Embedding Model to register on startup",
|
||||
),
|
||||
"OLLAMA_EMBEDDING_DIMENSION": (
|
||||
"384",
|
||||
"Ollama Embedding Dimension",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from llama_stack.distribution.datatypes import (
|
|||
Api,
|
||||
BenchmarkInput,
|
||||
BuildConfig,
|
||||
BuildProvider,
|
||||
DatasetInput,
|
||||
DistributionSpec,
|
||||
ModelInput,
|
||||
|
|
@ -183,7 +184,7 @@ class RunConfigSettings(BaseModel):
|
|||
def run_config(
|
||||
self,
|
||||
name: str,
|
||||
providers: dict[str, list[Provider]],
|
||||
providers: dict[str, list[BuildProvider]],
|
||||
container_image: str | None = None,
|
||||
) -> dict:
|
||||
provider_registry = get_provider_registry()
|
||||
|
|
@ -199,7 +200,7 @@ class RunConfigSettings(BaseModel):
|
|||
api = Api(api_str)
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
|
||||
|
||||
provider_id = provider.provider_type.split("::")[-1]
|
||||
config_class = provider_registry[api][provider.provider_type].config_class
|
||||
assert config_class is not None, (
|
||||
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
|
||||
|
|
@ -210,10 +211,14 @@ class RunConfigSettings(BaseModel):
|
|||
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
|
||||
else:
|
||||
config = {}
|
||||
|
||||
provider.config = config
|
||||
# Convert Provider object to dict for YAML serialization
|
||||
provider_configs[api_str].append(provider.model_dump(exclude_none=True))
|
||||
# BuildProvider does not have a config attribute; skip assignment
|
||||
provider_configs[api_str].append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=config,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
# Get unique set of APIs from providers
|
||||
apis = sorted(providers.keys())
|
||||
|
||||
|
|
@ -257,7 +262,8 @@ class DistributionTemplate(BaseModel):
|
|||
description: str
|
||||
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
||||
|
||||
providers: dict[str, list[Provider]]
|
||||
# Now uses BuildProvider for build config, not Provider
|
||||
providers: dict[str, list[BuildProvider]]
|
||||
run_configs: dict[str, RunConfigSettings]
|
||||
template_path: Path | None = None
|
||||
|
||||
|
|
@ -295,11 +301,9 @@ class DistributionTemplate(BaseModel):
|
|||
for api, providers in self.providers.items():
|
||||
build_providers[api] = []
|
||||
for provider in providers:
|
||||
# Create a minimal provider object with only essential build information
|
||||
build_provider = Provider(
|
||||
provider_id=provider.provider_id,
|
||||
# Create a minimal build provider object with only essential build information
|
||||
build_provider = BuildProvider(
|
||||
provider_type=provider.provider_type,
|
||||
config={}, # Empty config for build
|
||||
module=provider.module,
|
||||
)
|
||||
build_providers[api].append(build_provider)
|
||||
|
|
@ -323,50 +327,52 @@ class DistributionTemplate(BaseModel):
|
|||
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
|
||||
providers_table += f"| {api} | {providers_str} |\n"
|
||||
|
||||
template = self.template_path.read_text()
|
||||
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
|
||||
orphantext = "---\norphan: true\n---\n"
|
||||
if self.template_path is not None:
|
||||
template = self.template_path.read_text()
|
||||
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
|
||||
orphantext = "---\norphan: true\n---\n"
|
||||
|
||||
if template.startswith(orphantext):
|
||||
template = template.replace(orphantext, orphantext + comment)
|
||||
else:
|
||||
template = comment + template
|
||||
if template.startswith(orphantext):
|
||||
template = template.replace(orphantext, orphantext + comment)
|
||||
else:
|
||||
template = comment + template
|
||||
|
||||
# Render template with rich-generated table
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
# NOTE: autoescape is required to prevent XSS attacks
|
||||
autoescape=True,
|
||||
)
|
||||
template = env.from_string(template)
|
||||
# Render template with rich-generated table
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
# NOTE: autoescape is required to prevent XSS attacks
|
||||
autoescape=True,
|
||||
)
|
||||
template = env.from_string(template)
|
||||
|
||||
default_models = []
|
||||
if self.available_models_by_provider:
|
||||
has_multiple_providers = len(self.available_models_by_provider.keys()) > 1
|
||||
for provider_id, model_entries in self.available_models_by_provider.items():
|
||||
for model_entry in model_entries:
|
||||
doc_parts = []
|
||||
if model_entry.aliases:
|
||||
doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}")
|
||||
if has_multiple_providers:
|
||||
doc_parts.append(f"provider: {provider_id}")
|
||||
default_models = []
|
||||
if self.available_models_by_provider:
|
||||
has_multiple_providers = len(self.available_models_by_provider.keys()) > 1
|
||||
for provider_id, model_entries in self.available_models_by_provider.items():
|
||||
for model_entry in model_entries:
|
||||
doc_parts = []
|
||||
if model_entry.aliases:
|
||||
doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}")
|
||||
if has_multiple_providers:
|
||||
doc_parts.append(f"provider: {provider_id}")
|
||||
|
||||
default_models.append(
|
||||
DefaultModel(
|
||||
model_id=model_entry.provider_model_id,
|
||||
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||
default_models.append(
|
||||
DefaultModel(
|
||||
model_id=model_entry.provider_model_id,
|
||||
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return template.render(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
providers=self.providers,
|
||||
providers_table=providers_table,
|
||||
run_config_env_vars=self.run_config_env_vars,
|
||||
default_models=default_models,
|
||||
)
|
||||
return template.render(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
providers=self.providers,
|
||||
providers_table=providers_table,
|
||||
run_config_env_vars=self.run_config_env_vars,
|
||||
default_models=default_models,
|
||||
)
|
||||
return ""
|
||||
|
||||
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
||||
def enum_representer(dumper, data):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
|
@ -19,86 +19,28 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
),
|
||||
Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
),
|
||||
],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
)
|
||||
],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
)
|
||||
],
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"telemetry": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
],
|
||||
"eval": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
)
|
||||
BuildProvider(provider_type="remote::watsonx"),
|
||||
BuildProvider(provider_type="inline::sentence-transformers"),
|
||||
],
|
||||
"vector_io": [BuildProvider(provider_type="inline::faiss")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
),
|
||||
Provider(
|
||||
provider_id="localfs",
|
||||
provider_type="inline::localfs",
|
||||
),
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
Provider(
|
||||
provider_id="basic",
|
||||
provider_type="inline::basic",
|
||||
),
|
||||
Provider(
|
||||
provider_id="llm-as-judge",
|
||||
provider_type="inline::llm-as-judge",
|
||||
),
|
||||
Provider(
|
||||
provider_id="braintrust",
|
||||
provider_type="inline::braintrust",
|
||||
),
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
Provider(
|
||||
provider_id="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
),
|
||||
Provider(
|
||||
provider_id="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
),
|
||||
Provider(
|
||||
provider_id="rag-runtime",
|
||||
provider_type="inline::rag-runtime",
|
||||
),
|
||||
Provider(
|
||||
provider_id="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
),
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue