This commit is contained in:
Xi Yan 2025-03-18 21:49:11 -07:00
parent 011fd59a29
commit 8a576d7d72
24 changed files with 297 additions and 2525 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -187,9 +187,7 @@ a default SQLite store will be used.""",
benchmarks: List[BenchmarkInput] = Field(default_factory=list) benchmarks: List[BenchmarkInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list) tool_groups: List[ToolGroupInput] = Field(default_factory=list)
logging: Optional[LoggingConfig] = Field( logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
default=None, description="Configuration for Llama Stack Logging"
)
server: ServerConfig = Field( server: ServerConfig = Field(
default_factory=ServerConfig, default_factory=ServerConfig,
@ -200,9 +198,7 @@ a default SQLite store will be used.""",
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field( distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
description="The distribution spec to build including API providers. "
)
image_type: str = Field( image_type: str = Field(
default="conda", default="conda",
description="Type of package to build (conda | container | venv)", description="Type of package to build (conda | container | venv)",

View file

@ -47,14 +47,9 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]: def providable_apis() -> List[Api]:
routing_table_apis = { routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
x.routing_table_api for x in builtin_automatically_routed_apis()
}
return [ return [
api api for api in Api if api not in routing_table_apis and api not in [Api.inspect, Api.providers, Api.benchmarks]
for api in Api
if api not in routing_table_apis
and api not in [Api.inspect, Api.providers, Api.benchmarks]
] ]

View file

@ -103,9 +103,7 @@ async def resolve_impls(
2. Sorting them in dependency order. 2. Sorting them in dependency order.
3. Instantiating them with required dependencies. 3. Instantiating them with required dependencies.
""" """
routing_table_apis = { routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
x.routing_table_api for x in builtin_automatically_routed_apis()
}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()} router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers( providers_with_specs = validate_and_prepare_providers(
@ -113,9 +111,7 @@ async def resolve_impls(
) )
apis_to_serve = run_config.apis or set( apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
) )
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve)) providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
@ -180,23 +176,17 @@ def validate_and_prepare_providers(
for api_str, providers in run_config.providers.items(): for api_str, providers in run_config.providers.items():
api = Api(api_str) api = Api(api_str)
if api in routing_table_apis: if api in routing_table_apis:
raise ValueError( raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__": if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning( logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
)
continue continue
validate_provider(provider, api, provider_registry) validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [ p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
a.value for a in p.optional_api_dependencies
]
spec = ProviderWithSpec(spec=p, **provider.model_dump()) spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec specs[provider.provider_id] = spec
@ -206,14 +196,10 @@ def validate_and_prepare_providers(
return providers_with_specs return providers_with_specs
def validate_provider( def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
provider: Provider, api: Api, provider_registry: ProviderRegistry
):
"""Validates if the provider is allowed and handles deprecations.""" """Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]: if provider.provider_type not in provider_registry[api]:
raise ValueError( raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
@ -288,9 +274,7 @@ async def instantiate_providers(
) -> Dict: ) -> Dict:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {} impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = { inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
f"inner-{x.value}": {} for x in router_apis
}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies: for a in provider.spec.optional_api_dependencies:
@ -299,9 +283,7 @@ async def instantiate_providers(
inner_impls = {} inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[ inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
@ -359,9 +341,7 @@ async def instantiate_provider(
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module"):
raise AttributeError( raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
)
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
args = [] args = []
@ -398,10 +378,7 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups # TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api]) check_protocol_compliance(impl, protocols[provider_spec.api])
if ( if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _, _ = additional_protocols[provider_spec.api] additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api) check_protocol_compliance(impl, additional_api)
@ -429,19 +406,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
logger.error( logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class
method_owner = next( method_owner = next((cls for cls in mro if name in cls.__dict__), None)
(cls for cls in mro if name in cls.__dict__), None if method_owner is None or method_owner.__name__ == protocol.__name__:
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented")) missing_methods.append((name, "not_actually_implemented"))
if missing_methods: if missing_methods:

View file

@ -44,9 +44,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
api: Api, routing_table: RoutingTable, deps: Dict[str, Any]
) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,

View file

@ -8,19 +8,12 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -42,12 +35,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -94,9 +81,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug( logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -114,9 +99,7 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -125,9 +108,7 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -164,9 +145,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model( await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -220,16 +199,11 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> List[MetricInResponse]:
metrics = self._construct_metrics( metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -254,9 +228,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -266,19 +238,12 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if ( if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
tool_prompt_format raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -296,14 +261,9 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [ tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError( raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -318,25 +278,17 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens( prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
messages, tool_config.tool_prompt_format
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
if ( if chunk.event.event_type == ChatCompletionResponseEventType.progress:
chunk.event.event_type
== ChatCompletionResponseEventType.progress
):
if chunk.event.delta.type == "text": if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if ( if chunk.event.event_type == ChatCompletionResponseEventType.complete:
chunk.event.event_type
== ChatCompletionResponseEventType.complete
):
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[ [
CompletionMessage( CompletionMessage(
@ -353,11 +305,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -374,9 +322,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def completion( async def completion(
@ -397,9 +343,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -419,11 +363,7 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params): async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if ( if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( metrics = await self._compute_and_log_token_usage(
@ -432,11 +372,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -450,9 +386,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def embeddings( async def embeddings(
@ -468,9 +402,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError( raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -504,9 +436,7 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield( return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -607,9 +537,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
"insert_into_memory" documents, vector_db_id, chunk_size_in_tokens
).insert(documents, vector_db_id, chunk_size_in_tokens) )
def __init__( def __init__(
self, self,
@ -642,6 +572,4 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools( return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
tool_group_id, mcp_endpoint
)

View file

@ -12,7 +12,6 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import ( from llama_stack.apis.datasets import (
Dataset, Dataset,
DatasetPurpose, DatasetPurpose,
@ -95,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects( async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
@ -126,9 +123,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl( def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
@ -164,9 +159,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier( async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry # Get from disk registry
obj = await self.dist_registry.get(type, identifier) obj = await self.dist_registry.get(type, identifier)
if not obj: if not obj:
@ -176,13 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider( await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object( async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries # if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0: if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0] obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -240,9 +229,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None: if model_type is None:
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError( raise ValueError("Embedding model must have an embedding dimension in its metadata")
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
@ -262,9 +249,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse( return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Shield: async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier) shield = await self.get_object_by_identifier("shield", identifier)
@ -329,18 +314,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
) )
else: else:
raise ValueError( raise ValueError("No provider available. Please configure a vector_io provider.")
"No provider available. Please configure a vector_io provider."
)
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError( raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
f"Model {embedding_model} does not have an embedding dimension"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_db_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
@ -362,9 +343,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse( return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Dataset: async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id) dataset = await self.get_object_by_identifier("dataset", dataset_id)
@ -447,9 +426,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
# TODO (xiyan): we will need a way to infer provider_id for evaluation # TODO (xiyan): we will need a way to infer provider_id for evaluation
# keep it as meta-reference for now # keep it as meta-reference for now
if len(self.impls_by_provider_id) == 0: if len(self.impls_by_provider_id) == 0:
raise ValueError( raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
"No evaluation providers available. Please configure an evaluation provider."
)
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
benchmark = Benchmark( benchmark = Benchmark(
@ -491,12 +468,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
toolgroup_id, mcp_endpoint tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(

View file

@ -105,9 +105,7 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name self.var_name = var_name
self.path = path self.path = path
super().__init__( super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -198,9 +196,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key: if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}") raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key): if not all(c.isalnum() or c == "_" for c in key):
raise ValueError( raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value return key, value
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
@ -213,20 +209,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack( async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry( dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
run_config.metadata_store, run_config.image_name impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig: def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = ( template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
if not path.exists(): if not path.exists():
@ -269,9 +259,7 @@ def run_config_from_adhoc_config_spec(
# call method "sample_run_config" on the provider spec config class # call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class) provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars( provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_config_type.sample_run_config(__distro_dir__=distro_dir)
)
provider_configs_by_api[api_str] = [ provider_configs_by_api[api_str] = [
Provider( Provider(

View file

@ -22,9 +22,7 @@ class LlamaStackApi:
}, },
) )
def run_scoring( def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
"""Run scoring on a single row""" """Run scoring on a single row"""
if not scoring_params: if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids} scoring_params = {fn_id: None for fn_id in scoring_function_ids}

View file

@ -4,12 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from streamlit_option_menu import option_menu
from llama_stack.distribution.ui.page.distribution.datasets import datasets from llama_stack.distribution.ui.page.distribution.datasets import datasets
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.distribution.ui.page.distribution.models import models from llama_stack.distribution.ui.page.distribution.models import models
from llama_stack.distribution.ui.page.distribution.shields import shields from llama_stack.distribution.ui.page.distribution.shields import shields
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu
def resources_page(): def resources_page():

View file

@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="context-entity-recall", provider_resource_id="context-entity-recall",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams( params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.eval,
provider_type="inline::meta-reference",
pip_packages=["tree_sitter"],
module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
],
),
]

View file

@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::basic",
pip_packages=[],
module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::llm-as-judge",
pip_packages=[],
module="llama_stack.providers.inline.scoring.llm_as_judge",
config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.inference,
],
),
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::braintrust",
pip_packages=["autoevals", "openai"],
module="llama_stack.providers.inline.scoring.braintrust",
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]

View file

@ -75,29 +75,31 @@ VALID_SCHEMAS_FOR_EVAL = [
] ]
def get_valid_schemas(api_str: str): # TODO(xiyan): add this back
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING # def get_valid_schemas(api_str: str):
elif api_str == Api.eval.value: # if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_EVAL # return VALID_SCHEMAS_FOR_SCORING
else: # elif api_str == Api.eval.value:
raise ValueError(f"Invalid API string: {api_str}") # return VALID_SCHEMAS_FOR_EVAL
# else:
# raise ValueError(f"Invalid API string: {api_str}")
def validate_dataset_schema( # def validate_dataset_schema(
dataset_schema: Dict[str, Any], # dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]], # expected_schemas: List[Dict[str, Any]],
): # ):
if dataset_schema not in expected_schemas: # if dataset_schema not in expected_schemas:
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}") # raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
def validate_row_schema( # def validate_row_schema(
input_row: Dict[str, Any], # input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]], # expected_schemas: List[Dict[str, Any]],
): # ):
for schema in expected_schemas: # for schema in expected_schemas:
if all(key in input_row for key in schema): # if all(key in input_row for key in schema):
return # return
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}") # raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")

View file

@ -11,8 +11,8 @@ from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOCon
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )

View file

@ -16,8 +16,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.cerebras.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.cerebras.models import MODEL_ENTRIES
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )

View file

@ -22,8 +22,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )

View file

@ -45,8 +45,8 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
) )
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )
@ -96,10 +96,7 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers() inference_providers, available_models = get_inference_providers()
providers = { providers = {
"inference": ( "inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
[p.provider_type for p in inference_providers]
+ ["inline::sentence-transformers"]
),
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"], "safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"], "agents": ["inline::meta-reference"],
@ -119,9 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
Provider( Provider(
provider_id="sqlite-vec", provider_id="sqlite-vec",
provider_type="inline::sqlite-vec", provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config( config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
f"~/.llama/distributions/{name}"
),
), ),
Provider( Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}", provider_id="${env.ENABLE_CHROMADB+chromadb}",

View file

@ -21,8 +21,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )

View file

@ -15,8 +15,8 @@ from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )

View file

@ -17,8 +17,8 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )
@ -87,9 +87,7 @@ def get_distribution_template() -> DistributionTemplate:
] ]
}, },
default_models=[inference_model, safety_model], default_models=[inference_model, safety_model],
default_shields=[ default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")
],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
}, },

View file

@ -9,7 +9,6 @@ from typing import Dict, List, Tuple
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BenchmarkInput,
DatasetInput, DatasetInput,
ModelInput, ModelInput,
Provider, Provider,
@ -31,14 +30,12 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )
def get_inference_providers() -> ( def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]
):
# in this template, we allow each API key to be optional # in this template, we allow each API key to be optional
providers = [ providers = [
( (
@ -119,9 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
Provider( Provider(
provider_id="sqlite-vec", provider_id="sqlite-vec",
provider_type="inline::sqlite-vec", provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config( config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
f"~/.llama/distributions/{name}"
),
), ),
Provider( Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}", provider_id="${env.ENABLE_CHROMADB+chromadb}",

View file

@ -21,8 +21,8 @@ from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
get_model_registry,
RunConfigSettings, RunConfigSettings,
get_model_registry,
) )