mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
todos
This commit is contained in:
parent
011fd59a29
commit
8a576d7d72
24 changed files with 297 additions and 2525 deletions
1420
docs/_static/llama-stack-spec.html
vendored
1420
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
946
docs/_static/llama-stack-spec.yaml
vendored
946
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -187,9 +187,7 @@ a default SQLite store will be used.""",
|
||||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
logging: Optional[LoggingConfig] = Field(
|
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||||
default=None, description="Configuration for Llama Stack Logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
server: ServerConfig = Field(
|
server: ServerConfig = Field(
|
||||||
default_factory=ServerConfig,
|
default_factory=ServerConfig,
|
||||||
|
@ -200,9 +198,7 @@ a default SQLite store will be used.""",
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
||||||
distribution_spec: DistributionSpec = Field(
|
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
||||||
description="The distribution spec to build including API providers. "
|
|
||||||
)
|
|
||||||
image_type: str = Field(
|
image_type: str = Field(
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
|
|
|
@ -47,14 +47,9 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> List[Api]:
|
||||||
routing_table_apis = {
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
|
||||||
}
|
|
||||||
return [
|
return [
|
||||||
api
|
api for api in Api if api not in routing_table_apis and api not in [Api.inspect, Api.providers, Api.benchmarks]
|
||||||
for api in Api
|
|
||||||
if api not in routing_table_apis
|
|
||||||
and api not in [Api.inspect, Api.providers, Api.benchmarks]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -103,9 +103,7 @@ async def resolve_impls(
|
||||||
2. Sorting them in dependency order.
|
2. Sorting them in dependency order.
|
||||||
3. Instantiating them with required dependencies.
|
3. Instantiating them with required dependencies.
|
||||||
"""
|
"""
|
||||||
routing_table_apis = {
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
|
||||||
}
|
|
||||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||||
|
|
||||||
providers_with_specs = validate_and_prepare_providers(
|
providers_with_specs = validate_and_prepare_providers(
|
||||||
|
@ -113,9 +111,7 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
|
|
||||||
apis_to_serve = run_config.apis or set(
|
apis_to_serve = run_config.apis or set(
|
||||||
list(providers_with_specs.keys())
|
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||||
+ [x.value for x in routing_table_apis]
|
|
||||||
+ [x.value for x in router_apis]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||||
|
@ -180,23 +176,17 @@ def validate_and_prepare_providers(
|
||||||
for api_str, providers in run_config.providers.items():
|
for api_str, providers in run_config.providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
if api in routing_table_apis:
|
if api in routing_table_apis:
|
||||||
raise ValueError(
|
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
|
||||||
)
|
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
logger.warning(
|
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
validate_provider(provider, api, provider_registry)
|
validate_provider(provider, api, provider_registry)
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
p.deps__ = [a.value for a in p.api_dependencies] + [
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||||
a.value for a in p.optional_api_dependencies
|
|
||||||
]
|
|
||||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||||
specs[provider.provider_id] = spec
|
specs[provider.provider_id] = spec
|
||||||
|
|
||||||
|
@ -206,14 +196,10 @@ def validate_and_prepare_providers(
|
||||||
return providers_with_specs
|
return providers_with_specs
|
||||||
|
|
||||||
|
|
||||||
def validate_provider(
|
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||||
provider: Provider, api: Api, provider_registry: ProviderRegistry
|
|
||||||
):
|
|
||||||
"""Validates if the provider is allowed and handles deprecations."""
|
"""Validates if the provider is allowed and handles deprecations."""
|
||||||
if provider.provider_type not in provider_registry[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
|
@ -288,9 +274,7 @@ async def instantiate_providers(
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: Dict[Api, Any] = {}
|
impls: Dict[Api, Any] = {}
|
||||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {
|
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
f"inner-{x.value}": {} for x in router_apis
|
|
||||||
}
|
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
@ -299,9 +283,7 @@ async def instantiate_providers(
|
||||||
|
|
||||||
inner_impls = {}
|
inner_impls = {}
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
f"inner-{provider.spec.router_api.value}"
|
|
||||||
]
|
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||||
|
|
||||||
|
@ -359,9 +341,7 @@ async def instantiate_provider(
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
|
|
||||||
)
|
|
||||||
|
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
|
@ -398,10 +378,7 @@ async def instantiate_provider(
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
if (
|
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
|
||||||
and provider_spec.api in additional_protocols
|
|
||||||
):
|
|
||||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||||
check_protocol_compliance(impl, additional_api)
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
|
@ -429,19 +406,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
logger.error(
|
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
|
||||||
)
|
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
method_owner = next(
|
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||||
(cls for cls in mro if name in cls.__dict__), None
|
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||||
)
|
|
||||||
if (
|
|
||||||
method_owner is None
|
|
||||||
or method_owner.__name__ == protocol.__name__
|
|
||||||
):
|
|
||||||
missing_methods.append((name, "not_actually_implemented"))
|
missing_methods.append((name, "not_actually_implemented"))
|
||||||
|
|
||||||
if missing_methods:
|
if missing_methods:
|
||||||
|
|
|
@ -44,9 +44,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||||
api: Api, routing_table: RoutingTable, deps: Dict[str, Any]
|
|
||||||
) -> Any:
|
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
|
|
@ -8,19 +8,12 @@ import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.apis.eval import (
|
|
||||||
BenchmarkConfig,
|
|
||||||
Eval,
|
|
||||||
EvaluateResponse,
|
|
||||||
Job,
|
|
||||||
JobStatus,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -42,12 +35,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
|
||||||
ScoreBatchResponse,
|
|
||||||
ScoreResponse,
|
|
||||||
Scoring,
|
|
||||||
ScoringFnParams,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -94,9 +81,7 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
|
||||||
)
|
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -114,9 +99,7 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
vector_db_id, chunks, ttl_seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
@ -125,9 +108,7 @@ class VectorIORouter(VectorIO):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
vector_db_id, query, params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -164,9 +145,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
model_id, provider_model_id, provider_id, metadata, model_type
|
|
||||||
)
|
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
@ -220,16 +199,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> List[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
prompt_tokens, completion_tokens, total_tokens, model
|
|
||||||
)
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
await self.telemetry.log_event(metric)
|
await self.telemetry.log_event(metric)
|
||||||
return [
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
|
||||||
for metric in metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -254,9 +228,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -266,19 +238,12 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if (
|
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||||
tool_prompt_format
|
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||||
and tool_prompt_format != tool_config.tool_prompt_format
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params = {}
|
params = {}
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
@ -296,14 +261,9 @@ class InferenceRouter(Inference):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# verify tool_choice is one of the tools
|
# verify tool_choice is one of the tools
|
||||||
tool_names = [
|
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
|
||||||
for t in tools
|
|
||||||
]
|
|
||||||
if tool_config.tool_choice not in tool_names:
|
if tool_config.tool_choice not in tool_names:
|
||||||
raise ValueError(
|
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -318,25 +278,17 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
messages, tool_config.tool_prompt_format
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
if (
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
chunk.event.event_type
|
|
||||||
== ChatCompletionResponseEventType.progress
|
|
||||||
):
|
|
||||||
if chunk.event.delta.type == "text":
|
if chunk.event.delta.type == "text":
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if (
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
chunk.event.event_type
|
|
||||||
== ChatCompletionResponseEventType.complete
|
|
||||||
):
|
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(
|
||||||
[
|
[
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
@ -353,11 +305,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = (
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -374,9 +322,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = (
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -397,9 +343,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -419,11 +363,7 @@ class InferenceRouter(Inference):
|
||||||
async for chunk in await provider.completion(**params):
|
async for chunk in await provider.completion(**params):
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if (
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
hasattr(chunk, "stop_reason")
|
|
||||||
and chunk.stop_reason
|
|
||||||
and self.telemetry
|
|
||||||
):
|
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
metrics = await self._compute_and_log_token_usage(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
@ -432,11 +372,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = (
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -450,9 +386,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = (
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
@ -468,9 +402,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
@ -504,9 +436,7 @@ class SafetyRouter(Safety):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
shield_id, provider_shield_id, provider_id, params
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -607,9 +537,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
"insert_into_memory"
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -642,6 +572,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
tool_group_id, mcp_endpoint
|
|
||||||
)
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
|
||||||
from llama_stack.apis.datasets import (
|
from llama_stack.apis.datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetPurpose,
|
DatasetPurpose,
|
||||||
|
@ -95,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(
|
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
|
||||||
) -> None:
|
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
|
@ -126,9 +123,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(
|
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||||
self, routing_key: str, provider_id: Optional[str] = None
|
|
||||||
) -> Any:
|
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
|
@ -164,9 +159,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
async def get_object_by_identifier(
|
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
# Get from disk registry
|
# Get from disk registry
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
if not obj:
|
if not obj:
|
||||||
|
@ -176,13 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def register_object(
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
self, obj: RoutableObjectWithProvider
|
|
||||||
) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -240,9 +229,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
"Embedding model must have an embedding dimension in its metadata"
|
|
||||||
)
|
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
|
@ -262,9 +249,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
return ListShieldsResponse(
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
data=await self.get_all_with_type(ResourceType.shield.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
@ -329,18 +314,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
"No provider available. Please configure a vector_io provider."
|
|
||||||
)
|
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
f"Model {embedding_model} does not have an embedding dimension"
|
|
||||||
)
|
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
"identifier": vector_db_id,
|
"identifier": vector_db_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_db.value,
|
||||||
|
@ -362,9 +343,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
return ListDatasetsResponse(
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
data=await self.get_all_with_type(ResourceType.dataset.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -447,9 +426,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
# TODO (xiyan): we will need a way to infer provider_id for evaluation
|
# TODO (xiyan): we will need a way to infer provider_id for evaluation
|
||||||
# keep it as meta-reference for now
|
# keep it as meta-reference for now
|
||||||
if len(self.impls_by_provider_id) == 0:
|
if len(self.impls_by_provider_id) == 0:
|
||||||
raise ValueError(
|
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
|
||||||
"No evaluation providers available. Please configure an evaluation provider."
|
|
||||||
)
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
benchmark = Benchmark(
|
benchmark = Benchmark(
|
||||||
|
@ -491,12 +468,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
toolgroup_id, mcp_endpoint
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
)
|
|
||||||
tool_host = (
|
|
||||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
|
|
|
@ -105,9 +105,7 @@ class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
self.var_name = var_name
|
self.var_name = var_name
|
||||||
self.path = path
|
self.path = path
|
||||||
super().__init__(
|
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||||
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
@ -198,9 +196,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
if not key:
|
if not key:
|
||||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||||
if not all(c.isalnum() or c == "_" for c in key):
|
if not all(c.isalnum() or c == "_" for c in key):
|
||||||
raise ValueError(
|
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
||||||
f"Key must contain only alphanumeric characters and underscores: {key}"
|
|
||||||
)
|
|
||||||
return key, value
|
return key, value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -213,20 +209,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
run_config.metadata_store, run_config.image_name
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
||||||
)
|
|
||||||
impls = await resolve_impls(
|
|
||||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
|
||||||
)
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||||
template_path = (
|
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||||
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
with importlib.resources.as_file(template_path) as path:
|
with importlib.resources.as_file(template_path) as path:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
@ -269,9 +259,7 @@ def run_config_from_adhoc_config_spec(
|
||||||
|
|
||||||
# call method "sample_run_config" on the provider spec config class
|
# call method "sample_run_config" on the provider spec config class
|
||||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
provider_config = replace_env_vars(
|
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||||
provider_config_type.sample_run_config(__distro_dir__=distro_dir)
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_configs_by_api[api_str] = [
|
provider_configs_by_api[api_str] = [
|
||||||
Provider(
|
Provider(
|
||||||
|
|
|
@ -22,9 +22,7 @@ class LlamaStackApi:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_scoring(
|
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
||||||
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
|
|
||||||
):
|
|
||||||
"""Run scoring on a single row"""
|
"""Run scoring on a single row"""
|
||||||
if not scoring_params:
|
if not scoring_params:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
|
|
|
@ -4,12 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from streamlit_option_menu import option_menu
|
||||||
|
|
||||||
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||||
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||||
from llama_stack.distribution.ui.page.distribution.models import models
|
from llama_stack.distribution.ui.page.distribution.models import models
|
||||||
from llama_stack.distribution.ui.page.distribution.shields import shields
|
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||||
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||||
from streamlit_option_menu import option_menu
|
|
||||||
|
|
||||||
|
|
||||||
def resources_page():
|
def resources_page():
|
||||||
|
|
|
@ -20,7 +20,5 @@ context_entity_recall_fn_def = ScoringFn(
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="context-entity-recall",
|
provider_resource_id="context-entity-recall",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=BasicScoringFnParams(
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
||||||
aggregation_functions=[AggregationFunctionType.average]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
|
||||||
return [
|
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.eval,
|
|
||||||
provider_type="inline::meta-reference",
|
|
||||||
pip_packages=["tree_sitter"],
|
|
||||||
module="llama_stack.providers.inline.eval.meta_reference",
|
|
||||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
|
||||||
api_dependencies=[
|
|
||||||
Api.datasetio,
|
|
||||||
Api.datasets,
|
|
||||||
Api.scoring,
|
|
||||||
Api.inference,
|
|
||||||
Api.agents,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -1,49 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
|
||||||
return [
|
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.scoring,
|
|
||||||
provider_type="inline::basic",
|
|
||||||
pip_packages=[],
|
|
||||||
module="llama_stack.providers.inline.scoring.basic",
|
|
||||||
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
|
|
||||||
api_dependencies=[
|
|
||||||
Api.datasetio,
|
|
||||||
Api.datasets,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.scoring,
|
|
||||||
provider_type="inline::llm-as-judge",
|
|
||||||
pip_packages=[],
|
|
||||||
module="llama_stack.providers.inline.scoring.llm_as_judge",
|
|
||||||
config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig",
|
|
||||||
api_dependencies=[
|
|
||||||
Api.datasetio,
|
|
||||||
Api.datasets,
|
|
||||||
Api.inference,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.scoring,
|
|
||||||
provider_type="inline::braintrust",
|
|
||||||
pip_packages=["autoevals", "openai"],
|
|
||||||
module="llama_stack.providers.inline.scoring.braintrust",
|
|
||||||
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
|
|
||||||
api_dependencies=[
|
|
||||||
Api.datasetio,
|
|
||||||
Api.datasets,
|
|
||||||
],
|
|
||||||
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
|
|
||||||
),
|
|
||||||
]
|
|
|
@ -75,29 +75,31 @@ VALID_SCHEMAS_FOR_EVAL = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_valid_schemas(api_str: str):
|
# TODO(xiyan): add this back
|
||||||
if api_str == Api.scoring.value:
|
|
||||||
return VALID_SCHEMAS_FOR_SCORING
|
# def get_valid_schemas(api_str: str):
|
||||||
elif api_str == Api.eval.value:
|
# if api_str == Api.scoring.value:
|
||||||
return VALID_SCHEMAS_FOR_EVAL
|
# return VALID_SCHEMAS_FOR_SCORING
|
||||||
else:
|
# elif api_str == Api.eval.value:
|
||||||
raise ValueError(f"Invalid API string: {api_str}")
|
# return VALID_SCHEMAS_FOR_EVAL
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Invalid API string: {api_str}")
|
||||||
|
|
||||||
|
|
||||||
def validate_dataset_schema(
|
# def validate_dataset_schema(
|
||||||
dataset_schema: Dict[str, Any],
|
# dataset_schema: Dict[str, Any],
|
||||||
expected_schemas: List[Dict[str, Any]],
|
# expected_schemas: List[Dict[str, Any]],
|
||||||
):
|
# ):
|
||||||
if dataset_schema not in expected_schemas:
|
# if dataset_schema not in expected_schemas:
|
||||||
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
|
# raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
|
||||||
|
|
||||||
|
|
||||||
def validate_row_schema(
|
# def validate_row_schema(
|
||||||
input_row: Dict[str, Any],
|
# input_row: Dict[str, Any],
|
||||||
expected_schemas: List[Dict[str, Any]],
|
# expected_schemas: List[Dict[str, Any]],
|
||||||
):
|
# ):
|
||||||
for schema in expected_schemas:
|
# for schema in expected_schemas:
|
||||||
if all(key in input_row for key in schema):
|
# if all(key in input_row for key in schema):
|
||||||
return
|
# return
|
||||||
|
|
||||||
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")
|
# raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")
|
||||||
|
|
|
@ -11,8 +11,8 @@ from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOCon
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.cerebras.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,8 +45,8 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,10 +96,7 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
inference_providers, available_models = get_inference_providers()
|
inference_providers, available_models = get_inference_providers()
|
||||||
providers = {
|
providers = {
|
||||||
"inference": (
|
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||||
[p.provider_type for p in inference_providers]
|
|
||||||
+ ["inline::sentence-transformers"]
|
|
||||||
),
|
|
||||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
@ -119,9 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="sqlite-vec",
|
provider_id="sqlite-vec",
|
||||||
provider_type="inline::sqlite-vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
f"~/.llama/distributions/{name}"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
|
|
@ -21,8 +21,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@ from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,9 +87,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
default_models=[inference_model, safety_model],
|
default_models=[inference_model, safety_model],
|
||||||
default_shields=[
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||||
ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")
|
|
||||||
],
|
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Dict, List, Tuple
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkInput,
|
|
||||||
DatasetInput,
|
DatasetInput,
|
||||||
ModelInput,
|
ModelInput,
|
||||||
Provider,
|
Provider,
|
||||||
|
@ -31,14 +30,12 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_inference_providers() -> (
|
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||||
Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]
|
|
||||||
):
|
|
||||||
# in this template, we allow each API key to be optional
|
# in this template, we allow each API key to be optional
|
||||||
providers = [
|
providers = [
|
||||||
(
|
(
|
||||||
|
@ -119,9 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="sqlite-vec",
|
provider_id="sqlite-vec",
|
||||||
provider_type="inline::sqlite-vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
f"~/.llama/distributions/{name}"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
|
|
@ -21,8 +21,8 @@ from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
get_model_registry,
|
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue