mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	Merge remote-tracking branch 'origin/main' into stores
This commit is contained in:
		
						commit
						b72154ce5e
					
				
					 1161 changed files with 609896 additions and 42960 deletions
				
			
		|  | @ -324,14 +324,14 @@ fi | |||
| RUN pip uninstall -y uv | ||||
| EOF | ||||
| 
 | ||||
| # If a run config is provided, we use the --config flag | ||||
| # If a run config is provided, we use the llama stack CLI | ||||
| if [[ -n "$run_config" ]]; then | ||||
|   add_to_container << EOF | ||||
| ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"] | ||||
| ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"] | ||||
| EOF | ||||
| elif [[ "$distro_or_config" != *.yaml ]]; then | ||||
|   add_to_container << EOF | ||||
| ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"] | ||||
| ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"] | ||||
| EOF | ||||
| fi | ||||
| 
 | ||||
|  |  | |||
|  | @ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import ( | |||
|     sqlstore_impl, | ||||
| ) | ||||
| 
 | ||||
| logger = get_logger(name=__name__, category="openai::conversations") | ||||
| logger = get_logger(name=__name__, category="openai_conversations") | ||||
| 
 | ||||
| 
 | ||||
| class ConversationServiceConfig(BaseModel): | ||||
|  | @ -196,12 +196,15 @@ class ConversationServiceImpl(Conversations): | |||
|         await self._get_validated_conversation(conversation_id) | ||||
| 
 | ||||
|         created_items = [] | ||||
|         created_at = int(time.time()) | ||||
|         base_time = int(time.time()) | ||||
| 
 | ||||
|         for item in items: | ||||
|         for i, item in enumerate(items): | ||||
|             item_dict = item.model_dump() | ||||
|             item_id = self._get_or_generate_item_id(item, item_dict) | ||||
| 
 | ||||
|             # make each timestamp unique to maintain order | ||||
|             created_at = base_time + i | ||||
| 
 | ||||
|             item_record = { | ||||
|                 "id": item_id, | ||||
|                 "conversation_id": conversation_id, | ||||
|  |  | |||
|  | @ -47,10 +47,6 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]: | |||
|             routing_table_api=Api.shields, | ||||
|             router_api=Api.safety, | ||||
|         ), | ||||
|         AutoRoutedApiInfo( | ||||
|             routing_table_api=Api.vector_dbs, | ||||
|             router_api=Api.vector_io, | ||||
|         ), | ||||
|         AutoRoutedApiInfo( | ||||
|             routing_table_api=Api.datasets, | ||||
|             router_api=Api.datasetio, | ||||
|  | @ -243,6 +239,7 @@ def get_external_providers_from_module( | |||
|                     spec = module.get_provider_spec() | ||||
|                 else: | ||||
|                     # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run | ||||
|                     # in the case we are building we CANNOT import this module of course because it has not been installed. | ||||
|                     spec = ProviderSpec( | ||||
|                         api=Api(provider_api), | ||||
|                         provider_type=provider.provider_type, | ||||
|  | @ -251,9 +248,20 @@ def get_external_providers_from_module( | |||
|                         config_class="", | ||||
|                     ) | ||||
|                 provider_type = provider.provider_type | ||||
|                 # in the case we are building we CANNOT import this module of course because it has not been installed. | ||||
|                 # return a partially filled out spec that the build script will populate. | ||||
|                 registry[Api(provider_api)][provider_type] = spec | ||||
|                 if isinstance(spec, list): | ||||
|                     # optionally allow people to pass inline and remote provider specs as a returned list. | ||||
|                     # with the old method, users could pass in directories of specs using overlapping code | ||||
|                     # we want to ensure we preserve that flexibility in this method. | ||||
|                     logger.info( | ||||
|                         f"Detected a list of external provider specs from {provider.module} adding all to the registry" | ||||
|                     ) | ||||
|                     for provider_spec in spec: | ||||
|                         if provider_spec.provider_type != provider.provider_type: | ||||
|                             continue | ||||
|                         logger.info(f"Adding {provider.provider_type} to registry") | ||||
|                         registry[Api(provider_api)][provider.provider_type] = provider_spec | ||||
|                 else: | ||||
|                     registry[Api(provider_api)][provider_type] = spec | ||||
|             except ModuleNotFoundError as exc: | ||||
|                 raise ValueError( | ||||
|                     "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" | ||||
|  |  | |||
							
								
								
									
										42
									
								
								llama_stack/core/id_generation.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								llama_stack/core/id_generation.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,42 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from collections.abc import Callable | ||||
| 
 | ||||
| IdFactory = Callable[[], str] | ||||
| IdOverride = Callable[[str, IdFactory], str] | ||||
| 
 | ||||
| _id_override: IdOverride | None = None | ||||
| 
 | ||||
| 
 | ||||
| def generate_object_id(kind: str, factory: IdFactory) -> str: | ||||
|     """Generate an identifier for the given kind using the provided factory. | ||||
| 
 | ||||
|     Allows tests to override ID generation deterministically by installing an | ||||
|     override callback via :func:`set_id_override`. | ||||
|     """ | ||||
| 
 | ||||
|     override = _id_override | ||||
|     if override is not None: | ||||
|         return override(kind, factory) | ||||
|     return factory() | ||||
| 
 | ||||
| 
 | ||||
| def set_id_override(override: IdOverride) -> IdOverride | None: | ||||
|     """Install an override used to generate deterministic identifiers.""" | ||||
| 
 | ||||
|     global _id_override | ||||
| 
 | ||||
|     previous = _id_override | ||||
|     _id_override = override | ||||
|     return previous | ||||
| 
 | ||||
| 
 | ||||
| def reset_id_override(previous: IdOverride | None) -> None: | ||||
|     """Restore the previous override returned by :func:`set_id_override`.""" | ||||
| 
 | ||||
|     global _id_override | ||||
|     _id_override = previous | ||||
|  | @ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( | |||
|     setup_logger, | ||||
|     start_trace, | ||||
| ) | ||||
| from llama_stack.strong_typing.inspection import is_unwrapped_body_param | ||||
| 
 | ||||
| logger = get_logger(name=__name__, category="core") | ||||
| 
 | ||||
|  | @ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): | |||
| 
 | ||||
|         body, field_names = self._handle_file_uploads(options, body) | ||||
| 
 | ||||
|         body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) | ||||
|         body = self._convert_body(matched_func, body, exclude_params=set(field_names)) | ||||
| 
 | ||||
|         trace_path = webmethod.descriptive_name or route_path | ||||
|         await start_trace(trace_path, {"__location__": "library_client"}) | ||||
|  | @ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): | |||
|         func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) | ||||
|         body |= path_params | ||||
| 
 | ||||
|         body = self._convert_body(path, options.method, body) | ||||
|         # Prepare body for the function call (handles both Pydantic and traditional params) | ||||
|         body = self._convert_body(func, body) | ||||
| 
 | ||||
|         trace_path = webmethod.descriptive_name or route_path | ||||
|         await start_trace(trace_path, {"__location__": "library_client"}) | ||||
|  | @ -493,21 +495,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): | |||
|         ) | ||||
|         return await response.parse() | ||||
| 
 | ||||
|     def _convert_body( | ||||
|         self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None | ||||
|     ) -> dict: | ||||
|     def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict: | ||||
|         if not body: | ||||
|             return {} | ||||
| 
 | ||||
|         assert self.route_impls is not None  # Should be guaranteed by request() method, assertion for mypy | ||||
|         exclude_params = exclude_params or set() | ||||
| 
 | ||||
|         func, _, _, _ = find_matching_route(method, path, self.route_impls) | ||||
|         sig = inspect.signature(func) | ||||
|         params_list = [p for p in sig.parameters.values() if p.name != "self"] | ||||
|         # Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)]) | ||||
|         if len(params_list) == 1: | ||||
|             param = params_list[0] | ||||
|             param_type = param.annotation | ||||
|             if is_unwrapped_body_param(param_type): | ||||
|                 base_type = get_args(param_type)[0] | ||||
|                 return {param.name: base_type(**body)} | ||||
| 
 | ||||
|         # Strip NOT_GIVENs to use the defaults in signature | ||||
|         body = {k: v for k, v in body.items() if v is not NOT_GIVEN} | ||||
| 
 | ||||
|         # Check if there's an unwrapped body parameter among multiple parameters | ||||
|         # (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)]) | ||||
|         unwrapped_body_param = None | ||||
|         for param in params_list: | ||||
|             if is_unwrapped_body_param(param.annotation): | ||||
|                 unwrapped_body_param = param | ||||
|                 break | ||||
| 
 | ||||
|         # Convert parameters to Pydantic models where needed | ||||
|         converted_body = {} | ||||
|         for param_name, param in sig.parameters.items(): | ||||
|  | @ -517,5 +530,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): | |||
|                     converted_body[param_name] = value | ||||
|                 else: | ||||
|                     converted_body[param_name] = convert_to_pydantic(param.annotation, value) | ||||
|             elif unwrapped_body_param and param.name == unwrapped_body_param.name: | ||||
|                 # This is the unwrapped body param - construct it from remaining body keys | ||||
|                 base_type = get_args(param.annotation)[0] | ||||
|                 # Extract only the keys that aren't already used by other params | ||||
|                 remaining_keys = {k: v for k, v in body.items() if k not in converted_body} | ||||
|                 converted_body[param.name] = base_type(**remaining_keys) | ||||
| 
 | ||||
|         return converted_body | ||||
|  |  | |||
|  | @ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions | |||
| from llama_stack.apis.shields import Shields | ||||
| from llama_stack.apis.telemetry import Telemetry | ||||
| from llama_stack.apis.tools import ToolGroups, ToolRuntime | ||||
| from llama_stack.apis.vector_dbs import VectorDBs | ||||
| from llama_stack.apis.vector_io import VectorIO | ||||
| from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA | ||||
| from llama_stack.core.client import get_client_impl | ||||
|  | @ -55,7 +54,6 @@ from llama_stack.providers.datatypes import ( | |||
|     ScoringFunctionsProtocolPrivate, | ||||
|     ShieldsProtocolPrivate, | ||||
|     ToolGroupsProtocolPrivate, | ||||
|     VectorDBsProtocolPrivate, | ||||
| ) | ||||
| 
 | ||||
| logger = get_logger(name=__name__, category="core") | ||||
|  | @ -81,7 +79,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> | |||
|         Api.inspect: Inspect, | ||||
|         Api.batches: Batches, | ||||
|         Api.vector_io: VectorIO, | ||||
|         Api.vector_dbs: VectorDBs, | ||||
|         Api.models: Models, | ||||
|         Api.safety: Safety, | ||||
|         Api.shields: Shields, | ||||
|  | @ -125,7 +122,6 @@ def additional_protocols_map() -> dict[Api, Any]: | |||
|     return { | ||||
|         Api.inference: (ModelsProtocolPrivate, Models, Api.models), | ||||
|         Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), | ||||
|         Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), | ||||
|         Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), | ||||
|         Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), | ||||
|         Api.scoring: ( | ||||
|  | @ -150,6 +146,7 @@ async def resolve_impls( | |||
|     provider_registry: ProviderRegistry, | ||||
|     dist_registry: DistributionRegistry, | ||||
|     policy: list[AccessRule], | ||||
|     internal_impls: dict[Api, Any] | None = None, | ||||
| ) -> dict[Api, Any]: | ||||
|     """ | ||||
|     Resolves provider implementations by: | ||||
|  | @ -172,7 +169,7 @@ async def resolve_impls( | |||
| 
 | ||||
|     sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) | ||||
| 
 | ||||
|     return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy) | ||||
|     return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls) | ||||
| 
 | ||||
| 
 | ||||
| def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: | ||||
|  | @ -280,9 +277,10 @@ async def instantiate_providers( | |||
|     dist_registry: DistributionRegistry, | ||||
|     run_config: StackRunConfig, | ||||
|     policy: list[AccessRule], | ||||
|     internal_impls: dict[Api, Any] | None = None, | ||||
| ) -> dict[Api, Any]: | ||||
|     """Instantiates providers asynchronously while managing dependencies.""" | ||||
|     impls: dict[Api, Any] = {} | ||||
|     impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {} | ||||
|     inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} | ||||
|     for api_str, provider in sorted_providers: | ||||
|         # Skip providers that are not enabled | ||||
|  |  | |||
|  | @ -31,10 +31,8 @@ async def get_routing_table_impl( | |||
|     from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable | ||||
|     from ..routing_tables.shields import ShieldsRoutingTable | ||||
|     from ..routing_tables.toolgroups import ToolGroupsRoutingTable | ||||
|     from ..routing_tables.vector_dbs import VectorDBsRoutingTable | ||||
| 
 | ||||
|     api_to_tables = { | ||||
|         "vector_dbs": VectorDBsRoutingTable, | ||||
|         "models": ModelsRoutingTable, | ||||
|         "shields": ShieldsRoutingTable, | ||||
|         "datasets": DatasetsRoutingTable, | ||||
|  |  | |||
|  | @ -10,9 +10,10 @@ from collections.abc import AsyncGenerator, AsyncIterator | |||
| from datetime import UTC, datetime | ||||
| from typing import Annotated, Any | ||||
| 
 | ||||
| from fastapi import Body | ||||
| from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam | ||||
| from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam | ||||
| from pydantic import Field, TypeAdapter | ||||
| from pydantic import TypeAdapter | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import ( | ||||
|     InterleavedContent, | ||||
|  | @ -31,15 +32,17 @@ from llama_stack.apis.inference import ( | |||
|     OpenAIAssistantMessageParam, | ||||
|     OpenAIChatCompletion, | ||||
|     OpenAIChatCompletionChunk, | ||||
|     OpenAIChatCompletionRequestWithExtraBody, | ||||
|     OpenAIChatCompletionToolCall, | ||||
|     OpenAIChatCompletionToolCallFunction, | ||||
|     OpenAIChoice, | ||||
|     OpenAIChoiceLogprobs, | ||||
|     OpenAICompletion, | ||||
|     OpenAICompletionRequestWithExtraBody, | ||||
|     OpenAICompletionWithInputMessages, | ||||
|     OpenAIEmbeddingsRequestWithExtraBody, | ||||
|     OpenAIEmbeddingsResponse, | ||||
|     OpenAIMessageParam, | ||||
|     OpenAIResponseFormatParam, | ||||
|     Order, | ||||
|     StopReason, | ||||
|     ToolPromptFormat, | ||||
|  | @ -181,61 +184,23 @@ class InferenceRouter(Inference): | |||
| 
 | ||||
|     async def openai_completion( | ||||
|         self, | ||||
|         model: str, | ||||
|         prompt: str | list[str] | list[int] | list[list[int]], | ||||
|         best_of: int | None = None, | ||||
|         echo: bool | None = None, | ||||
|         frequency_penalty: float | None = None, | ||||
|         logit_bias: dict[str, float] | None = None, | ||||
|         logprobs: bool | None = None, | ||||
|         max_tokens: int | None = None, | ||||
|         n: int | None = None, | ||||
|         presence_penalty: float | None = None, | ||||
|         seed: int | None = None, | ||||
|         stop: str | list[str] | None = None, | ||||
|         stream: bool | None = None, | ||||
|         stream_options: dict[str, Any] | None = None, | ||||
|         temperature: float | None = None, | ||||
|         top_p: float | None = None, | ||||
|         user: str | None = None, | ||||
|         guided_choice: list[str] | None = None, | ||||
|         prompt_logprobs: int | None = None, | ||||
|         suffix: str | None = None, | ||||
|         params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], | ||||
|     ) -> OpenAICompletion: | ||||
|         logger.debug( | ||||
|             f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", | ||||
|         ) | ||||
|         model_obj = await self._get_model(model, ModelType.llm) | ||||
|         params = dict( | ||||
|             model=model_obj.identifier, | ||||
|             prompt=prompt, | ||||
|             best_of=best_of, | ||||
|             echo=echo, | ||||
|             frequency_penalty=frequency_penalty, | ||||
|             logit_bias=logit_bias, | ||||
|             logprobs=logprobs, | ||||
|             max_tokens=max_tokens, | ||||
|             n=n, | ||||
|             presence_penalty=presence_penalty, | ||||
|             seed=seed, | ||||
|             stop=stop, | ||||
|             stream=stream, | ||||
|             stream_options=stream_options, | ||||
|             temperature=temperature, | ||||
|             top_p=top_p, | ||||
|             user=user, | ||||
|             guided_choice=guided_choice, | ||||
|             prompt_logprobs=prompt_logprobs, | ||||
|             suffix=suffix, | ||||
|             f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}", | ||||
|         ) | ||||
|         model_obj = await self._get_model(params.model, ModelType.llm) | ||||
| 
 | ||||
|         # Update params with the resolved model identifier | ||||
|         params.model = model_obj.identifier | ||||
| 
 | ||||
|         provider = await self.routing_table.get_provider_impl(model_obj.identifier) | ||||
|         if stream: | ||||
|             return await provider.openai_completion(**params) | ||||
|         if params.stream: | ||||
|             return await provider.openai_completion(params) | ||||
|             # TODO: Metrics do NOT work with openai_completion stream=True due to the fact | ||||
|             # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. | ||||
|             # response_stream = await provider.openai_completion(**params) | ||||
| 
 | ||||
|         response = await provider.openai_completion(**params) | ||||
|         response = await provider.openai_completion(params) | ||||
|         if self.telemetry: | ||||
|             metrics = self._construct_metrics( | ||||
|                 prompt_tokens=response.usage.prompt_tokens, | ||||
|  | @ -254,93 +219,49 @@ class InferenceRouter(Inference): | |||
| 
 | ||||
|     async def openai_chat_completion( | ||||
|         self, | ||||
|         model: str, | ||||
|         messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)], | ||||
|         frequency_penalty: float | None = None, | ||||
|         function_call: str | dict[str, Any] | None = None, | ||||
|         functions: list[dict[str, Any]] | None = None, | ||||
|         logit_bias: dict[str, float] | None = None, | ||||
|         logprobs: bool | None = None, | ||||
|         max_completion_tokens: int | None = None, | ||||
|         max_tokens: int | None = None, | ||||
|         n: int | None = None, | ||||
|         parallel_tool_calls: bool | None = None, | ||||
|         presence_penalty: float | None = None, | ||||
|         response_format: OpenAIResponseFormatParam | None = None, | ||||
|         seed: int | None = None, | ||||
|         stop: str | list[str] | None = None, | ||||
|         stream: bool | None = None, | ||||
|         stream_options: dict[str, Any] | None = None, | ||||
|         temperature: float | None = None, | ||||
|         tool_choice: str | dict[str, Any] | None = None, | ||||
|         tools: list[dict[str, Any]] | None = None, | ||||
|         top_logprobs: int | None = None, | ||||
|         top_p: float | None = None, | ||||
|         user: str | None = None, | ||||
|         params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)], | ||||
|     ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: | ||||
|         logger.debug( | ||||
|             f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", | ||||
|             f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}", | ||||
|         ) | ||||
|         model_obj = await self._get_model(model, ModelType.llm) | ||||
|         model_obj = await self._get_model(params.model, ModelType.llm) | ||||
| 
 | ||||
|         # Use the OpenAI client for a bit of extra input validation without | ||||
|         # exposing the OpenAI client itself as part of our API surface | ||||
|         if tool_choice: | ||||
|             TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) | ||||
|             if tools is None: | ||||
|         if params.tool_choice: | ||||
|             TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice) | ||||
|             if params.tools is None: | ||||
|                 raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") | ||||
|         if tools: | ||||
|             for tool in tools: | ||||
|         if params.tools: | ||||
|             for tool in params.tools: | ||||
|                 TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) | ||||
| 
 | ||||
|         # Some providers make tool calls even when tool_choice is "none" | ||||
|         # so just clear them both out to avoid unexpected tool calls | ||||
|         if tool_choice == "none" and tools is not None: | ||||
|             tool_choice = None | ||||
|             tools = None | ||||
|         if params.tool_choice == "none" and params.tools is not None: | ||||
|             params.tool_choice = None | ||||
|             params.tools = None | ||||
| 
 | ||||
|         # Update params with the resolved model identifier | ||||
|         params.model = model_obj.identifier | ||||
| 
 | ||||
|         params = dict( | ||||
|             model=model_obj.identifier, | ||||
|             messages=messages, | ||||
|             frequency_penalty=frequency_penalty, | ||||
|             function_call=function_call, | ||||
|             functions=functions, | ||||
|             logit_bias=logit_bias, | ||||
|             logprobs=logprobs, | ||||
|             max_completion_tokens=max_completion_tokens, | ||||
|             max_tokens=max_tokens, | ||||
|             n=n, | ||||
|             parallel_tool_calls=parallel_tool_calls, | ||||
|             presence_penalty=presence_penalty, | ||||
|             response_format=response_format, | ||||
|             seed=seed, | ||||
|             stop=stop, | ||||
|             stream=stream, | ||||
|             stream_options=stream_options, | ||||
|             temperature=temperature, | ||||
|             tool_choice=tool_choice, | ||||
|             tools=tools, | ||||
|             top_logprobs=top_logprobs, | ||||
|             top_p=top_p, | ||||
|             user=user, | ||||
|         ) | ||||
|         provider = await self.routing_table.get_provider_impl(model_obj.identifier) | ||||
|         if stream: | ||||
|             response_stream = await provider.openai_chat_completion(**params) | ||||
|         if params.stream: | ||||
|             response_stream = await provider.openai_chat_completion(params) | ||||
| 
 | ||||
|             # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk] | ||||
|             # We need to add metrics to each chunk and store the final completion | ||||
|             return self.stream_tokens_and_compute_metrics_openai_chat( | ||||
|                 response=response_stream, | ||||
|                 model=model_obj, | ||||
|                 messages=messages, | ||||
|                 messages=params.messages, | ||||
|             ) | ||||
| 
 | ||||
|         response = await self._nonstream_openai_chat_completion(provider, params) | ||||
| 
 | ||||
|         # Store the response with the ID that will be returned to the client | ||||
|         if self.store: | ||||
|             asyncio.create_task(self.store.store_chat_completion(response, messages)) | ||||
|             asyncio.create_task(self.store.store_chat_completion(response, params.messages)) | ||||
| 
 | ||||
|         if self.telemetry: | ||||
|             metrics = self._construct_metrics( | ||||
|  | @ -359,26 +280,18 @@ class InferenceRouter(Inference): | |||
| 
 | ||||
|     async def openai_embeddings( | ||||
|         self, | ||||
|         model: str, | ||||
|         input: str | list[str], | ||||
|         encoding_format: str | None = "float", | ||||
|         dimensions: int | None = None, | ||||
|         user: str | None = None, | ||||
|         params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)], | ||||
|     ) -> OpenAIEmbeddingsResponse: | ||||
|         logger.debug( | ||||
|             f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", | ||||
|         ) | ||||
|         model_obj = await self._get_model(model, ModelType.embedding) | ||||
|         params = dict( | ||||
|             model=model_obj.identifier, | ||||
|             input=input, | ||||
|             encoding_format=encoding_format, | ||||
|             dimensions=dimensions, | ||||
|             user=user, | ||||
|             f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}", | ||||
|         ) | ||||
|         model_obj = await self._get_model(params.model, ModelType.embedding) | ||||
| 
 | ||||
|         # Update model to use resolved identifier | ||||
|         params.model = model_obj.identifier | ||||
| 
 | ||||
|         provider = await self.routing_table.get_provider_impl(model_obj.identifier) | ||||
|         return await provider.openai_embeddings(**params) | ||||
|         return await provider.openai_embeddings(params) | ||||
| 
 | ||||
|     async def list_chat_completions( | ||||
|         self, | ||||
|  | @ -396,8 +309,10 @@ class InferenceRouter(Inference): | |||
|             return await self.store.get_chat_completion(completion_id) | ||||
|         raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") | ||||
| 
 | ||||
|     async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: | ||||
|         response = await provider.openai_chat_completion(**params) | ||||
|     async def _nonstream_openai_chat_completion( | ||||
|         self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody | ||||
|     ) -> OpenAIChatCompletion: | ||||
|         response = await provider.openai_chat_completion(params) | ||||
|         for choice in response.choices: | ||||
|             # some providers return an empty list for no tool calls in non-streaming responses | ||||
|             # but the OpenAI API returns None. So, set tool_calls to None if it's empty | ||||
|  | @ -611,7 +526,7 @@ class InferenceRouter(Inference): | |||
|                         completion_text += "".join(choice_data["content_parts"]) | ||||
| 
 | ||||
|                     # Add metrics to the chunk | ||||
|                     if self.telemetry and chunk.usage: | ||||
|                     if self.telemetry and hasattr(chunk, "usage") and chunk.usage: | ||||
|                         metrics = self._construct_metrics( | ||||
|                             prompt_tokens=chunk.usage.prompt_tokens, | ||||
|                             completion_tokens=chunk.usage.completion_tokens, | ||||
|  |  | |||
|  | @ -6,12 +6,16 @@ | |||
| 
 | ||||
| import asyncio | ||||
| import uuid | ||||
| from typing import Any | ||||
| from typing import Annotated, Any | ||||
| 
 | ||||
| from fastapi import Body | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import InterleavedContent | ||||
| from llama_stack.apis.models import ModelType | ||||
| from llama_stack.apis.vector_io import ( | ||||
|     Chunk, | ||||
|     OpenAICreateVectorStoreFileBatchRequestWithExtraBody, | ||||
|     OpenAICreateVectorStoreRequestWithExtraBody, | ||||
|     QueryChunksResponse, | ||||
|     SearchRankingOptions, | ||||
|     VectorIO, | ||||
|  | @ -51,30 +55,18 @@ class VectorIORouter(VectorIO): | |||
|         logger.debug("VectorIORouter.shutdown") | ||||
|         pass | ||||
| 
 | ||||
|     async def _get_first_embedding_model(self) -> tuple[str, int] | None: | ||||
|         """Get the first available embedding model identifier.""" | ||||
|         try: | ||||
|             # Get all models from the routing table | ||||
|             all_models = await self.routing_table.get_all_with_type("model") | ||||
|     async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int: | ||||
|         """Get the embedding dimension for a specific embedding model.""" | ||||
|         all_models = await self.routing_table.get_all_with_type("model") | ||||
| 
 | ||||
|             # Filter for embedding models | ||||
|             embedding_models = [ | ||||
|                 model | ||||
|                 for model in all_models | ||||
|                 if hasattr(model, "model_type") and model.model_type == ModelType.embedding | ||||
|             ] | ||||
| 
 | ||||
|             if embedding_models: | ||||
|                 dimension = embedding_models[0].metadata.get("embedding_dimension", None) | ||||
|         for model in all_models: | ||||
|             if model.identifier == embedding_model_id and model.model_type == ModelType.embedding: | ||||
|                 dimension = model.metadata.get("embedding_dimension") | ||||
|                 if dimension is None: | ||||
|                     raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension") | ||||
|                 return embedding_models[0].identifier, dimension | ||||
|             else: | ||||
|                 logger.warning("No embedding models found in the routing table") | ||||
|                 return None | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error getting embedding models: {e}") | ||||
|             return None | ||||
|                     raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata") | ||||
|                 return int(dimension) | ||||
| 
 | ||||
|         raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model") | ||||
| 
 | ||||
|     async def register_vector_db( | ||||
|         self, | ||||
|  | @ -120,24 +112,35 @@ class VectorIORouter(VectorIO): | |||
|     # OpenAI Vector Stores API endpoints | ||||
|     async def openai_create_vector_store( | ||||
|         self, | ||||
|         name: str, | ||||
|         file_ids: list[str] | None = None, | ||||
|         expires_after: dict[str, Any] | None = None, | ||||
|         chunking_strategy: dict[str, Any] | None = None, | ||||
|         metadata: dict[str, Any] | None = None, | ||||
|         embedding_model: str | None = None, | ||||
|         embedding_dimension: int | None = None, | ||||
|         provider_id: str | None = None, | ||||
|         params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)], | ||||
|     ) -> VectorStoreObject: | ||||
|         logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") | ||||
|         # Extract llama-stack-specific parameters from extra_body | ||||
|         extra = params.model_extra or {} | ||||
|         embedding_model = extra.get("embedding_model") | ||||
|         embedding_dimension = extra.get("embedding_dimension") | ||||
|         provider_id = extra.get("provider_id") | ||||
| 
 | ||||
|         # If no embedding model is provided, use the first available one | ||||
|         logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}") | ||||
| 
 | ||||
|         # Require explicit embedding model specification | ||||
|         if embedding_model is None: | ||||
|             embedding_model_info = await self._get_first_embedding_model() | ||||
|             if embedding_model_info is None: | ||||
|                 raise ValueError("No embedding model provided and no embedding models available in the system") | ||||
|             embedding_model, embedding_dimension = embedding_model_info | ||||
|             logger.info(f"No embedding model specified, using first available: {embedding_model}") | ||||
|             raise ValueError("embedding_model is required in extra_body when creating a vector store") | ||||
| 
 | ||||
|         if embedding_dimension is None: | ||||
|             embedding_dimension = await self._get_embedding_model_dimension(embedding_model) | ||||
| 
 | ||||
|         # Auto-select provider if not specified | ||||
|         if provider_id is None: | ||||
|             num_providers = len(self.routing_table.impls_by_provider_id) | ||||
|             if num_providers == 0: | ||||
|                 raise ValueError("No vector_io providers available") | ||||
|             if num_providers > 1: | ||||
|                 available_providers = list(self.routing_table.impls_by_provider_id.keys()) | ||||
|                 raise ValueError( | ||||
|                     f"Multiple vector_io providers available. Please specify provider_id in extra_body. " | ||||
|                     f"Available providers: {available_providers}" | ||||
|                 ) | ||||
|             provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] | ||||
| 
 | ||||
|         vector_db_id = f"vs_{uuid.uuid4()}" | ||||
|         registered_vector_db = await self.routing_table.register_vector_db( | ||||
|  | @ -146,20 +149,19 @@ class VectorIORouter(VectorIO): | |||
|             embedding_dimension=embedding_dimension, | ||||
|             provider_id=provider_id, | ||||
|             provider_vector_db_id=vector_db_id, | ||||
|             vector_db_name=name, | ||||
|             vector_db_name=params.name, | ||||
|         ) | ||||
|         provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) | ||||
|         return await provider.openai_create_vector_store( | ||||
|             name=name, | ||||
|             file_ids=file_ids, | ||||
|             expires_after=expires_after, | ||||
|             chunking_strategy=chunking_strategy, | ||||
|             metadata=metadata, | ||||
|             embedding_model=embedding_model, | ||||
|             embedding_dimension=embedding_dimension, | ||||
|             provider_id=registered_vector_db.provider_id, | ||||
|             provider_vector_db_id=registered_vector_db.provider_resource_id, | ||||
|         ) | ||||
| 
 | ||||
|         # Update model_extra with registered values so provider uses the already-registered vector_db | ||||
|         if params.model_extra is None: | ||||
|             params.model_extra = {} | ||||
|         params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id | ||||
|         params.model_extra["provider_id"] = registered_vector_db.provider_id | ||||
|         params.model_extra["embedding_model"] = embedding_model | ||||
|         params.model_extra["embedding_dimension"] = embedding_dimension | ||||
| 
 | ||||
|         return await provider.openai_create_vector_store(params) | ||||
| 
 | ||||
|     async def openai_list_vector_stores( | ||||
|         self, | ||||
|  | @ -219,7 +221,8 @@ class VectorIORouter(VectorIO): | |||
|         vector_store_id: str, | ||||
|     ) -> VectorStoreObject: | ||||
|         logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") | ||||
|         return await self.routing_table.openai_retrieve_vector_store(vector_store_id) | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store(vector_store_id) | ||||
| 
 | ||||
|     async def openai_update_vector_store( | ||||
|         self, | ||||
|  | @ -229,7 +232,8 @@ class VectorIORouter(VectorIO): | |||
|         metadata: dict[str, Any] | None = None, | ||||
|     ) -> VectorStoreObject: | ||||
|         logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") | ||||
|         return await self.routing_table.openai_update_vector_store( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_update_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             name=name, | ||||
|             expires_after=expires_after, | ||||
|  | @ -241,7 +245,8 @@ class VectorIORouter(VectorIO): | |||
|         vector_store_id: str, | ||||
|     ) -> VectorStoreDeleteResponse: | ||||
|         logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") | ||||
|         return await self.routing_table.openai_delete_vector_store(vector_store_id) | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_delete_vector_store(vector_store_id) | ||||
| 
 | ||||
|     async def openai_search_vector_store( | ||||
|         self, | ||||
|  | @ -254,7 +259,8 @@ class VectorIORouter(VectorIO): | |||
|         search_mode: str | None = "vector", | ||||
|     ) -> VectorStoreSearchResponsePage: | ||||
|         logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") | ||||
|         return await self.routing_table.openai_search_vector_store( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_search_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             query=query, | ||||
|             filters=filters, | ||||
|  | @ -272,7 +278,8 @@ class VectorIORouter(VectorIO): | |||
|         chunking_strategy: VectorStoreChunkingStrategy | None = None, | ||||
|     ) -> VectorStoreFileObject: | ||||
|         logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") | ||||
|         return await self.routing_table.openai_attach_file_to_vector_store( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_attach_file_to_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|             attributes=attributes, | ||||
|  | @ -289,7 +296,8 @@ class VectorIORouter(VectorIO): | |||
|         filter: VectorStoreFileStatus | None = None, | ||||
|     ) -> list[VectorStoreFileObject]: | ||||
|         logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") | ||||
|         return await self.routing_table.openai_list_files_in_vector_store( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_list_files_in_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             limit=limit, | ||||
|             order=order, | ||||
|  | @ -304,7 +312,8 @@ class VectorIORouter(VectorIO): | |||
|         file_id: str, | ||||
|     ) -> VectorStoreFileObject: | ||||
|         logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") | ||||
|         return await self.routing_table.openai_retrieve_vector_store_file( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store_file( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
|  | @ -315,7 +324,8 @@ class VectorIORouter(VectorIO): | |||
|         file_id: str, | ||||
|     ) -> VectorStoreFileContentsResponse: | ||||
|         logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") | ||||
|         return await self.routing_table.openai_retrieve_vector_store_file_contents( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store_file_contents( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
|  | @ -327,7 +337,8 @@ class VectorIORouter(VectorIO): | |||
|         attributes: dict[str, Any], | ||||
|     ) -> VectorStoreFileObject: | ||||
|         logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") | ||||
|         return await self.routing_table.openai_update_vector_store_file( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_update_vector_store_file( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|             attributes=attributes, | ||||
|  | @ -339,7 +350,8 @@ class VectorIORouter(VectorIO): | |||
|         file_id: str, | ||||
|     ) -> VectorStoreFileDeleteResponse: | ||||
|         logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") | ||||
|         return await self.routing_table.openai_delete_vector_store_file( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_delete_vector_store_file( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
|  | @ -370,17 +382,13 @@ class VectorIORouter(VectorIO): | |||
|     async def openai_create_vector_store_file_batch( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         file_ids: list[str], | ||||
|         attributes: dict[str, Any] | None = None, | ||||
|         chunking_strategy: VectorStoreChunkingStrategy | None = None, | ||||
|         params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)], | ||||
|     ) -> VectorStoreFileBatchObject: | ||||
|         logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files") | ||||
|         return await self.routing_table.openai_create_vector_store_file_batch( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_ids=file_ids, | ||||
|             attributes=attributes, | ||||
|             chunking_strategy=chunking_strategy, | ||||
|         logger.debug( | ||||
|             f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files" | ||||
|         ) | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_create_vector_store_file_batch(vector_store_id, params) | ||||
| 
 | ||||
|     async def openai_retrieve_vector_store_file_batch( | ||||
|         self, | ||||
|  | @ -388,7 +396,8 @@ class VectorIORouter(VectorIO): | |||
|         vector_store_id: str, | ||||
|     ) -> VectorStoreFileBatchObject: | ||||
|         logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}") | ||||
|         return await self.routing_table.openai_retrieve_vector_store_file_batch( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store_file_batch( | ||||
|             batch_id=batch_id, | ||||
|             vector_store_id=vector_store_id, | ||||
|         ) | ||||
|  | @ -404,7 +413,8 @@ class VectorIORouter(VectorIO): | |||
|         order: str | None = "desc", | ||||
|     ) -> VectorStoreFilesListInBatchResponse: | ||||
|         logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}") | ||||
|         return await self.routing_table.openai_list_files_in_vector_store_file_batch( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_list_files_in_vector_store_file_batch( | ||||
|             batch_id=batch_id, | ||||
|             vector_store_id=vector_store_id, | ||||
|             after=after, | ||||
|  | @ -420,7 +430,8 @@ class VectorIORouter(VectorIO): | |||
|         vector_store_id: str, | ||||
|     ) -> VectorStoreFileBatchObject: | ||||
|         logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}") | ||||
|         return await self.routing_table.openai_cancel_vector_store_file_batch( | ||||
|         provider = await self.routing_table.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_cancel_vector_store_file_batch( | ||||
|             batch_id=batch_id, | ||||
|             vector_store_id=vector_store_id, | ||||
|         ) | ||||
|  |  | |||
|  | @ -9,7 +9,6 @@ from typing import Any | |||
| from llama_stack.apis.common.errors import ModelNotFoundError | ||||
| from llama_stack.apis.models import Model | ||||
| from llama_stack.apis.resource import ResourceType | ||||
| from llama_stack.apis.scoring_functions import ScoringFn | ||||
| from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed | ||||
| from llama_stack.core.access_control.datatypes import Action | ||||
| from llama_stack.core.datatypes import ( | ||||
|  | @ -17,6 +16,7 @@ from llama_stack.core.datatypes import ( | |||
|     RoutableObject, | ||||
|     RoutableObjectWithProvider, | ||||
|     RoutedProtocol, | ||||
|     ScoringFnWithOwner, | ||||
| ) | ||||
| from llama_stack.core.request_headers import get_authenticated_user | ||||
| from llama_stack.core.store import DistributionRegistry | ||||
|  | @ -114,7 +114,7 @@ class CommonRoutingTableImpl(RoutingTable): | |||
|             elif api == Api.scoring: | ||||
|                 p.scoring_function_store = self | ||||
|                 scoring_functions = await p.list_scoring_functions() | ||||
|                 await add_objects(scoring_functions, pid, ScoringFn) | ||||
|                 await add_objects(scoring_functions, pid, ScoringFnWithOwner) | ||||
|             elif api == Api.eval: | ||||
|                 p.benchmark_store = self | ||||
|             elif api == Api.tool_runtime: | ||||
|  | @ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable): | |||
|         from .scoring_functions import ScoringFunctionsRoutingTable | ||||
|         from .shields import ShieldsRoutingTable | ||||
|         from .toolgroups import ToolGroupsRoutingTable | ||||
|         from .vector_dbs import VectorDBsRoutingTable | ||||
| 
 | ||||
|         def apiname_object(): | ||||
|             if isinstance(self, ModelsRoutingTable): | ||||
|                 return ("Inference", "model") | ||||
|             elif isinstance(self, ShieldsRoutingTable): | ||||
|                 return ("Safety", "shield") | ||||
|             elif isinstance(self, VectorDBsRoutingTable): | ||||
|                 return ("VectorIO", "vector_db") | ||||
|             elif isinstance(self, DatasetsRoutingTable): | ||||
|                 return ("DatasetIO", "dataset") | ||||
|             elif isinstance(self, ScoringFunctionsRoutingTable): | ||||
|  |  | |||
|  | @ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): | |||
|             try: | ||||
|                 models = await provider.list_models() | ||||
|             except Exception as e: | ||||
|                 logger.warning(f"Model refresh failed for provider {provider_id}: {e}") | ||||
|                 logger.debug(f"Model refresh failed for provider {provider_id}: {e}") | ||||
|                 continue | ||||
| 
 | ||||
|             self.listed_providers.add(provider_id) | ||||
|  | @ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): | |||
|             raise ValueError(f"Provider {model.provider_id} not found in the routing table") | ||||
|         return self.impls_by_provider_id[model.provider_id] | ||||
| 
 | ||||
|     async def has_model(self, model_id: str) -> bool: | ||||
|         """ | ||||
|         Check if a model exists in the routing table. | ||||
| 
 | ||||
|         :param model_id: The model identifier to check | ||||
|         :return: True if the model exists, False otherwise | ||||
|         """ | ||||
|         try: | ||||
|             await lookup_model(self, model_id) | ||||
|             return True | ||||
|         except ModelNotFoundError: | ||||
|             return False | ||||
| 
 | ||||
|     async def register_model( | ||||
|         self, | ||||
|         model_id: str, | ||||
|  |  | |||
|  | @ -1,247 +0,0 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from typing import Any | ||||
| 
 | ||||
| from pydantic import TypeAdapter | ||||
| 
 | ||||
| from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError | ||||
| from llama_stack.apis.models import ModelType | ||||
| from llama_stack.apis.resource import ResourceType | ||||
| from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs | ||||
| from llama_stack.apis.vector_io.vector_io import ( | ||||
|     SearchRankingOptions, | ||||
|     VectorStoreChunkingStrategy, | ||||
|     VectorStoreDeleteResponse, | ||||
|     VectorStoreFileContentsResponse, | ||||
|     VectorStoreFileDeleteResponse, | ||||
|     VectorStoreFileObject, | ||||
|     VectorStoreFileStatus, | ||||
|     VectorStoreObject, | ||||
|     VectorStoreSearchResponsePage, | ||||
| ) | ||||
| from llama_stack.core.datatypes import ( | ||||
|     VectorDBWithOwner, | ||||
| ) | ||||
| from llama_stack.log import get_logger | ||||
| 
 | ||||
| from .common import CommonRoutingTableImpl, lookup_model | ||||
| 
 | ||||
| logger = get_logger(name=__name__, category="core::routing_tables") | ||||
| 
 | ||||
| 
 | ||||
| class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): | ||||
|     async def list_vector_dbs(self) -> ListVectorDBsResponse: | ||||
|         return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) | ||||
| 
 | ||||
|     async def get_vector_db(self, vector_db_id: str) -> VectorDB: | ||||
|         vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) | ||||
|         if vector_db is None: | ||||
|             raise VectorStoreNotFoundError(vector_db_id) | ||||
|         return vector_db | ||||
| 
 | ||||
|     async def register_vector_db( | ||||
|         self, | ||||
|         vector_db_id: str, | ||||
|         embedding_model: str, | ||||
|         embedding_dimension: int | None = 384, | ||||
|         provider_id: str | None = None, | ||||
|         provider_vector_db_id: str | None = None, | ||||
|         vector_db_name: str | None = None, | ||||
|     ) -> VectorDB: | ||||
|         if provider_id is None: | ||||
|             if len(self.impls_by_provider_id) > 0: | ||||
|                 provider_id = list(self.impls_by_provider_id.keys())[0] | ||||
|                 if len(self.impls_by_provider_id) > 1: | ||||
|                     logger.warning( | ||||
|                         f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." | ||||
|                     ) | ||||
|             else: | ||||
|                 raise ValueError("No provider available. Please configure a vector_io provider.") | ||||
|         model = await lookup_model(self, embedding_model) | ||||
|         if model is None: | ||||
|             raise ModelNotFoundError(embedding_model) | ||||
|         if model.model_type != ModelType.embedding: | ||||
|             raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) | ||||
|         if "embedding_dimension" not in model.metadata: | ||||
|             raise ValueError(f"Model {embedding_model} does not have an embedding dimension") | ||||
| 
 | ||||
|         provider = self.impls_by_provider_id[provider_id] | ||||
|         logger.warning( | ||||
|             "VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly." | ||||
|         ) | ||||
|         vector_store = await provider.openai_create_vector_store( | ||||
|             name=vector_db_name or vector_db_id, | ||||
|             embedding_model=embedding_model, | ||||
|             embedding_dimension=model.metadata["embedding_dimension"], | ||||
|             provider_id=provider_id, | ||||
|             provider_vector_db_id=provider_vector_db_id, | ||||
|         ) | ||||
| 
 | ||||
|         vector_store_id = vector_store.id | ||||
|         actual_provider_vector_db_id = provider_vector_db_id or vector_store_id | ||||
|         logger.warning( | ||||
|             f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name" | ||||
|         ) | ||||
| 
 | ||||
|         vector_db_data = { | ||||
|             "identifier": vector_store_id, | ||||
|             "type": ResourceType.vector_db.value, | ||||
|             "provider_id": provider_id, | ||||
|             "provider_resource_id": actual_provider_vector_db_id, | ||||
|             "embedding_model": embedding_model, | ||||
|             "embedding_dimension": model.metadata["embedding_dimension"], | ||||
|             "vector_db_name": vector_store.name, | ||||
|         } | ||||
|         vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) | ||||
|         await self.register_object(vector_db) | ||||
|         return vector_db | ||||
| 
 | ||||
|     async def unregister_vector_db(self, vector_db_id: str) -> None: | ||||
|         existing_vector_db = await self.get_vector_db(vector_db_id) | ||||
|         await self.unregister_object(existing_vector_db) | ||||
| 
 | ||||
|     async def openai_retrieve_vector_store( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|     ) -> VectorStoreObject: | ||||
|         await self.assert_action_allowed("read", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store(vector_store_id) | ||||
| 
 | ||||
|     async def openai_update_vector_store( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         name: str | None = None, | ||||
|         expires_after: dict[str, Any] | None = None, | ||||
|         metadata: dict[str, Any] | None = None, | ||||
|     ) -> VectorStoreObject: | ||||
|         await self.assert_action_allowed("update", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_update_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             name=name, | ||||
|             expires_after=expires_after, | ||||
|             metadata=metadata, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_delete_vector_store( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|     ) -> VectorStoreDeleteResponse: | ||||
|         await self.assert_action_allowed("delete", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         result = await provider.openai_delete_vector_store(vector_store_id) | ||||
|         await self.unregister_vector_db(vector_store_id) | ||||
|         return result | ||||
| 
 | ||||
|     async def openai_search_vector_store( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         query: str | list[str], | ||||
|         filters: dict[str, Any] | None = None, | ||||
|         max_num_results: int | None = 10, | ||||
|         ranking_options: SearchRankingOptions | None = None, | ||||
|         rewrite_query: bool | None = False, | ||||
|         search_mode: str | None = "vector", | ||||
|     ) -> VectorStoreSearchResponsePage: | ||||
|         await self.assert_action_allowed("read", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_search_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             query=query, | ||||
|             filters=filters, | ||||
|             max_num_results=max_num_results, | ||||
|             ranking_options=ranking_options, | ||||
|             rewrite_query=rewrite_query, | ||||
|             search_mode=search_mode, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_attach_file_to_vector_store( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         file_id: str, | ||||
|         attributes: dict[str, Any] | None = None, | ||||
|         chunking_strategy: VectorStoreChunkingStrategy | None = None, | ||||
|     ) -> VectorStoreFileObject: | ||||
|         await self.assert_action_allowed("update", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_attach_file_to_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|             attributes=attributes, | ||||
|             chunking_strategy=chunking_strategy, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_list_files_in_vector_store( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         limit: int | None = 20, | ||||
|         order: str | None = "desc", | ||||
|         after: str | None = None, | ||||
|         before: str | None = None, | ||||
|         filter: VectorStoreFileStatus | None = None, | ||||
|     ) -> list[VectorStoreFileObject]: | ||||
|         await self.assert_action_allowed("read", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_list_files_in_vector_store( | ||||
|             vector_store_id=vector_store_id, | ||||
|             limit=limit, | ||||
|             order=order, | ||||
|             after=after, | ||||
|             before=before, | ||||
|             filter=filter, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_retrieve_vector_store_file( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         file_id: str, | ||||
|     ) -> VectorStoreFileObject: | ||||
|         await self.assert_action_allowed("read", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store_file( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_retrieve_vector_store_file_contents( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         file_id: str, | ||||
|     ) -> VectorStoreFileContentsResponse: | ||||
|         await self.assert_action_allowed("read", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_retrieve_vector_store_file_contents( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_update_vector_store_file( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         file_id: str, | ||||
|         attributes: dict[str, Any], | ||||
|     ) -> VectorStoreFileObject: | ||||
|         await self.assert_action_allowed("update", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_update_vector_store_file( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|             attributes=attributes, | ||||
|         ) | ||||
| 
 | ||||
|     async def openai_delete_vector_store_file( | ||||
|         self, | ||||
|         vector_store_id: str, | ||||
|         file_id: str, | ||||
|     ) -> VectorStoreFileDeleteResponse: | ||||
|         await self.assert_action_allowed("delete", "vector_db", vector_store_id) | ||||
|         provider = await self.get_provider_impl(vector_store_id) | ||||
|         return await provider.openai_delete_vector_store_file( | ||||
|             vector_store_id=vector_store_id, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
|  | @ -27,6 +27,11 @@ class AuthenticationMiddleware: | |||
|     3. Extracts user attributes from the provider's response | ||||
|     4. Makes these attributes available to the route handlers for access control | ||||
| 
 | ||||
|     Unauthenticated Access: | ||||
|     Endpoints can opt out of authentication by setting require_authentication=False | ||||
|     in their @webmethod decorator. This is typically used for operational endpoints | ||||
|     like /health and /version to support monitoring, load balancers, and observability tools. | ||||
| 
 | ||||
|     The middleware supports multiple authentication providers through the AuthProvider interface: | ||||
|     - Kubernetes: Validates tokens against the Kubernetes API server | ||||
|     - Custom: Validates tokens against a custom endpoint | ||||
|  | @ -88,7 +93,26 @@ class AuthenticationMiddleware: | |||
| 
 | ||||
|     async def __call__(self, scope, receive, send): | ||||
|         if scope["type"] == "http": | ||||
|             # First, handle authentication | ||||
|             # Find the route and check if authentication is required | ||||
|             path = scope.get("path", "") | ||||
|             method = scope.get("method", hdrs.METH_GET) | ||||
| 
 | ||||
|             if not hasattr(self, "route_impls"): | ||||
|                 self.route_impls = initialize_route_impls(self.impls) | ||||
| 
 | ||||
|             webmethod = None | ||||
|             try: | ||||
|                 _, _, _, webmethod = find_matching_route(method, path, self.route_impls) | ||||
|             except ValueError: | ||||
|                 # If no matching endpoint is found, pass here to run auth anyways | ||||
|                 pass | ||||
| 
 | ||||
|             # If webmethod explicitly sets require_authentication=False, allow without auth | ||||
|             if webmethod and webmethod.require_authentication is False: | ||||
|                 logger.debug(f"Allowing unauthenticated access to endpoint: {path}") | ||||
|                 return await self.app(scope, receive, send) | ||||
| 
 | ||||
|             # Handle authentication | ||||
|             headers = dict(scope.get("headers", [])) | ||||
|             auth_header = headers.get(b"authorization", b"").decode() | ||||
| 
 | ||||
|  | @ -127,19 +151,7 @@ class AuthenticationMiddleware: | |||
|             ) | ||||
| 
 | ||||
|             # Scope-based API access control | ||||
|             path = scope.get("path", "") | ||||
|             method = scope.get("method", hdrs.METH_GET) | ||||
| 
 | ||||
|             if not hasattr(self, "route_impls"): | ||||
|                 self.route_impls = initialize_route_impls(self.impls) | ||||
| 
 | ||||
|             try: | ||||
|                 _, _, _, webmethod = find_matching_route(method, path, self.route_impls) | ||||
|             except ValueError: | ||||
|                 # If no matching endpoint is found, pass through to FastAPI | ||||
|                 return await self.app(scope, receive, send) | ||||
| 
 | ||||
|             if webmethod.required_scope: | ||||
|             if webmethod and webmethod.required_scope: | ||||
|                 user = user_from_scope(scope) | ||||
|                 if not _has_required_scope(webmethod.required_scope, user): | ||||
|                     return await self._send_auth_error( | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ | |||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import argparse | ||||
| import asyncio | ||||
| import concurrent.futures | ||||
| import functools | ||||
|  | @ -12,7 +11,6 @@ import inspect | |||
| import json | ||||
| import logging  # allow-direct-logging | ||||
| import os | ||||
| import ssl | ||||
| import sys | ||||
| import traceback | ||||
| import warnings | ||||
|  | @ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError | |||
| 
 | ||||
| from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError | ||||
| from llama_stack.apis.common.responses import PaginatedResponse | ||||
| from llama_stack.cli.utils import add_config_distro_args, get_config_from_args | ||||
| from llama_stack.core.access_control.access_control import AccessDeniedError | ||||
| from llama_stack.core.datatypes import ( | ||||
|     AuthenticationRequiredError, | ||||
|  | @ -55,7 +52,6 @@ from llama_stack.core.stack import ( | |||
|     Stack, | ||||
|     cast_image_name_to_string, | ||||
|     replace_env_vars, | ||||
|     validate_env_pair, | ||||
| ) | ||||
| from llama_stack.core.utils.config import redact_sensitive_fields | ||||
| from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro | ||||
|  | @ -142,6 +138,13 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro | |||
|         return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}") | ||||
|     elif isinstance(exc, AuthenticationRequiredError): | ||||
|         return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}") | ||||
|     elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int): | ||||
|         # Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses) | ||||
|         # These include AuthenticationError (401), PermissionDeniedError (403), etc. | ||||
|         # This preserves the actual HTTP status code from the provider | ||||
|         status_code = exc.status_code | ||||
|         detail = str(exc) | ||||
|         return HTTPException(status_code=status_code, detail=detail) | ||||
|     else: | ||||
|         return HTTPException( | ||||
|             status_code=httpx.codes.INTERNAL_SERVER_ERROR, | ||||
|  | @ -181,7 +184,17 @@ async def lifespan(app: StackApp): | |||
| 
 | ||||
| def is_streaming_request(func_name: str, request: Request, **kwargs): | ||||
|     # TODO: pass the api method and punt it to the Protocol definition directly | ||||
|     return kwargs.get("stream", False) | ||||
|     # If there's a stream parameter at top level, use it | ||||
|     if "stream" in kwargs: | ||||
|         return kwargs["stream"] | ||||
| 
 | ||||
|     # If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it | ||||
|     if "params" in kwargs: | ||||
|         params = kwargs["params"] | ||||
|         if hasattr(params, "stream"): | ||||
|             return params.stream | ||||
| 
 | ||||
|     return False | ||||
| 
 | ||||
| 
 | ||||
| async def maybe_await(value): | ||||
|  | @ -236,15 +249,31 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: | |||
| 
 | ||||
|         await log_request_pre_validation(request) | ||||
| 
 | ||||
|         test_context_token = None | ||||
|         test_context_var = None | ||||
|         reset_test_context_fn = None | ||||
| 
 | ||||
|         # Use context manager with both provider data and auth attributes | ||||
|         with request_provider_data_context(request.headers, user): | ||||
|             if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"): | ||||
|                 from llama_stack.core.testing_context import ( | ||||
|                     TEST_CONTEXT, | ||||
|                     reset_test_context, | ||||
|                     sync_test_context_from_provider_data, | ||||
|                 ) | ||||
| 
 | ||||
|                 test_context_token = sync_test_context_from_provider_data() | ||||
|                 test_context_var = TEST_CONTEXT | ||||
|                 reset_test_context_fn = reset_test_context | ||||
| 
 | ||||
|             is_streaming = is_streaming_request(func.__name__, request, **kwargs) | ||||
| 
 | ||||
|             try: | ||||
|                 if is_streaming: | ||||
|                     gen = preserve_contexts_async_generator( | ||||
|                         sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] | ||||
|                     ) | ||||
|                     context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] | ||||
|                     if test_context_var is not None: | ||||
|                         context_vars.append(test_context_var) | ||||
|                     gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars) | ||||
|                     return StreamingResponse(gen, media_type="text/event-stream") | ||||
|                 else: | ||||
|                     value = func(**kwargs) | ||||
|  | @ -262,6 +291,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: | |||
|                 else: | ||||
|                     logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") | ||||
|                 raise translate_exception(e) from e | ||||
|             finally: | ||||
|                 if test_context_token is not None and reset_test_context_fn is not None: | ||||
|                     reset_test_context_fn(test_context_token) | ||||
| 
 | ||||
|     sig = inspect.signature(func) | ||||
| 
 | ||||
|  | @ -333,23 +365,18 @@ class ClientVersionMiddleware: | |||
|         return await self.app(scope, receive, send) | ||||
| 
 | ||||
| 
 | ||||
| def create_app( | ||||
|     config_file: str | None = None, | ||||
|     env_vars: list[str] | None = None, | ||||
| ) -> StackApp: | ||||
| def create_app() -> StackApp: | ||||
|     """Create and configure the FastAPI application. | ||||
| 
 | ||||
|     Args: | ||||
|         config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution. | ||||
|         env_vars: List of environment variables in KEY=value format. | ||||
|         disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var. | ||||
|     This factory function reads configuration from environment variables: | ||||
|     - LLAMA_STACK_CONFIG: Path to config file (required) | ||||
| 
 | ||||
|     Returns: | ||||
|         Configured StackApp instance. | ||||
|     """ | ||||
|     config_file = config_file or os.getenv("LLAMA_STACK_CONFIG") | ||||
|     config_file = os.getenv("LLAMA_STACK_CONFIG") | ||||
|     if config_file is None: | ||||
|         raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set") | ||||
|         raise ValueError("LLAMA_STACK_CONFIG environment variable is required") | ||||
| 
 | ||||
|     config_file = resolve_config_or_distro(config_file, Mode.RUN) | ||||
| 
 | ||||
|  | @ -361,16 +388,6 @@ def create_app( | |||
|             logger_config = LoggingConfig(**cfg) | ||||
|         logger = get_logger(name=__name__, category="core::server", config=logger_config) | ||||
| 
 | ||||
|         if env_vars: | ||||
|             for env_pair in env_vars: | ||||
|                 try: | ||||
|                     key, value = validate_env_pair(env_pair) | ||||
|                     logger.info(f"Setting environment variable {key} => {value}") | ||||
|                     os.environ[key] = value | ||||
|                 except ValueError as e: | ||||
|                     logger.error(f"Error: {str(e)}") | ||||
|                     raise ValueError(f"Invalid environment variable format: {env_pair}") from e | ||||
| 
 | ||||
|         config = replace_env_vars(config_contents) | ||||
|         config = StackRunConfig(**cast_image_name_to_string(config)) | ||||
| 
 | ||||
|  | @ -494,101 +511,6 @@ def create_app( | |||
|     return app | ||||
| 
 | ||||
| 
 | ||||
| def main(args: argparse.Namespace | None = None): | ||||
|     """Start the LlamaStack server.""" | ||||
|     parser = argparse.ArgumentParser(description="Start the LlamaStack server.") | ||||
| 
 | ||||
|     add_config_distro_args(parser) | ||||
|     parser.add_argument( | ||||
|         "--port", | ||||
|         type=int, | ||||
|         default=int(os.getenv("LLAMA_STACK_PORT", 8321)), | ||||
|         help="Port to listen on", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env", | ||||
|         action="append", | ||||
|         help="Environment variables in KEY=value format. Can be specified multiple times.", | ||||
|     ) | ||||
| 
 | ||||
|     # Determine whether the server args are being passed by the "run" command, if this is the case | ||||
|     # the args will be passed as a Namespace object to the main function, otherwise they will be | ||||
|     # parsed from the command line | ||||
|     if args is None: | ||||
|         args = parser.parse_args() | ||||
| 
 | ||||
|     config_or_distro = get_config_from_args(args) | ||||
| 
 | ||||
|     try: | ||||
|         app = create_app( | ||||
|             config_file=config_or_distro, | ||||
|             env_vars=args.env, | ||||
|         ) | ||||
|     except Exception as e: | ||||
|         logger.error(f"Error creating app: {str(e)}") | ||||
|         sys.exit(1) | ||||
| 
 | ||||
|     config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) | ||||
|     with open(config_file) as fp: | ||||
|         config_contents = yaml.safe_load(fp) | ||||
|         if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): | ||||
|             logger_config = LoggingConfig(**cfg) | ||||
|         else: | ||||
|             logger_config = None | ||||
|         config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents))) | ||||
| 
 | ||||
|     import uvicorn | ||||
| 
 | ||||
|     # Configure SSL if certificates are provided | ||||
|     port = args.port or config.server.port | ||||
| 
 | ||||
|     ssl_config = None | ||||
|     keyfile = config.server.tls_keyfile | ||||
|     certfile = config.server.tls_certfile | ||||
| 
 | ||||
|     if keyfile and certfile: | ||||
|         ssl_config = { | ||||
|             "ssl_keyfile": keyfile, | ||||
|             "ssl_certfile": certfile, | ||||
|         } | ||||
|         if config.server.tls_cafile: | ||||
|             ssl_config["ssl_ca_certs"] = config.server.tls_cafile | ||||
|             ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED | ||||
|             logger.info( | ||||
|                 f"HTTPS enabled with certificates:\n  Key: {keyfile}\n  Cert: {certfile}\n  CA: {config.server.tls_cafile}" | ||||
|             ) | ||||
|         else: | ||||
|             logger.info(f"HTTPS enabled with certificates:\n  Key: {keyfile}\n  Cert: {certfile}") | ||||
| 
 | ||||
|     listen_host = config.server.host or ["::", "0.0.0.0"] | ||||
|     logger.info(f"Listening on {listen_host}:{port}") | ||||
| 
 | ||||
|     uvicorn_config = { | ||||
|         "app": app, | ||||
|         "host": listen_host, | ||||
|         "port": port, | ||||
|         "lifespan": "on", | ||||
|         "log_level": logger.getEffectiveLevel(), | ||||
|         "log_config": logger_config, | ||||
|     } | ||||
|     if ssl_config: | ||||
|         uvicorn_config.update(ssl_config) | ||||
| 
 | ||||
|     # We need to catch KeyboardInterrupt because uvicorn's signal handling | ||||
|     # re-raises SIGINT signals using signal.raise_signal(), which Python | ||||
|     # converts to KeyboardInterrupt. Without this catch, we'd get a confusing | ||||
|     # stack trace when using Ctrl+C or kill -2 (SIGINT). | ||||
|     # SIGTERM (kill -15) works fine without this because Python doesn't | ||||
|     # have a default handler for it. | ||||
|     # | ||||
|     # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own | ||||
|     # signal handling but this is quite intrusive and not worth the effort. | ||||
|     try: | ||||
|         asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) | ||||
|     except (KeyboardInterrupt, SystemExit): | ||||
|         logger.info("Received interrupt signal, shutting down gracefully...") | ||||
| 
 | ||||
| 
 | ||||
| def _log_run_config(run_config: StackRunConfig): | ||||
|     """Logs the run config with redacted fields and disabled providers removed.""" | ||||
|     logger.info("Run configuration:") | ||||
|  | @ -615,7 +537,3 @@ def remove_disabled_providers(obj): | |||
|         return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None] | ||||
|     else: | ||||
|         return obj | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  |  | |||
|  | @ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields | |||
| from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration | ||||
| from llama_stack.apis.telemetry import Telemetry | ||||
| from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime | ||||
| from llama_stack.apis.vector_dbs import VectorDBs | ||||
| from llama_stack.apis.vector_io import VectorIO | ||||
| from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl | ||||
| from llama_stack.core.datatypes import Provider, StackRunConfig | ||||
|  | @ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core") | |||
| 
 | ||||
| class LlamaStack( | ||||
|     Providers, | ||||
|     VectorDBs, | ||||
|     Inference, | ||||
|     Agents, | ||||
|     Safety, | ||||
|  | @ -83,7 +81,6 @@ class LlamaStack( | |||
| RESOURCES = [ | ||||
|     ("models", Api.models, "register_model", "list_models"), | ||||
|     ("shields", Api.shields, "register_shield", "list_shields"), | ||||
|     ("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"), | ||||
|     ("datasets", Api.datasets, "register_dataset", "list_datasets"), | ||||
|     ( | ||||
|         "scoring_fns", | ||||
|  | @ -274,22 +271,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]: | |||
|     return config_dict | ||||
| 
 | ||||
| 
 | ||||
| def validate_env_pair(env_pair: str) -> tuple[str, str]: | ||||
|     """Validate and split an environment variable key-value pair.""" | ||||
|     try: | ||||
|         key, value = env_pair.split("=", 1) | ||||
|         key = key.strip() | ||||
|         if not key: | ||||
|             raise ValueError(f"Empty key in environment variable pair: {env_pair}") | ||||
|         if not all(c.isalnum() or c == "_" for c in key): | ||||
|             raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") | ||||
|         return key, value | ||||
|     except ValueError as e: | ||||
|         raise ValueError( | ||||
|             f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value" | ||||
|         ) from e | ||||
| 
 | ||||
| 
 | ||||
| def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None: | ||||
|     """Add internal implementations (inspect and providers) to the implementations dictionary. | ||||
| 
 | ||||
|  | @ -332,22 +313,27 @@ class Stack: | |||
|     # asked for in the run config. | ||||
|     async def initialize(self): | ||||
|         if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: | ||||
|             from llama_stack.testing.inference_recorder import setup_inference_recording | ||||
|             from llama_stack.testing.api_recorder import setup_api_recording | ||||
| 
 | ||||
|             global TEST_RECORDING_CONTEXT | ||||
|             TEST_RECORDING_CONTEXT = setup_inference_recording() | ||||
|             TEST_RECORDING_CONTEXT = setup_api_recording() | ||||
|             if TEST_RECORDING_CONTEXT: | ||||
|                 TEST_RECORDING_CONTEXT.__enter__() | ||||
|                 logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") | ||||
|                 logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") | ||||
| 
 | ||||
|         dist_registry, _ = await create_dist_registry(self.run_config.persistence, self.run_config.image_name) | ||||
|         policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] | ||||
|         impls = await resolve_impls( | ||||
|             self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy | ||||
|         ) | ||||
| 
 | ||||
|         # Add internal implementations after all other providers are resolved | ||||
|         add_internal_implementations(impls, self.run_config) | ||||
|         internal_impls = {} | ||||
|         add_internal_implementations(internal_impls, self.run_config) | ||||
| 
 | ||||
|         impls = await resolve_impls( | ||||
|             self.run_config, | ||||
|             self.provider_registry or get_provider_registry(self.run_config), | ||||
|             dist_registry, | ||||
|             policy, | ||||
|             internal_impls, | ||||
|         ) | ||||
| 
 | ||||
|         if Api.prompts in impls: | ||||
|             await impls[Api.prompts].initialize() | ||||
|  | @ -397,7 +383,7 @@ class Stack: | |||
|             try: | ||||
|                 TEST_RECORDING_CONTEXT.__exit__(None, None, None) | ||||
|             except Exception as e: | ||||
|                 logger.error(f"Error during inference recording cleanup: {e}") | ||||
|                 logger.error(f"Error during API recording cleanup: {e}") | ||||
| 
 | ||||
|         global REGISTRY_REFRESH_TASK | ||||
|         if REGISTRY_REFRESH_TASK: | ||||
|  |  | |||
|  | @ -25,7 +25,7 @@ error_handler() { | |||
| trap 'error_handler ${LINENO}' ERR | ||||
| 
 | ||||
| if [ $# -lt 3 ]; then | ||||
|   echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..." | ||||
|   echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]" | ||||
|   exit 1 | ||||
| fi | ||||
| 
 | ||||
|  | @ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")") | |||
| 
 | ||||
| # Initialize variables | ||||
| yaml_config="" | ||||
| env_vars="" | ||||
| other_args="" | ||||
| 
 | ||||
| # Process remaining arguments | ||||
|  | @ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do | |||
|         exit 1 | ||||
|       fi | ||||
|       ;; | ||||
|     --env) | ||||
|       if [[ -n "$2" ]]; then | ||||
|         env_vars="$env_vars --env $2" | ||||
|         shift 2 | ||||
|       else | ||||
|         echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2 | ||||
|         exit 1 | ||||
|       fi | ||||
|       ;; | ||||
|     *) | ||||
|       other_args="$other_args $1" | ||||
|       shift | ||||
|  | @ -116,10 +106,9 @@ if [[ "$env_type" == "venv" ]]; then | |||
|         yaml_config_arg="" | ||||
|     fi | ||||
| 
 | ||||
|     $PYTHON_BINARY -m llama_stack.core.server.server \ | ||||
|     llama stack run \ | ||||
|     $yaml_config_arg \ | ||||
|     --port "$port" \ | ||||
|     $env_vars \ | ||||
|     $other_args | ||||
| elif [[ "$env_type" == "container" ]]; then | ||||
|     echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}" | ||||
|  |  | |||
|  | @ -95,9 +95,11 @@ class DiskDistributionRegistry(DistributionRegistry): | |||
| 
 | ||||
|     async def register(self, obj: RoutableObjectWithProvider) -> bool: | ||||
|         existing_obj = await self.get(obj.type, obj.identifier) | ||||
|         # dont register if the object's providerid already exists | ||||
|         if existing_obj and existing_obj.provider_id == obj.provider_id: | ||||
|             return False | ||||
|         if existing_obj and existing_obj != obj: | ||||
|             raise ValueError( | ||||
|                 f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. " | ||||
|                 "Unregister it first if you want to replace it." | ||||
|             ) | ||||
| 
 | ||||
|         await self.kvstore.set( | ||||
|             KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), | ||||
|  |  | |||
							
								
								
									
										44
									
								
								llama_stack/core/testing_context.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								llama_stack/core/testing_context.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,44 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import os | ||||
| from contextvars import ContextVar | ||||
| 
 | ||||
| from llama_stack.core.request_headers import PROVIDER_DATA_VAR | ||||
| 
 | ||||
| TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None) | ||||
| 
 | ||||
| 
 | ||||
| def get_test_context() -> str | None: | ||||
|     return TEST_CONTEXT.get() | ||||
| 
 | ||||
| 
 | ||||
| def set_test_context(value: str | None): | ||||
|     return TEST_CONTEXT.set(value) | ||||
| 
 | ||||
| 
 | ||||
| def reset_test_context(token) -> None: | ||||
|     TEST_CONTEXT.reset(token) | ||||
| 
 | ||||
| 
 | ||||
| def sync_test_context_from_provider_data(): | ||||
|     """Sync test context from provider data when running in server test mode.""" | ||||
|     if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ: | ||||
|         return None | ||||
| 
 | ||||
|     stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") | ||||
|     if stack_config_type != "server": | ||||
|         return None | ||||
| 
 | ||||
|     try: | ||||
|         provider_data = PROVIDER_DATA_VAR.get() | ||||
|     except LookupError: | ||||
|         provider_data = None | ||||
| 
 | ||||
|     if provider_data and "__test_id" in provider_data: | ||||
|         return TEST_CONTEXT.set(provider_data["__test_id"]) | ||||
| 
 | ||||
|     return None | ||||
|  | @ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks | |||
| from llama_stack.core.ui.page.distribution.models import models | ||||
| from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions | ||||
| from llama_stack.core.ui.page.distribution.shields import shields | ||||
| from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs | ||||
| 
 | ||||
| 
 | ||||
| def resources_page(): | ||||
|     options = [ | ||||
|         "Models", | ||||
|         "Vector Databases", | ||||
|         "Shields", | ||||
|         "Scoring Functions", | ||||
|         "Datasets", | ||||
|         "Benchmarks", | ||||
|     ] | ||||
|     icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"] | ||||
|     icons = ["magic", "shield", "file-bar-graph", "database", "list-task"] | ||||
|     selected_resource = option_menu( | ||||
|         None, | ||||
|         options, | ||||
|  | @ -37,8 +35,6 @@ def resources_page(): | |||
|     ) | ||||
|     if selected_resource == "Benchmarks": | ||||
|         benchmarks() | ||||
|     elif selected_resource == "Vector Databases": | ||||
|         vector_dbs() | ||||
|     elif selected_resource == "Datasets": | ||||
|         datasets() | ||||
|     elif selected_resource == "Models": | ||||
|  |  | |||
|  | @ -1,20 +0,0 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import streamlit as st | ||||
| 
 | ||||
| from llama_stack.core.ui.modules.api import llama_stack_api | ||||
| 
 | ||||
| 
 | ||||
| def vector_dbs(): | ||||
|     st.header("Vector Databases") | ||||
|     vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()} | ||||
| 
 | ||||
|     if len(vector_dbs_info) > 0: | ||||
|         selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys())) | ||||
|         st.json(vector_dbs_info[selected_vector_db]) | ||||
|     else: | ||||
|         st.info("No vector databases found") | ||||
|  | @ -1,301 +0,0 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import uuid | ||||
| 
 | ||||
| import streamlit as st | ||||
| from llama_stack_client import Agent, AgentEventLogger, RAGDocument | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import ToolCallDelta | ||||
| from llama_stack.core.ui.modules.api import llama_stack_api | ||||
| from llama_stack.core.ui.modules.utils import data_url_from_file | ||||
| 
 | ||||
| 
 | ||||
| def rag_chat_page(): | ||||
|     st.title("🦙 RAG") | ||||
| 
 | ||||
|     def reset_agent_and_chat(): | ||||
|         st.session_state.clear() | ||||
|         st.cache_resource.clear() | ||||
| 
 | ||||
|     def should_disable_input(): | ||||
|         return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0 | ||||
| 
 | ||||
|     def log_message(message): | ||||
|         with st.chat_message(message["role"]): | ||||
|             if "tool_output" in message and message["tool_output"]: | ||||
|                 with st.expander(label="Tool Output", expanded=False, icon="🛠"): | ||||
|                     st.write(message["tool_output"]) | ||||
|             st.markdown(message["content"]) | ||||
| 
 | ||||
|     with st.sidebar: | ||||
|         # File/Directory Upload Section | ||||
|         st.subheader("Upload Documents", divider=True) | ||||
|         uploaded_files = st.file_uploader( | ||||
|             "Upload file(s) or directory", | ||||
|             accept_multiple_files=True, | ||||
|             type=["txt", "pdf", "doc", "docx"],  # Add more file types as needed | ||||
|         ) | ||||
|         # Process uploaded files | ||||
|         if uploaded_files: | ||||
|             st.success(f"Successfully uploaded {len(uploaded_files)} files") | ||||
|             # Add memory bank name input field | ||||
|             vector_db_name = st.text_input( | ||||
|                 "Document Collection Name", | ||||
|                 value="rag_vector_db", | ||||
|                 help="Enter a unique identifier for this document collection", | ||||
|             ) | ||||
|             if st.button("Create Document Collection"): | ||||
|                 documents = [ | ||||
|                     RAGDocument( | ||||
|                         document_id=uploaded_file.name, | ||||
|                         content=data_url_from_file(uploaded_file), | ||||
|                     ) | ||||
|                     for i, uploaded_file in enumerate(uploaded_files) | ||||
|                 ] | ||||
| 
 | ||||
|                 providers = llama_stack_api.client.providers.list() | ||||
|                 vector_io_provider = None | ||||
| 
 | ||||
|                 for x in providers: | ||||
|                     if x.api == "vector_io": | ||||
|                         vector_io_provider = x.provider_id | ||||
| 
 | ||||
|                 llama_stack_api.client.vector_dbs.register( | ||||
|                     vector_db_id=vector_db_name,  # Use the user-provided name | ||||
|                     embedding_dimension=384, | ||||
|                     embedding_model="all-MiniLM-L6-v2", | ||||
|                     provider_id=vector_io_provider, | ||||
|                 ) | ||||
| 
 | ||||
|                 # insert documents using the custom vector db name | ||||
|                 llama_stack_api.client.tool_runtime.rag_tool.insert( | ||||
|                     vector_db_id=vector_db_name,  # Use the user-provided name | ||||
|                     documents=documents, | ||||
|                     chunk_size_in_tokens=512, | ||||
|                 ) | ||||
|                 st.success("Vector database created successfully!") | ||||
| 
 | ||||
|         st.subheader("RAG Parameters", divider=True) | ||||
| 
 | ||||
|         rag_mode = st.radio( | ||||
|             "RAG mode", | ||||
|             ["Direct", "Agent-based"], | ||||
|             captions=[ | ||||
|                 "RAG is performed by directly retrieving the information and augmenting the user query", | ||||
|                 "RAG is performed by an agent activating a dedicated knowledge search tool.", | ||||
|             ], | ||||
|             on_change=reset_agent_and_chat, | ||||
|             disabled=should_disable_input(), | ||||
|         ) | ||||
| 
 | ||||
|         # select memory banks | ||||
|         vector_dbs = llama_stack_api.client.vector_dbs.list() | ||||
|         vector_dbs = [vector_db.identifier for vector_db in vector_dbs] | ||||
|         selected_vector_dbs = st.multiselect( | ||||
|             label="Select Document Collections to use in RAG queries", | ||||
|             options=vector_dbs, | ||||
|             on_change=reset_agent_and_chat, | ||||
|             disabled=should_disable_input(), | ||||
|         ) | ||||
| 
 | ||||
|         st.subheader("Inference Parameters", divider=True) | ||||
|         available_models = llama_stack_api.client.models.list() | ||||
|         available_models = [model.identifier for model in available_models if model.model_type == "llm"] | ||||
|         selected_model = st.selectbox( | ||||
|             label="Choose a model", | ||||
|             options=available_models, | ||||
|             index=0, | ||||
|             on_change=reset_agent_and_chat, | ||||
|             disabled=should_disable_input(), | ||||
|         ) | ||||
|         system_prompt = st.text_area( | ||||
|             "System Prompt", | ||||
|             value="You are a helpful assistant. ", | ||||
|             help="Initial instructions given to the AI to set its behavior and context", | ||||
|             on_change=reset_agent_and_chat, | ||||
|             disabled=should_disable_input(), | ||||
|         ) | ||||
|         temperature = st.slider( | ||||
|             "Temperature", | ||||
|             min_value=0.0, | ||||
|             max_value=1.0, | ||||
|             value=0.0, | ||||
|             step=0.1, | ||||
|             help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", | ||||
|             on_change=reset_agent_and_chat, | ||||
|             disabled=should_disable_input(), | ||||
|         ) | ||||
| 
 | ||||
|         top_p = st.slider( | ||||
|             "Top P", | ||||
|             min_value=0.0, | ||||
|             max_value=1.0, | ||||
|             value=0.95, | ||||
|             step=0.1, | ||||
|             on_change=reset_agent_and_chat, | ||||
|             disabled=should_disable_input(), | ||||
|         ) | ||||
| 
 | ||||
|         # Add clear chat button to sidebar | ||||
|         if st.button("Clear Chat", use_container_width=True): | ||||
|             reset_agent_and_chat() | ||||
|             st.rerun() | ||||
| 
 | ||||
|     # Chat Interface | ||||
|     if "messages" not in st.session_state: | ||||
|         st.session_state.messages = [] | ||||
|     if "displayed_messages" not in st.session_state: | ||||
|         st.session_state.displayed_messages = [] | ||||
| 
 | ||||
|     # Display chat history | ||||
|     for message in st.session_state.displayed_messages: | ||||
|         log_message(message) | ||||
| 
 | ||||
|     if temperature > 0.0: | ||||
|         strategy = { | ||||
|             "type": "top_p", | ||||
|             "temperature": temperature, | ||||
|             "top_p": top_p, | ||||
|         } | ||||
|     else: | ||||
|         strategy = {"type": "greedy"} | ||||
| 
 | ||||
|     @st.cache_resource | ||||
|     def create_agent(): | ||||
|         return Agent( | ||||
|             llama_stack_api.client, | ||||
|             model=selected_model, | ||||
|             instructions=system_prompt, | ||||
|             sampling_params={ | ||||
|                 "strategy": strategy, | ||||
|             }, | ||||
|             tools=[ | ||||
|                 dict( | ||||
|                     name="builtin::rag/knowledge_search", | ||||
|                     args={ | ||||
|                         "vector_db_ids": list(selected_vector_dbs), | ||||
|                     }, | ||||
|                 ) | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|     if rag_mode == "Agent-based": | ||||
|         agent = create_agent() | ||||
|         if "agent_session_id" not in st.session_state: | ||||
|             st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}") | ||||
| 
 | ||||
|         session_id = st.session_state["agent_session_id"] | ||||
| 
 | ||||
|     def agent_process_prompt(prompt): | ||||
|         # Add user message to chat history | ||||
|         st.session_state.messages.append({"role": "user", "content": prompt}) | ||||
| 
 | ||||
|         # Send the prompt to the agent | ||||
|         response = agent.create_turn( | ||||
|             messages=[ | ||||
|                 { | ||||
|                     "role": "user", | ||||
|                     "content": prompt, | ||||
|                 } | ||||
|             ], | ||||
|             session_id=session_id, | ||||
|         ) | ||||
| 
 | ||||
|         # Display assistant response | ||||
|         with st.chat_message("assistant"): | ||||
|             retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠") | ||||
|             message_placeholder = st.empty() | ||||
|             full_response = "" | ||||
|             retrieval_response = "" | ||||
|             for log in AgentEventLogger().log(response): | ||||
|                 log.print() | ||||
|                 if log.role == "tool_execution": | ||||
|                     retrieval_response += log.content.replace("====", "").strip() | ||||
|                     retrieval_message_placeholder.write(retrieval_response) | ||||
|                 else: | ||||
|                     full_response += log.content | ||||
|                     message_placeholder.markdown(full_response + "▌") | ||||
|             message_placeholder.markdown(full_response) | ||||
| 
 | ||||
|             st.session_state.messages.append({"role": "assistant", "content": full_response}) | ||||
|             st.session_state.displayed_messages.append( | ||||
|                 {"role": "assistant", "content": full_response, "tool_output": retrieval_response} | ||||
|             ) | ||||
| 
 | ||||
|     def direct_process_prompt(prompt): | ||||
|         # Add the system prompt in the beginning of the conversation | ||||
|         if len(st.session_state.messages) == 0: | ||||
|             st.session_state.messages.append({"role": "system", "content": system_prompt}) | ||||
| 
 | ||||
|         # Query the vector DB | ||||
|         rag_response = llama_stack_api.client.tool_runtime.rag_tool.query( | ||||
|             content=prompt, vector_db_ids=list(selected_vector_dbs) | ||||
|         ) | ||||
|         prompt_context = rag_response.content | ||||
| 
 | ||||
|         with st.chat_message("assistant"): | ||||
|             with st.expander(label="Retrieval Output", expanded=False): | ||||
|                 st.write(prompt_context) | ||||
| 
 | ||||
|             retrieval_message_placeholder = st.empty() | ||||
|             message_placeholder = st.empty() | ||||
|             full_response = "" | ||||
|             retrieval_response = "" | ||||
| 
 | ||||
|             # Construct the extended prompt | ||||
|             extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" | ||||
| 
 | ||||
|             # Run inference directly | ||||
|             st.session_state.messages.append({"role": "user", "content": extended_prompt}) | ||||
|             response = llama_stack_api.client.inference.chat_completion( | ||||
|                 messages=st.session_state.messages, | ||||
|                 model_id=selected_model, | ||||
|                 sampling_params={ | ||||
|                     "strategy": strategy, | ||||
|                 }, | ||||
|                 stream=True, | ||||
|             ) | ||||
| 
 | ||||
|             # Display assistant response | ||||
|             for chunk in response: | ||||
|                 response_delta = chunk.event.delta | ||||
|                 if isinstance(response_delta, ToolCallDelta): | ||||
|                     retrieval_response += response_delta.tool_call.replace("====", "").strip() | ||||
|                     retrieval_message_placeholder.info(retrieval_response) | ||||
|                 else: | ||||
|                     full_response += chunk.event.delta.text | ||||
|                     message_placeholder.markdown(full_response + "▌") | ||||
|             message_placeholder.markdown(full_response) | ||||
| 
 | ||||
|         response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"} | ||||
|         st.session_state.messages.append(response_dict) | ||||
|         st.session_state.displayed_messages.append(response_dict) | ||||
| 
 | ||||
|     # Chat input | ||||
|     if prompt := st.chat_input("Ask a question about your documents"): | ||||
|         # Add user message to chat history | ||||
|         st.session_state.displayed_messages.append({"role": "user", "content": prompt}) | ||||
| 
 | ||||
|         # Display user message | ||||
|         with st.chat_message("user"): | ||||
|             st.markdown(prompt) | ||||
| 
 | ||||
|         # store the prompt to process it after page refresh | ||||
|         st.session_state.prompt = prompt | ||||
| 
 | ||||
|         # force page refresh to disable the settings widgets | ||||
|         st.rerun() | ||||
| 
 | ||||
|     if "prompt" in st.session_state and st.session_state.prompt is not None: | ||||
|         if rag_mode == "Agent-based": | ||||
|             agent_process_prompt(st.session_state.prompt) | ||||
|         else:  # rag_mode == "Direct" | ||||
|             direct_process_prompt(st.session_state.prompt) | ||||
|         st.session_state.prompt = None | ||||
| 
 | ||||
| 
 | ||||
| rag_chat_page() | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue