forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -58,22 +58,14 @@ def get_provider_dependencies(
|
|||
for api_str, provider_or_providers in config_providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
|
||||
providers = (
|
||||
provider_or_providers
|
||||
if isinstance(provider_or_providers, list)
|
||||
else [provider_or_providers]
|
||||
)
|
||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||
|
||||
for provider in providers:
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
provider_type = (
|
||||
provider if isinstance(provider, str) else provider.provider_type
|
||||
)
|
||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||
)
|
||||
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
|
||||
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
|
@ -109,19 +101,13 @@ def build_image(
|
|||
image_name: str,
|
||||
template_or_config: str,
|
||||
):
|
||||
container_base = (
|
||||
build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||
)
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(
|
||||
build_config.distribution_spec.providers
|
||||
)
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
script = str(
|
||||
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
|
||||
)
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
template_or_config,
|
||||
|
@ -132,9 +118,7 @@ def build_image(
|
|||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
script = str(
|
||||
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
|
||||
)
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
str(image_name),
|
||||
|
@ -142,9 +126,7 @@ def build_image(
|
|||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
script = str(
|
||||
importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
|
||||
)
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
str(image_name),
|
||||
|
|
|
@ -68,9 +68,7 @@ def create_api_client_class(protocol) -> Type:
|
|||
return_type = None
|
||||
else:
|
||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||
assert return_type, (
|
||||
f"Could not extract return type for {sig.return_annotation}"
|
||||
)
|
||||
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
|
@ -87,9 +85,7 @@ def create_api_client_class(protocol) -> Type:
|
|||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
return_type = extract_async_iterator_type(sig.return_annotation)
|
||||
assert return_type, (
|
||||
f"Could not extract return type for {sig.return_annotation}"
|
||||
)
|
||||
assert return_type, f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
|
@ -204,9 +200,7 @@ async def example(model: str = None):
|
|||
if not model:
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
stream = True
|
||||
|
|
|
@ -26,9 +26,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
|
@ -47,9 +45,7 @@ def configure_single_provider(
|
|||
)
|
||||
|
||||
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, build_spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
if is_nux:
|
||||
|
@ -87,9 +83,7 @@ def configure_api_providers(
|
|||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(provider_registry[api], p)
|
||||
)
|
||||
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
||||
logger.info("")
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
|
@ -114,11 +108,7 @@ def configure_api_providers(
|
|||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{provider_type}-{i:02d}"
|
||||
if len(plist) > 1
|
||||
else provider_type
|
||||
),
|
||||
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
),
|
||||
|
@ -137,11 +127,7 @@ def upgrade_from_routing_table(
|
|||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{entry['provider_type']}-{i:02d}"
|
||||
if len(entries) > 1
|
||||
else entry["provider_type"]
|
||||
),
|
||||
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
|
|
|
@ -163,9 +163,7 @@ a default SQLite store will be used.""",
|
|||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
||||
distribution_spec: DistributionSpec = Field(
|
||||
description="The distribution spec to build including API providers. "
|
||||
)
|
||||
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
||||
image_type: str = Field(
|
||||
default="conda",
|
||||
description="Type of package to build (conda | container | venv)",
|
||||
|
|
|
@ -55,9 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
|
||||
|
||||
|
|
|
@ -154,9 +154,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
|
||||
def sync_generator():
|
||||
try:
|
||||
async_stream = loop.run_until_complete(
|
||||
self.async_client.request(*args, **kwargs)
|
||||
)
|
||||
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||
while True:
|
||||
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||
yield chunk
|
||||
|
@ -181,9 +179,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(
|
||||
sink for sink in current_sinks if sink != "console"
|
||||
)
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
|
@ -202,9 +198,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
async def initialize(self):
|
||||
try:
|
||||
self.impls = await construct_stack(
|
||||
self.config, self.custom_provider_registry
|
||||
)
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, "red")
|
||||
cprint(
|
||||
|
@ -247,9 +241,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][
|
||||
_convert_path_to_regex(endpoint.route)
|
||||
] = func
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
return True
|
||||
|
@ -266,9 +258,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise ValueError("Client not initialized")
|
||||
|
||||
if self.provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||
)
|
||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
|
@ -408,9 +398,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: Optional[dict] = None
|
||||
) -> dict:
|
||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
|
@ -425,7 +413,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = convert_to_pydantic(
|
||||
param.annotation, value
|
||||
)
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
return converted_body
|
||||
|
|
|
@ -115,9 +115,7 @@ async def resolve_impls(
|
|||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
|
||||
providers_with_specs = {}
|
||||
|
@ -125,16 +123,12 @@ async def resolve_impls(
|
|||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
|
@ -145,9 +139,7 @@ async def resolve_impls(
|
|||
log.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [
|
||||
a.value for a in p.optional_api_dependencies
|
||||
]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.model_dump()),
|
||||
|
@ -158,9 +150,7 @@ async def resolve_impls(
|
|||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
|
@ -197,9 +187,7 @@ async def resolve_impls(
|
|||
)
|
||||
}
|
||||
|
||||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
|
@ -237,9 +225,7 @@ async def resolve_impls(
|
|||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
|
@ -336,10 +322,7 @@ async def instantiate_provider(
|
|||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
|
@ -367,19 +350,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
log.error(
|
||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||
)
|
||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
)
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
):
|
||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
|
|
|
@ -85,9 +85,7 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||
vector_db_id, chunks, ttl_seconds
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
@ -95,9 +93,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||
vector_db_id, query, params
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
@ -123,9 +119,7 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -143,9 +137,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
|
@ -176,9 +168,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -202,9 +192,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
|
@ -231,9 +219,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
return await self.routing_table.register_shield(
|
||||
shield_id, provider_shield_id, provider_id, params
|
||||
)
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -268,9 +254,7 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
dataset_id
|
||||
).get_rows_paginated(
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
page_token=page_token,
|
||||
|
@ -305,9 +289,7 @@ class ScoringRouter(Scoring):
|
|||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -328,9 +310,7 @@ class ScoringRouter(Scoring):
|
|||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -381,9 +361,7 @@ class EvalRouter(Eval):
|
|||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_status(
|
||||
task_id, job_id
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
|
@ -420,9 +398,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"query_from_memory"
|
||||
).query(content, vector_db_ids, query_config)
|
||||
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
|
@ -430,9 +408,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"insert_into_memory"
|
||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -460,6 +438,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||
tool_group_id, mcp_endpoint
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -94,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(
|
||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||
) -> None:
|
||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
|
@ -131,9 +129,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(
|
||||
self, routing_key: str, provider_id: Optional[str] = None
|
||||
) -> Any:
|
||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
|
@ -171,9 +167,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
|
@ -183,13 +177,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -244,9 +234,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
"Embedding model must have an embedding dimension in its metadata"
|
||||
)
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
|
@ -266,9 +254,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.shield.value)
|
||||
)
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
|
@ -340,9 +326,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(
|
||||
f"Model {embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
|
@ -364,9 +348,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.dataset.value)
|
||||
)
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
@ -411,9 +393,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
)
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
@ -510,12 +490,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||
toolgroup_id, mcp_endpoint
|
||||
)
|
||||
tool_host = (
|
||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
)
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
|
|
|
@ -43,9 +43,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroup:
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(
|
||||
sub_protocol, predicate=inspect.isfunction
|
||||
)
|
||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
|
|
@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
||||
)
|
||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||
|
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
|
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
)
|
||||
]
|
||||
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
path_params = extract_path_params(route)
|
||||
|
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||
new_params = [new_params[0]] + [
|
||||
(
|
||||
param.replace(
|
||||
annotation=Annotated[
|
||||
param.annotation, FastapiPath(..., title=param.name)
|
||||
]
|
||||
)
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
|
|||
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||
if client_version:
|
||||
try:
|
||||
client_version_parts = tuple(
|
||||
map(int, client_version.split(".")[:2])
|
||||
)
|
||||
server_version_parts = tuple(
|
||||
map(int, self.server_version.split(".")[:2])
|
||||
)
|
||||
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
|
||||
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
|
||||
if client_version_parts != server_version_parts:
|
||||
|
||||
async def send_version_error(send):
|
||||
|
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
|
|||
}
|
||||
}
|
||||
).encode()
|
||||
await send(
|
||||
{"type": "http.response.body", "body": error_msg}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
return await send_version_error(send)
|
||||
except (ValueError, IndexError):
|
||||
|
@ -296,9 +276,7 @@ def main():
|
|||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
|
||||
)
|
||||
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
|
@ -323,9 +301,7 @@ def main():
|
|||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
print(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = (
|
||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
)
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
print(f"Using template {args.template} config file: {config_file}")
|
||||
|
@ -383,9 +359,7 @@ def main():
|
|||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, module="pydantic._internal._fields"
|
||||
)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
|
@ -416,9 +390,7 @@ def main():
|
|||
|
||||
def extract_path_params(route: str) -> List[str]:
|
||||
segments = route.split("/")
|
||||
params = [
|
||||
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
|
||||
]
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
return params
|
||||
|
||||
|
||||
|
|
|
@ -110,9 +110,7 @@ class EnvVarError(Exception):
|
|||
def __init__(self, var_name: str, path: str = ""):
|
||||
self.var_name = var_name
|
||||
self.path = path
|
||||
super().__init__(
|
||||
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
|
||||
)
|
||||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
@ -187,9 +185,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
if not key:
|
||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
raise ValueError(
|
||||
f"Key must contain only alphanumeric characters and underscores: {key}"
|
||||
)
|
||||
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
||||
return key, value
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
|
@ -202,20 +198,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||
)
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||
template_path = (
|
||||
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||
)
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(template_path) as path:
|
||||
if not path.exists():
|
||||
|
|
|
@ -25,9 +25,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def update(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider: ...
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||
|
||||
|
@ -61,9 +59,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def get_cached(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
# Disk registry does not have a cache
|
||||
raise NotImplementedError("Disk registry does not have a cache")
|
||||
|
||||
|
@ -72,12 +68,8 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
values = await self.kvstore.range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
json_str = await self.kvstore.get(
|
||||
KEY_FORMAT.format(type=type, identifier=identifier)
|
||||
)
|
||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
|
@ -143,9 +135,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
async def initialize(self) -> None:
|
||||
await self._ensure_initialized()
|
||||
|
||||
def get_cached(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
return self.cache.get((type, identifier), None)
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
|
@ -153,9 +143,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
async with self._locked_cache() as cache:
|
||||
return list(cache.values())
|
||||
|
||||
async def get(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
cache_key = (type, identifier)
|
||||
|
||||
|
@ -197,9 +185,7 @@ async def create_dist_registry(
|
|||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||
)
|
||||
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
|
||||
)
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
await dist_registry.initialize()
|
||||
|
|
|
@ -161,9 +161,7 @@ async def test_duplicate_provider_registration(config):
|
|||
|
||||
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert (
|
||||
result.embedding_model == original_vector_db.embedding_model
|
||||
) # Original values preserved
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -193,14 +191,9 @@ async def test_get_all_objects(config):
|
|||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [
|
||||
v for v in all_results if v.identifier == original_vector_db.identifier
|
||||
]
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert (
|
||||
stored_vector_db.embedding_dimension
|
||||
== original_vector_db.embedding_dimension
|
||||
)
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
|
|
|
@ -22,15 +22,11 @@ def main():
|
|||
)
|
||||
|
||||
# Playground pages
|
||||
chat_page = st.Page(
|
||||
"page/playground/chat.py", title="Chat", icon="💬", default=True
|
||||
)
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page(
|
||||
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
|
||||
)
|
||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||
provider_page = st.Page(
|
||||
"page/distribution/providers.py",
|
||||
title="API Providers",
|
||||
|
|
|
@ -23,15 +23,11 @@ class LlamaStackApi:
|
|||
},
|
||||
)
|
||||
|
||||
def run_scoring(
|
||||
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
|
||||
):
|
||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
||||
"""Run scoring on a single row"""
|
||||
if not scoring_params:
|
||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||
return self.client.scoring.score(
|
||||
input_rows=[row], scoring_functions=scoring_params
|
||||
)
|
||||
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||
|
||||
|
||||
llama_stack_api = LlamaStackApi()
|
||||
|
|
|
@ -11,9 +11,7 @@ from modules.api import llama_stack_api
|
|||
def datasets():
|
||||
st.header("Datasets")
|
||||
|
||||
datasets_info = {
|
||||
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
|
||||
}
|
||||
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
|
||||
if len(datasets_info) > 0:
|
||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||
st.json(datasets_info[selected_dataset], expanded=True)
|
||||
|
|
|
@ -12,12 +12,8 @@ def eval_tasks():
|
|||
# Eval Tasks Section
|
||||
st.header("Eval Tasks")
|
||||
|
||||
eval_tasks_info = {
|
||||
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
|
||||
}
|
||||
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
|
||||
|
||||
if len(eval_tasks_info) > 0:
|
||||
selected_eval_task = st.selectbox(
|
||||
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
|
||||
)
|
||||
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
|
||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
||||
|
|
|
@ -11,9 +11,7 @@ from modules.api import llama_stack_api
|
|||
def models():
|
||||
# Models Section
|
||||
st.header("Models")
|
||||
models_info = {
|
||||
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
|
||||
}
|
||||
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
|
||||
|
||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||
st.json(models_info[selected_model])
|
||||
|
|
|
@ -11,12 +11,7 @@ from modules.api import llama_stack_api
|
|||
def scoring_functions():
|
||||
st.header("Scoring Functions")
|
||||
|
||||
scoring_functions_info = {
|
||||
s.identifier: s.to_dict()
|
||||
for s in llama_stack_api.client.scoring_functions.list()
|
||||
}
|
||||
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
|
||||
|
||||
selected_scoring_function = st.selectbox(
|
||||
"Select a scoring function", list(scoring_functions_info.keys())
|
||||
)
|
||||
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
|
||||
st.json(scoring_functions_info[selected_scoring_function], expanded=True)
|
||||
|
|
|
@ -12,9 +12,7 @@ def shields():
|
|||
# Shields Section
|
||||
st.header("Shields")
|
||||
|
||||
shields_info = {
|
||||
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
|
||||
}
|
||||
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
|
||||
|
||||
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
|
||||
st.json(shields_info[selected_shield])
|
||||
|
|
|
@ -10,14 +10,10 @@ from modules.api import llama_stack_api
|
|||
|
||||
def vector_dbs():
|
||||
st.header("Vector Databases")
|
||||
vector_dbs_info = {
|
||||
v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()
|
||||
}
|
||||
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
|
||||
|
||||
if len(vector_dbs_info) > 0:
|
||||
selected_vector_db = st.selectbox(
|
||||
"Select a vector database", list(vector_dbs_info.keys())
|
||||
)
|
||||
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
|
||||
st.json(vector_dbs_info[selected_vector_db])
|
||||
else:
|
||||
st.info("No vector databases found")
|
||||
|
|
|
@ -14,7 +14,6 @@ from modules.utils import process_dataset
|
|||
|
||||
|
||||
def application_evaluation_page():
|
||||
|
||||
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Scoring)")
|
||||
|
||||
|
@ -83,9 +82,7 @@ def application_evaluation_page():
|
|||
try:
|
||||
new_params[param_name] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
st.error(
|
||||
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
|
||||
)
|
||||
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
|
||||
|
||||
st.json(new_params)
|
||||
scoring_params[scoring_fn_id] = new_params
|
||||
|
@ -128,9 +125,7 @@ def application_evaluation_page():
|
|||
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
|
||||
|
||||
# Display current row results using separate containers
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
||||
)
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(
|
||||
score_res.to_json(),
|
||||
expanded=2,
|
||||
|
|
|
@ -195,7 +195,6 @@ def run_evaluation_3():
|
|||
|
||||
# Add run button and handle evaluation
|
||||
if st.button("Run Evaluation"):
|
||||
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = rows.rows
|
||||
|
@ -233,9 +232,7 @@ def run_evaluation_3():
|
|||
output_res[scoring_fn] = []
|
||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
||||
)
|
||||
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||
results_container.json(eval_res, expanded=2)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
|
@ -247,7 +244,6 @@ def run_evaluation_3():
|
|||
|
||||
|
||||
def native_evaluation_page():
|
||||
|
||||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Generation + Scoring)")
|
||||
|
||||
|
|
|
@ -11,9 +11,7 @@ from modules.api import llama_stack_api
|
|||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [
|
||||
model.identifier for model in available_models if model.model_type == "llm"
|
||||
]
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
|
@ -128,6 +126,4 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
full_response = response
|
||||
message_placeholder.markdown(full_response.completion_message.content)
|
||||
|
||||
st.session_state.messages.append(
|
||||
{"role": "assistant", "content": full_response}
|
||||
)
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
|
|
|
@ -74,9 +74,7 @@ def rag_chat_page():
|
|||
)
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [
|
||||
model.identifier for model in available_models if model.model_type == "llm"
|
||||
]
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
|
@ -137,9 +135,7 @@ def rag_chat_page():
|
|||
dict(
|
||||
name="builtin::rag",
|
||||
args={
|
||||
"vector_db_ids": [
|
||||
vector_db_id for vector_db_id in selected_vector_dbs
|
||||
],
|
||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@ -186,9 +182,7 @@ def rag_chat_page():
|
|||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append(
|
||||
{"role": "assistant", "content": full_response}
|
||||
)
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
|
||||
|
||||
rag_chat_page()
|
||||
|
|
|
@ -8,9 +8,7 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = Path(
|
||||
os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))
|
||||
)
|
||||
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||
|
||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
||||
|
|
|
@ -31,15 +31,11 @@ def is_list_of_primitives(field_type):
|
|||
|
||||
|
||||
def is_basemodel_without_fields(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
)
|
||||
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
|
||||
|
||||
def can_recurse(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
)
|
||||
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
|
||||
|
||||
def get_literal_values(field):
|
||||
|
@ -72,7 +68,7 @@ def is_discriminated_union(typ) -> bool:
|
|||
if isinstance(typ, FieldInfo):
|
||||
return typ.discriminator
|
||||
else:
|
||||
if not (get_origin(typ) is Annotated):
|
||||
if get_origin(typ) is not Annotated:
|
||||
return False
|
||||
args = get_args(typ)
|
||||
return len(args) >= 2 and args[1].discriminator
|
||||
|
@ -116,9 +112,7 @@ def prompt_for_discriminated_union(
|
|||
chosen_type = type_map[discriminator_value]
|
||||
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator) != discriminator_value
|
||||
):
|
||||
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
|
@ -134,9 +128,7 @@ def prompt_for_discriminated_union(
|
|||
#
|
||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||
# unit tests for coverage.
|
||||
def prompt_for_config(
|
||||
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
||||
) -> BaseModel:
|
||||
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
|
||||
"""
|
||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||
|
||||
|
@ -150,17 +142,11 @@ def prompt_for_config(
|
|||
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
existing_value = (
|
||||
getattr(existing_config, field_name) if existing_config else None
|
||||
)
|
||||
existing_value = getattr(existing_config, field_name) if existing_config else None
|
||||
if existing_value:
|
||||
default_value = existing_value
|
||||
else:
|
||||
default_value = (
|
||||
field.default
|
||||
if not isinstance(field.default, PydanticUndefinedType)
|
||||
else None
|
||||
)
|
||||
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
|
||||
is_required = field.is_required
|
||||
|
||||
# Skip fields with Literal type
|
||||
|
@ -183,15 +169,11 @@ def prompt_for_config(
|
|||
config_data[field_name] = validated_value
|
||||
break
|
||||
except KeyError:
|
||||
log.error(
|
||||
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
||||
)
|
||||
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
|
||||
continue
|
||||
|
||||
if is_discriminated_union(field):
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name, field, existing_value
|
||||
)
|
||||
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
|
||||
continue
|
||||
|
||||
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||
|
@ -202,9 +184,7 @@ def prompt_for_config(
|
|||
nested_type = get_non_none_type(field_type)
|
||||
log.info(f"Entering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||
elif is_optional(field_type) and is_discriminated_union(
|
||||
get_non_none_type(field_type)
|
||||
):
|
||||
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
|
@ -260,16 +240,12 @@ def prompt_for_config(
|
|||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded list"
|
||||
)
|
||||
raise ValueError("Input must be a JSON-encoded list")
|
||||
element_type = get_args(field_type)[0]
|
||||
value = [element_type(item) for item in value]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error(
|
||||
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
||||
)
|
||||
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
|
||||
continue
|
||||
except ValueError as e:
|
||||
log.error(f"{str(e)}")
|
||||
|
@ -279,20 +255,14 @@ def prompt_for_config(
|
|||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded dictionary"
|
||||
)
|
||||
raise ValueError("Input must be a JSON-encoded dictionary")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error(
|
||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||
)
|
||||
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
|
||||
continue
|
||||
|
||||
# Convert the input to the correct type
|
||||
elif inspect.isclass(field_type) and issubclass(
|
||||
field_type, BaseModel
|
||||
):
|
||||
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
# For nested BaseModels, we assume a dictionary-like string input
|
||||
import ast
|
||||
|
||||
|
@ -301,16 +271,12 @@ def prompt_for_config(
|
|||
value = field_type(user_input)
|
||||
|
||||
except ValueError:
|
||||
log.error(
|
||||
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
||||
)
|
||||
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate the field using our manual validation function
|
||||
validated_value = manually_validate_field(
|
||||
config_type, field_name, value
|
||||
)
|
||||
validated_value = manually_validate_field(config_type, field_name, value)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except ValueError as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue