mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 06:22:36 +00:00
Merge remote-tracking branch 'origin/main' into agent_rewrite
This commit is contained in:
commit
57b3d14895
30 changed files with 869 additions and 408 deletions
|
|
@ -496,12 +496,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
return await response.parse()
|
||||
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
body = body or {}
|
||||
exclude_params = exclude_params or set()
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
|
|
@ -530,11 +529,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
converted_body[param_name] = value
|
||||
else:
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
elif unwrapped_body_param and param.name == unwrapped_body_param.name:
|
||||
# This is the unwrapped body param - construct it from remaining body keys
|
||||
base_type = get_args(param.annotation)[0]
|
||||
# Extract only the keys that aren't already used by other params
|
||||
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
|
||||
converted_body[param.name] = base_type(**remaining_keys)
|
||||
|
||||
# handle unwrapped body parameter after processing all named parameters
|
||||
if unwrapped_body_param:
|
||||
base_type = get_args(unwrapped_body_param.annotation)[0]
|
||||
# extract only keys not already used by other params
|
||||
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
|
||||
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
|
||||
|
||||
return converted_body
|
||||
|
|
|
|||
|
|
@ -120,13 +120,7 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
|
||||
|
||||
# Require explicit embedding model specification
|
||||
if embedding_model is None:
|
||||
raise ValueError("embedding_model is required in extra_body when creating a vector store")
|
||||
|
||||
if embedding_dimension is None:
|
||||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
|
|
@ -158,8 +152,10 @@ class VectorIORouter(VectorIO):
|
|||
params.model_extra = {}
|
||||
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_db.provider_id
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
params.model_extra["embedding_dimension"] = embedding_dimension
|
||||
if embedding_model is not None:
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
if embedding_dimension is not None:
|
||||
params.model_extra["embedding_dimension"] = embedding_dimension
|
||||
|
||||
return await provider.openai_create_vector_store(params)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None
|
|||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def validate_default_embedding_model(impls: dict[Api, Any]):
|
||||
"""Validate that at most one embedding model is marked as default."""
|
||||
if Api.models not in impls:
|
||||
return
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = response.data if hasattr(response, "data") else response
|
||||
|
||||
default_embedding_models = []
|
||||
for model in models_list:
|
||||
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
|
||||
default_embedding_models.append(model.identifier)
|
||||
|
||||
if len(default_embedding_models) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
|
||||
"Only one embedding model can be marked as default."
|
||||
)
|
||||
|
||||
if default_embedding_models:
|
||||
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
|
|
@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
await validate_default_embedding_model(impls)
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue