From 82778ecbb015f11ff34300ad97871cd3dfff464a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 19 May 2025 22:02:23 +0200 Subject: [PATCH 01/61] fix: remove wrong deprecated warning (#2202) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? `--yaml-config` is gone now with https://github.com/meta-llama/llama-stack/pull/2196. Signed-off-by: Sébastien Han --- llama_stack/distribution/server/server.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 15a8058ae..8cc028769 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -370,14 +370,6 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - # Check for deprecated argument usage - if "--config" in sys.argv: - warnings.warn( - "The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.", - DeprecationWarning, - stacklevel=2, - ) - log_line = "" if args.config: # if the user provided a config file, use it, even if template was specified From 6d20b720b872d120ef4ade58b66b56bafb2f7302 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 19 May 2025 21:56:54 -0400 Subject: [PATCH 02/61] feat: Propagate W3C trace context headers from clients (#2153) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This extracts the W3C trace context headers (traceparent and tracestate) from incoming requests, stuffs them as attributes on the spans we create, and uses them within the tracing provider implementation to actually wrap our spans in the proper context. What this means in practice is that when a client (such as an OpenAI client) is instrumented to create these traces, we'll continue that distributed trace within Llama Stack as opposed to creating our own root span that breaks the distributed trace between client and server. It's slightly awkward to do this in Llama Stack because our Tracing API knows nothing about opentelemetry, W3C trace headers, etc - that's only knowledge the specific provider implementation has. So, that's why the trace headers get extracted by in the server code but not actually used until the provider implementation to form the proper context. This also centralizes how we were adding the `__root__` and `__root_span__` attributes, as those two were being added in different parts of the code instead of from a single place. Closes #2097 ## Test Plan This was tested manually using the helpful scripts from #2097. I verified that Llama Stack properly joined the client's span when the client was instrumented for distributed tracing, and that Llama Stack properly started its own root span when the incoming request was not part of an existing trace. Here's an example of the joined spans: ![Screenshot 2025-05-13 at 8 46 09 AM](https://github.com/user-attachments/assets/dbefda28-9faa-4339-a08d-1441efefc149) Signed-off-by: Ben Browning --- llama_stack/distribution/server/server.py | 13 ++++++++++++- .../telemetry/meta_reference/telemetry.py | 19 +++++++++++++++++-- .../providers/utils/telemetry/tracing.py | 5 ++++- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8cc028769..e25bf0817 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -280,7 +280,18 @@ class TracingMiddleware: logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) - trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_context = await start_trace(trace_path, trace_attributes) async def send_with_trace_id(message): if message["type"] == "http.response.start": diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 67362dd36..1bc979894 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from llama_stack.apis.telemetry import ( Event, @@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor ) from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore +from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS from .config import TelemetryConfig, TelemetrySink @@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): event.attributes = {} event.attributes["__ttl__"] = ttl_seconds + # Extract these W3C trace context attributes so they are not written to + # underlying storage, as we just need them to propagate the trace context. + traceparent = event.attributes.pop("traceparent", None) + tracestate = event.attributes.pop("tracestate", None) + if traceparent: + # If we have a traceparent header value, we're not the root span. + for root_attribute in ROOT_SPAN_MARKERS: + event.attributes.pop(root_attribute, None) + if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates if span_id in _GLOBAL_STORAGE["active_spans"]: @@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) context = trace.set_span_in_context(parent_span) - else: - event.attributes["__root_span__"] = "true" + elif traceparent: + carrier = { + "traceparent": traceparent, + "tracestate": tracestate, + } + context = TraceContextTextMapPropagator().extract(carrier=carrier) span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 0f4fdd0d8..4edfa6516 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core") INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 +ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] + def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method @@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) - context.push_span(name, {"__root__": True, **(attributes or {})}) + attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {}) + context.push_span(name, attributes) CURRENT_TRACE_CONTEXT.set(context) return context From ed7b4731aab8ecc544ae39b0b822223f204629be Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Tue, 20 May 2025 06:03:22 -0600 Subject: [PATCH 03/61] fix: Setting default value for `metadata_token_count` in case the key is not found (#2199) # What does this PR do? If a user has previously serialized data into their vector store without the `metadata_token_count` in the chunk, the `query` method will fail in a server error. This fixes that edge case by returning 0 when the key is not detected. This solution is suboptimal but I think it's better to understate the token size rather than recalculate it and add unnecessary complexity to the retrieval code. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: Francisco Javier Arceo --- llama_stack/providers/inline/tool_runtime/rag/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 39f752297..c46960f75 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): for i, chunk in enumerate(chunks): metadata = chunk.metadata tokens += metadata["token_count"] - tokens += metadata["metadata_token_count"] + tokens += metadata.get("metadata_token_count", 0) if tokens > query_config.max_tokens_in_context: log.error( From 90d7612f5f8e6ab1206e1e681cce2ad29a695cc2 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Tue, 20 May 2025 09:06:20 -0600 Subject: [PATCH 04/61] chore: Updated readme (#2219) # What does this PR do? chore: Updated readme [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: Francisco Javier Arceo --- llama_stack/ui/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/ui/README.md b/llama_stack/ui/README.md index e3e21bf0b..665619bf1 100644 --- a/llama_stack/ui/README.md +++ b/llama_stack/ui/README.md @@ -8,7 +8,7 @@ We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components. First, install dependencies: ```bash -npm install next react react-dom +npm install ``` Then, run the development server: From 3f6368d56cb7620653369e5930d72980fff829bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 20 May 2025 18:04:03 +0200 Subject: [PATCH 05/61] ci: enable ruff output format for github (#2214) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Update output format to enable automatic inline annotations. ![Screenshot 2025-05-20 at 10 55 38](https://github.com/user-attachments/assets/f943aa00-9b60-4cdb-b434-67b2de8b79f2) Signed-off-by: Sébastien Han --- .github/workflows/pre-commit.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 4df04fbb0..2bbd52c53 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -29,6 +29,7 @@ jobs: - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 env: SKIP: no-commit-to-branch + RUFF_OUTPUT_FORMAT: github - name: Verify if there are any diff files after pre-commit run: | From 2eae8568e1f5b5f5524ab1eb6d14f7b681d71ab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 20 May 2025 18:51:09 +0200 Subject: [PATCH 06/61] chore: collapse all local hook under the same repo (#2217) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- .pre-commit-config.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e78fcd158..c88ca1ca9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -94,9 +94,6 @@ repos: pass_filenames: false require_serial: true files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ - -- repo: local - hooks: - id: openapi-codegen name: API Spec Codegen additional_dependencies: From 1a770cf8aca31ebf3e28ae63092fca3b54cbf4cd Mon Sep 17 00:00:00 2001 From: Jash Gulabrai <37194352+JashG@users.noreply.github.com> Date: Tue, 20 May 2025 12:51:39 -0400 Subject: [PATCH 07/61] fix: Pass model parameter as config name to NeMo Customizer (#2218) # What does this PR do? When launching a fine-tuning job, an upcoming version of NeMo Customizer will expect the `config` name to be formatted as `namespace/name@version`. Here, `config` is a reference to a model + additional metadata. There could be multiple `config`s that reference the same base model. This PR updates NVIDIA's `supervised_fine_tune` to simply pass the `model` param as-is to NeMo Customizer. Currently, it expects a specific, allowlisted llama model (i.e. `meta/Llama3.1-8B-Instruct`) and converts it to the provider format (`meta/llama-3.1-8b-instruct`). [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan From a notebook, I built an image with my changes: ``` !llama stack build --template nvidia --image-type venv from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") client.initialize() ``` And could successfully launch a job: ``` response = client.post_training.supervised_fine_tune( job_uuid="", model="meta/llama-3.2-1b-instruct@v1.0.0+A100", # Model passed as-is to Customimzer ... ) job_id = response.job_uuid print(f"Created job with ID: {job_id}") Output: Created job with ID: cust-Jm4oGmbwcvoufaLU4XkrRU ``` [//]: # (## Documentation) --------- Co-authored-by: Jash Gulabrai --- .../providers/remote/post_training/nvidia/post_training.py | 7 ++----- tests/unit/providers/nvidia/test_parameters.py | 4 ++-- tests/unit/providers/nvidia/test_supervised_fine_tuning.py | 6 +++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index 409818cb3..d839ffd6f 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): Parameters: training_config: TrainingConfig - Configuration for training - model: str - Model identifier + model: str - NeMo Customizer configuration name algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm job_uuid: str - Unique identifier for the job, ignored atm @@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): User is informed about unsupported parameters via warnings. """ - # Map model to nvidia model name - # See `_MODEL_ENTRIES` for supported models - nvidia_model = self.get_provider_model_id(model) # Check for unsupported method parameters unsupported_method_params = [] @@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Prepare base job configuration job_config = { - "config": nvidia_model, + "config": model, "dataset": { "name": training_config["data_config"]["dataset_id"], "namespace": self.config.dataset_namespace, diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py index ea12122a0..cc33f7609 100644 --- a/tests/unit/providers/nvidia/test_parameters.py +++ b/tests/unit/providers/nvidia/test_parameters.py @@ -131,7 +131,7 @@ class TestNvidiaParameters(unittest.TestCase): def test_required_parameters_passed(self): """Test scenario 2: When required parameters are passed.""" - required_model = "meta-llama/Llama-3.1-8B-Instruct" + required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40" required_dataset_id = "required-dataset" required_job_uuid = "required-job" @@ -190,7 +190,7 @@ class TestNvidiaParameters(unittest.TestCase): self.mock_make_request.assert_called_once() call_args = self.mock_make_request.call_args - assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct" + assert call_args[1]["json"]["config"] == required_model assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id def test_unsupported_parameters_warning(self): diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 319011be3..97ca02fba 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -165,7 +165,7 @@ class TestNvidiaPostTraining(unittest.TestCase): training_job = self.run_async( self.adapter.supervised_fine_tune( job_uuid="1234", - model="meta-llama/Llama-3.1-8B-Instruct", + model="meta/llama-3.2-1b-instruct@v1.0.0+L40", checkpoint_dir="", algorithm_config=algorithm_config, training_config=convert_pydantic_to_json_value(training_config), @@ -184,7 +184,7 @@ class TestNvidiaPostTraining(unittest.TestCase): "POST", "/v1/customization/jobs", expected_json={ - "config": "meta/llama-3.1-8b-instruct", + "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40", "dataset": {"name": "sample-basic-test", "namespace": "default"}, "hyperparameters": { "training_type": "sft", @@ -219,7 +219,7 @@ class TestNvidiaPostTraining(unittest.TestCase): self.run_async( self.adapter.supervised_fine_tune( job_uuid="1234", - model="meta-llama/Llama-3.1-8B-Instruct", + model="meta/llama-3.2-1b-instruct@v1.0.0+L40", checkpoint_dir="", algorithm_config=algorithm_config, training_config=convert_pydantic_to_json_value(training_config), From 3339844fda052ee4ad0449de924d28f39f641ff7 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Tue, 20 May 2025 17:52:10 +0100 Subject: [PATCH 08/61] feat: Add "instructions" support to responses API (#2205) # What does this PR do? Add support for "instructions" to the responses API. Instructions provide a way to swap out system (or developer) messages in new responses. ## Test Plan unit tests added Signed-off-by: Derek Higgins --- docs/_static/llama-stack-spec.html | 3 + docs/_static/llama-stack-spec.yaml | 2 + llama_stack/apis/agents/agents.py | 1 + .../inline/agents/meta_reference/agents.py | 3 +- .../agents/meta_reference/openai_responses.py | 7 + .../meta_reference/test_openai_responses.py | 138 ++++++++++++++++++ 6 files changed, 153 insertions(+), 1 deletion(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 6378a5ced..6adfe9b2b 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7027,6 +7027,9 @@ "type": "string", "description": "The underlying LLM used for completions." }, + "instructions": { + "type": "string" + }, "previous_response_id": { "type": "string", "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 012610d02..31ca3f52a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -4952,6 +4952,8 @@ components: model: type: string description: The underlying LLM used for completions. + instructions: + type: string previous_response_id: type: string description: >- diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index b2f85336c..8ecafdf26 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -596,6 +596,7 @@ class Agents(Protocol): self, input: str | list[OpenAIResponseInput], model: str, + instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 86780fd61..8f54cc737 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -313,6 +313,7 @@ class MetaReferenceAgentsImpl(Agents): self, input: str | list[OpenAIResponseInput], model: str, + instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, @@ -320,5 +321,5 @@ class MetaReferenceAgentsImpl(Agents): tools: list[OpenAIResponseInputTool] | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( - input, model, previous_response_id, store, stream, temperature, tools + input, model, instructions, previous_response_id, store, stream, temperature, tools ) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 6d9d06109..f5b0d8c31 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -208,6 +208,10 @@ class OpenAIResponsesImpl: return input + async def _prepend_instructions(self, messages, instructions): + if instructions: + messages.insert(0, OpenAISystemMessageParam(content=instructions)) + async def get_openai_response( self, id: str, @@ -219,6 +223,7 @@ class OpenAIResponsesImpl: self, input: str | list[OpenAIResponseInput], model: str, + instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, @@ -229,7 +234,9 @@ class OpenAIResponsesImpl: input = await self._prepend_previous_response(input, previous_response_id) messages = await _convert_response_input_to_chat_messages(input) + await self._prepend_instructions(messages, instructions) chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None + chat_response = await self.inference_api.openai_chat_completion( model=model, messages=messages, diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index ed5f13a58..0a8d59306 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -384,3 +384,141 @@ async def test_prepend_previous_response_web_search(get_previous_response_with_i # Check for new input assert isinstance(input[3], OpenAIResponseMessage) assert input[3].content == "fake_input" + + +@pytest.mark.asyncio +async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + # Load the chat completion fixture + mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") + mock_inference_api.openai_chat_completion.return_value = mock_chat_completion + + # Execute + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + instructions=instructions, + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + sent_messages = call_args.kwargs["messages"] + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 2 + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + assert sent_messages[1].role == "user" + assert sent_messages[1].content == input_text + + +@pytest.mark.asyncio +async def test_create_openai_response_with_instructions_and_multiple_messages( + openai_responses_impl, mock_inference_api +): + # Setup + input_messages = [ + OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), + OpenAIResponseMessage( + role="assistant", + content="Galway, Longford, Sligo", + name=None, + ), + OpenAIResponseMessage(role="user", content="Which is the largest?", name=None), + ] + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") + mock_inference_api.openai_chat_completion.return_value = mock_chat_completion + + # Execute + await openai_responses_impl.create_openai_response( + input=input_messages, + model=model, + instructions=instructions, + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + sent_messages = call_args.kwargs["messages"] + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 4 # 1 system + 3 input messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +@pytest.mark.asyncio +@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") +async def test_create_openai_response_with_instructions_and_previous_response( + get_previous_response_with_input, openai_responses_impl, mock_inference_api +): + """Test prepending both instructions and previous response.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content="Name some towns in Ireland", + role="user", + ) + input_items = OpenAIResponseInputItemList(data=[input_item_message]) + response_output_message = OpenAIResponseMessage( + id="123", + content="Galway, Longford, Sligo", + status="completed", + role="assistant", + ) + response = OpenAIResponseObject( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + ) + previous_response = OpenAIResponsePreviousResponseWithInputItems( + input_items=input_items, + response=response, + ) + get_previous_response_with_input.return_value = previous_response + + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") + mock_inference_api.openai_chat_completion.return_value = mock_chat_completion + + # Execute + await openai_responses_impl.create_openai_response( + input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + sent_messages = call_args.kwargs["messages"] + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 4 + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" From 87a4b9cb28f8e9f94c40e79a2a8a8738e24aebe1 Mon Sep 17 00:00:00 2001 From: grs Date: Tue, 20 May 2025 13:00:44 -0400 Subject: [PATCH 09/61] fix: synchronize concurrent coroutines checking & updating key set (#2215) # What does this PR do? This PR adds a lock to coordinate concurrent coroutines passing through the jwt verification. As _refresh_jwks() was setting _jwks to an empty dict then repopulating it, having multiple coroutines doing this concurrently risks losing keys. The PR also builds the updated dict as a separate object and assigns it to _jwks once completed. This avoids impacting any coroutines using the key set as it is being updated. Signed-off-by: Gordon Sim --- .../distribution/server/auth_providers.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 4065a65f3..b73fded58 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -7,6 +7,7 @@ import json import time from abc import ABC, abstractmethod +from asyncio import Lock from enum import Enum from urllib.parse import parse_qs @@ -236,6 +237,7 @@ class OAuth2TokenAuthProvider(AuthProvider): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} + self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the JWT token.""" @@ -271,17 +273,19 @@ class OAuth2TokenAuthProvider(AuthProvider): """Close the HTTP client.""" async def _refresh_jwks(self) -> None: - if time.time() - self._jwks_at > self.config.cache_ttl: - async with httpx.AsyncClient() as client: - res = await client.get(self.config.jwks_uri, timeout=5) - res.raise_for_status() - jwks_data = res.json()["keys"] - self._jwks = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - self._jwks[kid] = k - self._jwks_at = time.time() + async with self._jwks_lock: + if time.time() - self._jwks_at > self.config.cache_ttl: + async with httpx.AsyncClient() as client: + res = await client.get(self.config.jwks_uri, timeout=5) + res.raise_for_status() + jwks_data = res.json()["keys"] + updated = {} + for k in jwks_data: + kid = k["kid"] + # Store the entire key object as it may be needed for different algorithms + updated[kid] = k + self._jwks = updated + self._jwks_at = time.time() class CustomAuthProviderConfig(BaseModel): From 091d8c48f217b413fa267a3c0412c2967be601cd Mon Sep 17 00:00:00 2001 From: grs Date: Tue, 20 May 2025 22:45:11 -0400 Subject: [PATCH 10/61] feat: add additional auth provider that uses oauth token introspection (#2187) # What does this PR do? This adds an alternative option to the oauth_token auth provider that can be used with existing authorization services which support token introspection as defined in RFC 7662. This could be useful where token revocation needs to be handled or where opaque tokens (or other non jwt formatted tokens) are used ## Test Plan Tested against keycloak Signed-off-by: Gordon Sim --- llama_stack/distribution/datatypes.py | 2 +- .../distribution/server/auth_providers.py | 100 +++++++++-- tests/unit/server/test_auth.py | 162 +++++++++++++++++- 3 files changed, 251 insertions(+), 13 deletions(-) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 446a88ca0..be5629ba1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -229,7 +229,7 @@ class AuthenticationConfig(BaseModel): ..., description="Type of authentication provider (e.g., 'kubernetes', 'custom')", ) - config: dict[str, str] = Field( + config: dict[str, Any] = Field( ..., description="Provider-specific configuration", ) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index b73fded58..baab75eca 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -5,15 +5,18 @@ # the root directory of this source tree. import json +import ssl import time from abc import ABC, abstractmethod from asyncio import Lock from enum import Enum +from typing import Any from urllib.parse import parse_qs import httpx from jose import jwt -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger @@ -85,7 +88,7 @@ class AuthProviderConfig(BaseModel): """Base configuration for authentication providers.""" provider_type: AuthProviderType = Field(..., description="Type of authentication provider") - config: dict[str, str] = Field(..., description="Provider-specific configuration") + config: dict[str, Any] = Field(..., description="Provider-specific configuration") class AuthProvider(ABC): @@ -198,10 +201,21 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) return attributes -class OAuth2TokenAuthProviderConfig(BaseModel): +class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys - jwks_uri: str + uri: str cache_ttl: int = 3600 + + +class OAuth2IntrospectionConfig(BaseModel): + url: str + client_id: str + client_secret: str + send_secret_in_body: bool = False + tls_cafile: str | None = None + + +class OAuth2TokenAuthProviderConfig(BaseModel): audience: str = "llama-stack" claims_mapping: dict[str, str] = Field( default_factory=lambda: { @@ -214,6 +228,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel): "namespace": "namespaces", }, ) + jwks: OAuth2JWKSConfig | None + introspection: OAuth2IntrospectionConfig | None = None @classmethod @field_validator("claims_mapping") @@ -225,6 +241,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel): raise ValueError(f"claims_mapping value is not a valid attribute: {value}") return v + @model_validator(mode="after") + def validate_mode(self) -> Self: + if not self.jwks and not self.introspection: + raise ValueError("One of jwks or introspection must be configured") + if self.jwks and self.introspection: + raise ValueError("At present only one of jwks or introspection should be configured") + return self + class OAuth2TokenAuthProvider(AuthProvider): """ @@ -240,8 +264,17 @@ class OAuth2TokenAuthProvider(AuthProvider): self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + if self.config.jwks: + return await self.validate_jwt_token(token, self.config.jwks, scope) + if self.config.introspection: + return await self.introspect_token(token, self.config.introspection, scope) + raise ValueError("One of jwks or introspection must be configured") + + async def validate_jwt_token( + self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None + ) -> TokenValidationResult: """Validate a token using the JWT token.""" - await self._refresh_jwks() + await self._refresh_jwks(config) try: header = jwt.get_unverified_header(token) @@ -269,14 +302,61 @@ class OAuth2TokenAuthProvider(AuthProvider): access_attributes=access_attributes, ) - async def close(self): - """Close the HTTP client.""" + async def introspect_token( + self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None + ) -> TokenValidationResult: + """Validate a token using token introspection as defined by RFC 7662.""" + form = { + "token": token, + } + if config.send_secret_in_body: + form["client_id"] = config.client_id + form["client_secret"] = config.client_secret + auth = None + else: + auth = (config.client_id, config.client_secret) + ssl_ctxt = None + if config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile) + try: + async with httpx.AsyncClient(verify=ssl_ctxt) as client: + response = await client.post( + config.url, + data=form, + auth=auth, + timeout=10.0, # Add a reasonable timeout + ) + if response.status_code != 200: + logger.warning(f"Token introspection failed with status code: {response.status_code}") + raise ValueError(f"Token introspection failed: {response.status_code}") - async def _refresh_jwks(self) -> None: + fields = response.json() + if not fields["active"]: + raise ValueError("Token not active") + principal = fields["sub"] or fields["username"] + access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) + return TokenValidationResult( + principal=principal, + access_attributes=access_attributes, + ) + except httpx.TimeoutException: + logger.exception("Token introspection request timed out") + raise + except ValueError: + # Re-raise ValueError exceptions to preserve their message + raise + except Exception as e: + logger.exception("Error during token introspection") + raise ValueError("Token introspection error") from e + + async def close(self): + pass + + async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None: async with self._jwks_lock: - if time.time() - self._jwks_at > self.config.cache_ttl: + if time.time() - self._jwks_at > config.cache_ttl: async with httpx.AsyncClient() as client: - res = await client.get(self.config.jwks_uri, timeout=5) + res = await client.get(config.uri, timeout=5) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index f15ca9de4..56458c0e7 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -396,8 +396,10 @@ def oauth2_app(): auth_config = AuthProviderConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ - "jwks_uri": "http://mock-authz-service/token/introspect", - "cache_ttl": "3600", + "jwks": { + "uri": "http://mock-authz-service/token/introspect", + "cache_ttl": "3600", + }, "audience": "llama-stack", }, ) @@ -517,3 +519,159 @@ def test_get_attributes_from_claims(): # TODO: add more tests for oauth2 token provider + + +# oauth token introspection tests +@pytest.fixture +def mock_introspection_endpoint(): + return "http://mock-authz-service/token/introspect" + + +@pytest.fixture +def introspection_app(mock_introspection_endpoint): + app = FastAPI() + auth_config = AuthProviderConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks": None, + "introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"}, + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def introspection_app_with_custom_mapping(mock_introspection_endpoint): + app = FastAPI() + auth_config = AuthProviderConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks": None, + "introspection": { + "url": mock_introspection_endpoint, + "client_id": "myclient", + "client_secret": "abcdefg", + "send_secret_in_body": "true", + }, + "claims_mapping": { + "sub": "roles", + "scope": "roles", + "groups": "teams", + "aud": "namespaces", + }, + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def introspection_client(introspection_app): + return TestClient(introspection_app) + + +@pytest.fixture +def introspection_client_with_custom_mapping(introspection_app_with_custom_mapping): + return TestClient(introspection_app_with_custom_mapping) + + +def test_missing_auth_header_introspection(introspection_client): + response = introspection_client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format_introspection(introspection_client): + response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +async def mock_introspection_active(*args, **kwargs): + return MockResponse( + 200, + { + "active": True, + "sub": "my-user", + "groups": ["group1", "group2"], + "scope": "foo bar", + "aud": ["set1", "set2"], + }, + ) + + +async def mock_introspection_inactive(*args, **kwargs): + return MockResponse( + 200, + { + "active": False, + }, + ) + + +async def mock_introspection_invalid(*args, **kwargs): + class InvalidResponse: + def __init__(self, status_code): + self.status_code = status_code + + def json(self): + raise ValueError("Not JSON") + + return InvalidResponse(200) + + +async def mock_introspection_failed(*args, **kwargs): + return MockResponse( + 500, + {}, + ) + + +@patch("httpx.AsyncClient.post", new=mock_introspection_active) +def test_valid_introspection_authentication(introspection_client, valid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_introspection_inactive) +def test_inactive_introspection_authentication(introspection_client, invalid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Token not active" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_introspection_invalid) +def test_invalid_introspection_authentication(introspection_client, invalid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Not JSON" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_introspection_failed) +def test_failed_introspection_authentication(introspection_client, invalid_api_key): + response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Token introspection failed: 500" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_introspection_active) +def test_valid_introspection_with_custom_mapping_authentication( + introspection_client_with_custom_mapping, valid_api_key +): + response = introspection_client_with_custom_mapping.get( + "/test", headers={"Authorization": f"Bearer {valid_api_key}"} + ) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} From 5a3d777b20ea19870cc4ffec70af31055f1aacbc Mon Sep 17 00:00:00 2001 From: Abhishek koserwal Date: Wed, 21 May 2025 13:55:51 +0530 Subject: [PATCH 11/61] feat: add llama stack rm command (#2127) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] ``` llama stack rm llamastack-test ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) #225 ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --- docs/source/distributions/building_distro.md | 42 +++++++ llama_stack/cli/stack/list_stacks.py | 56 +++++++++ llama_stack/cli/stack/remove.py | 116 +++++++++++++++++++ llama_stack/cli/stack/stack.py | 5 +- 4 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 llama_stack/cli/stack/list_stacks.py create mode 100644 llama_stack/cli/stack/remove.py diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index d9b73c910..0dbabf8aa 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -338,6 +338,48 @@ INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK ``` +### Listing Distributions +Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files. + +``` +llama stack list -h +usage: llama stack list [-h] + +list the build stacks + +options: + -h, --help show this help message and exit +``` + +Example Usage + +``` +llama stack list +``` + +### Removing a Distribution +Use the remove command to delete a distribution you've previously built. + +``` +llama stack rm -h +usage: llama stack rm [-h] [--all] [name] + +Remove the build stack + +positional arguments: + name Name of the stack to delete (default: None) + +options: + -h, --help show this help message and exit + --all, -a Delete all stacks (use with caution) (default: False) +``` + +Example +``` +llama stack rm llamastack-test +``` + +To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm ` to delete them when they’re no longer needed. ### Troubleshooting diff --git a/llama_stack/cli/stack/list_stacks.py b/llama_stack/cli/stack/list_stacks.py new file mode 100644 index 000000000..2ea0fdeea --- /dev/null +++ b/llama_stack/cli/stack/list_stacks.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +from pathlib import Path + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table + + +class StackListBuilds(Subcommand): + """List built stacks in .llama/distributions directory""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama stack list", + description="list the build stacks", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._list_stack_command) + + def _get_distribution_dirs(self) -> dict[str, Path]: + """Return a dictionary of distribution names and their paths""" + distributions = {} + dist_dir = Path.home() / ".llama" / "distributions" + + if dist_dir.exists(): + for stack_dir in dist_dir.iterdir(): + if stack_dir.is_dir(): + distributions[stack_dir.name] = stack_dir + return distributions + + def _list_stack_command(self, args: argparse.Namespace) -> None: + distributions = self._get_distribution_dirs() + + if not distributions: + print("No stacks found in ~/.llama/distributions") + return + + headers = ["Stack Name", "Path"] + headers.extend(["Build Config", "Run Config"]) + rows = [] + for name, path in distributions.items(): + row = [name, str(path)] + # Check for build and run config files + build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No" + run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No" + row.extend([build_config, run_config]) + rows.append(row) + print_table(rows, headers, separate_rows=True) diff --git a/llama_stack/cli/stack/remove.py b/llama_stack/cli/stack/remove.py new file mode 100644 index 000000000..be7c49a5d --- /dev/null +++ b/llama_stack/cli/stack/remove.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import shutil +import sys +from pathlib import Path + +from termcolor import cprint + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table + + +class StackRemove(Subcommand): + """Remove the build stack""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "rm", + prog="llama stack rm", + description="Remove the build stack", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._remove_stack_build_command) + + def _add_arguments(self) -> None: + self.parser.add_argument( + "name", + type=str, + nargs="?", + help="Name of the stack to delete", + ) + self.parser.add_argument( + "--all", + "-a", + action="store_true", + help="Delete all stacks (use with caution)", + ) + + def _get_distribution_dirs(self) -> dict[str, Path]: + """Return a dictionary of distribution names and their paths""" + distributions = {} + dist_dir = Path.home() / ".llama" / "distributions" + + if dist_dir.exists(): + for stack_dir in dist_dir.iterdir(): + if stack_dir.is_dir(): + distributions[stack_dir.name] = stack_dir + return distributions + + def _list_stacks(self) -> None: + """Display available stacks in a table""" + distributions = self._get_distribution_dirs() + if not distributions: + print("No stacks found in ~/.llama/distributions") + return + + headers = ["Stack Name", "Path"] + rows = [[name, str(path)] for name, path in distributions.items()] + print_table(rows, headers, separate_rows=True) + + def _remove_stack_build_command(self, args: argparse.Namespace) -> None: + distributions = self._get_distribution_dirs() + + if args.all: + confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower() + if confirm != "yes-i-really-want": + print("Deletion cancelled.") + return + + for name, path in distributions.items(): + try: + shutil.rmtree(path) + print(f"Deleted stack: {name}") + except Exception as e: + cprint( + f"Failed to delete stack {name}: {e}", + color="red", + ) + sys.exit(2) + + if not args.name: + self._list_stacks() + if not args.name: + return + + if args.name not in distributions: + self._list_stacks() + cprint( + f"Stack not found: {args.name}", + color="red", + ) + return + + stack_path = distributions[args.name] + + confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower() + if confirm != "y": + print("Deletion cancelled.") + return + + try: + shutil.rmtree(stack_path) + print(f"Successfully deleted stack: {args.name}") + except Exception as e: + cprint( + f"Failed to delete stack {args.name}: {e}", + color="red", + ) + sys.exit(2) diff --git a/llama_stack/cli/stack/stack.py b/llama_stack/cli/stack/stack.py index ccf1a5ffc..3aff78e23 100644 --- a/llama_stack/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -7,12 +7,14 @@ import argparse from importlib.metadata import version +from llama_stack.cli.stack.list_stacks import StackListBuilds from llama_stack.cli.stack.utils import print_subcommand_description from llama_stack.cli.subcommand import Subcommand from .build import StackBuild from .list_apis import StackListApis from .list_providers import StackListProviders +from .remove import StackRemove from .run import StackRun @@ -41,5 +43,6 @@ class StackParser(Subcommand): StackListApis.create(subparsers) StackListProviders.create(subparsers) StackRun.create(subparsers) - + StackRemove.create(subparsers) + StackListBuilds.create(subparsers) print_subcommand_description(self.parser, subparsers) From 2890243107c74a7a88b82595db49e9540d0a0561 Mon Sep 17 00:00:00 2001 From: liangwen12year <36004580+liangwen12year@users.noreply.github.com> Date: Wed, 21 May 2025 04:58:45 -0400 Subject: [PATCH 12/61] =?UTF-8?q?feat(quota):=20add=20server=E2=80=91side?= =?UTF-8?q?=20per=E2=80=91client=20request=20quotas=20(requires=20auth)=20?= =?UTF-8?q?(#2096)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? feat(quota): add server‑side per‑client request quotas (requires auth) Unrestricted usage can lead to runaway costs and fragmented client-side workarounds. This commit introduces a native quota mechanism to the server, giving operators a unified, centrally managed throttle for per-client requests—without needing extra proxies or custom client logic. This helps contain cloud-compute expenses, enables fine-grained usage control, and simplifies deployment and monitoring of Llama Stack services. Quotas are fully opt-in and have no effect unless explicitly configured. Notice that Quotas are fully opt-in and require authentication to be enabled. The 'sqlite' is the only supported quota `type` at this time, any other `type` will be rejected. And the only supported `period` is 'day'. Highlights: - Adds `QuotaMiddleware` to enforce per-client request quotas: - Uses `Authorization: Bearer ` (from AuthenticationMiddleware) - Tracks usage via a SQLite-based KV store - Returns 429 when the quota is exceeded - Extends `ServerConfig` with a `quota` section (type + config) - Enforces strict coupling: quotas require authentication or the server will fail to start Behavior changes: - Quotas are disabled by default unless explicitly configured - SQLite defaults to `./quotas.db` if no DB path is set - The server requires authentication when quotas are enabled To enable per-client request quotas in `run.yaml`, add: ``` server: port: 8321 auth: provider_type: "custom" config: endpoint: "https://auth.example.com/validate" quota: type: sqlite config: db_path: ./quotas.db limit: max_requests: 1000 period: day [//]: # (If resolving an issue, uncomment and update the line below) Closes #2093 ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: Wen Liang Co-authored-by: Wen Liang --- docs/source/distributions/configuration.md | 74 ++++++++++++ llama_stack/distribution/datatypes.py | 19 ++- llama_stack/distribution/server/auth.py | 4 + llama_stack/distribution/server/quota.py | 110 ++++++++++++++++++ llama_stack/distribution/server/server.py | 30 +++++ tests/unit/server/test_quota.py | 127 +++++++++++++++++++++ 6 files changed, 363 insertions(+), 1 deletion(-) create mode 100644 llama_stack/distribution/server/quota.py create mode 100644 tests/unit/server/test_quota.py diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index b62227a84..7a42f503a 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -208,6 +208,80 @@ And must respond with: If no access attributes are returned, the token is used as a namespace. +### Quota Configuration + +The `quota` section allows you to enable server-side request throttling for both +authenticated and anonymous clients. This is useful for preventing abuse, enforcing +fairness across tenants, and controlling infrastructure costs without requiring +client-side rate limiting or external proxies. + +Quotas are disabled by default. When enabled, each client is tracked using either: + +* Their authenticated `client_id` (derived from the Bearer token), or +* Their IP address (fallback for anonymous requests) + +Quota state is stored in a SQLite-backed key-value store, and rate limits are applied +within a configurable time window (currently only `day` is supported). + +#### Example + +```yaml +server: + quota: + kvstore: + type: sqlite + db_path: ./quotas.db + anonymous_max_requests: 100 + authenticated_max_requests: 1000 + period: day +``` + +#### Configuration Options + +| Field | Description | +| ---------------------------- | -------------------------------------------------------------------------- | +| `kvstore` | Required. Backend storage config for tracking request counts. | +| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. | +| `kvstore.db_path` | File path to the SQLite database. | +| `anonymous_max_requests` | Max requests per period for unauthenticated clients. | +| `authenticated_max_requests` | Max requests per period for authenticated clients. | +| `period` | Time window for quota enforcement. Only `"day"` is supported. | + +> Note: if `authenticated_max_requests` is set but no authentication provider is +configured, the server will fall back to applying `anonymous_max_requests` to all +clients. + +#### Example with Authentication Enabled + +```yaml +server: + port: 8321 + auth: + provider_type: custom + config: + endpoint: https://auth.example.com/validate + quota: + kvstore: + type: sqlite + db_path: ./quotas.db + anonymous_max_requests: 100 + authenticated_max_requests: 1000 + period: day +``` + +If a client exceeds their limit, the server responds with: + +```http +HTTP/1.1 429 Too Many Requests +Content-Type: application/json + +{ + "error": { + "message": "Quota exceeded" + } +} +``` + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index be5629ba1..ca3664828 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -25,7 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.datatypes import Api, ProviderSpec -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -235,6 +235,19 @@ class AuthenticationConfig(BaseModel): ) +class QuotaPeriod(str, Enum): + DAY = "day" + + +class QuotaConfig(BaseModel): + kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") + anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period") + authenticated_max_requests: int = Field( + default=1000, description="Max requests for authenticated clients per period" + ) + period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -262,6 +275,10 @@ class ServerConfig(BaseModel): default=None, description="The host the server should listen on", ) + quota: QuotaConfig | None = Field( + default=None, + description="Per client quota request configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 83436c51f..67acffe3e 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -113,6 +113,10 @@ class AuthenticationMiddleware: "roles": [token], } + # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) + # can identify the requester and enforce per-client rate limits. + scope["authenticated_client_id"] = token + # Store attributes in request scope scope["user_attributes"] = user_attributes scope["principal"] = validation_result.principal diff --git a/llama_stack/distribution/server/quota.py b/llama_stack/distribution/server/quota.py new file mode 100644 index 000000000..ddbffae64 --- /dev/null +++ b/llama_stack/distribution/server/quota.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import time +from datetime import datetime, timedelta, timezone + +from starlette.types import ASGIApp, Receive, Scope, Send + +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl + +logger = get_logger(name=__name__, category="quota") + + +class QuotaMiddleware: + """ + ASGI middleware that enforces separate quotas for authenticated and anonymous clients + within a configurable time window. + + - For authenticated requests, it reads the client ID from the + `Authorization: Bearer ` header. + - For anonymous requests, it falls back to the IP address of the client. + Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned + once a client exceeds its quota. + """ + + def __init__( + self, + app: ASGIApp, + kv_config: KVStoreConfig, + anonymous_max_requests: int, + authenticated_max_requests: int, + window_seconds: int = 86400, + ): + self.app = app + self.kv_config = kv_config + self.kv: KVStore | None = None + self.anonymous_max_requests = anonymous_max_requests + self.authenticated_max_requests = authenticated_max_requests + self.window_seconds = window_seconds + + if isinstance(self.kv_config, SqliteKVStoreConfig): + logger.warning( + "QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. " + f"window_seconds={self.window_seconds}" + ) + + async def _get_kv(self) -> KVStore: + if self.kv is None: + self.kv = await kvstore_impl(self.kv_config) + return self.kv + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] == "http": + # pick key & limit based on auth + auth_id = scope.get("authenticated_client_id") + if auth_id: + key_id = auth_id + limit = self.authenticated_max_requests + else: + # fallback to IP + client = scope.get("client") + key_id = client[0] if client else "anonymous" + limit = self.anonymous_max_requests + + current_window = int(time.time() // self.window_seconds) + key = f"quota:{key_id}:{current_window}" + + try: + kv = await self._get_kv() + prev = await kv.get(key) or "0" + count = int(prev) + 1 + + if int(prev) == 0: + # Set with expiration datetime when it is the first request in the window. + expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds) + await kv.set(key, str(count), expiration=expiration) + else: + await kv.set(key, str(count)) + except Exception: + logger.exception("Failed to access KV store for quota") + return await self._send_error(send, 500, "Quota service error") + + if count > limit: + logger.warning( + "Quota exceeded for client %s: %d/%d", + key_id, + count, + limit, + ) + return await self._send_error(send, 429, "Quota exceeded") + + return await self.app(scope, receive, send) + + async def _send_error(self, send: Send, status: int, message: str): + await send( + { + "type": "http.response.start", + "status": status, + "headers": [[b"content-type", b"application/json"]], + } + ) + body = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": body}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e25bf0817..52f2b71b0 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -60,6 +60,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints +from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -434,6 +435,35 @@ def main(args: argparse.Namespace | None = None): if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) + else: + if config.server.quota: + quota = config.server.quota + logger.warning( + "Configured authenticated_max_requests (%d) but no auth is enabled; " + "falling back to anonymous_max_requests (%d) for all the requests", + quota.authenticated_max_requests, + quota.anonymous_max_requests, + ) + + if config.server.quota: + logger.info("Enabling quota middleware for authenticated and anonymous clients") + + quota = config.server.quota + anonymous_max_requests = quota.anonymous_max_requests + # if auth is disabled, use the anonymous max requests + authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests + + kv_config = quota.kvstore + window_map = {"day": 86400} + window_seconds = window_map[quota.period.value] + + app.add_middleware( + QuotaMiddleware, + kv_config=kv_config, + anonymous_max_requests=anonymous_max_requests, + authenticated_max_requests=authenticated_max_requests, + window_seconds=window_seconds, + ) try: impls = asyncio.run(construct_stack(config)) diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py new file mode 100644 index 000000000..763bf8e94 --- /dev/null +++ b/tests/unit/server/test_quota.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from starlette.middleware.base import BaseHTTPMiddleware + +from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod +from llama_stack.distribution.server.quota import QuotaMiddleware +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class InjectClientIDMiddleware(BaseHTTPMiddleware): + """ + Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware. + """ + + def __init__(self, app, client_id="client1"): + super().__init__(app) + self.client_id = client_id + + async def dispatch(self, request: Request, call_next): + request.scope["authenticated_client_id"] = self.client_id + return await call_next(request) + + +def build_quota_config(db_path) -> QuotaConfig: + return QuotaConfig( + kvstore=SqliteKVStoreConfig(db_path=str(db_path)), + anonymous_max_requests=1, + authenticated_max_requests=2, + period=QuotaPeriod.DAY, + ) + + +@pytest.fixture +def auth_app(tmp_path, request): + """ + FastAPI app with InjectClientIDMiddleware and QuotaMiddleware for authenticated testing. + Each test gets its own DB file. + """ + inner_app = FastAPI() + + @inner_app.get("/test") + async def test_endpoint(): + return {"message": "ok"} + + db_path = tmp_path / f"quota_{request.node.name}.db" + quota = build_quota_config(db_path) + + app = InjectClientIDMiddleware( + QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ), + client_id=f"client_{request.node.name}", + ) + return app + + +def test_authenticated_quota_allows_up_to_limit(auth_app): + client = TestClient(auth_app) + assert client.get("/test").status_code == 200 + assert client.get("/test").status_code == 200 + + +def test_authenticated_quota_blocks_after_limit(auth_app): + client = TestClient(auth_app) + client.get("/test") + client.get("/test") + resp = client.get("/test") + assert resp.status_code == 429 + assert resp.json()["error"]["message"] == "Quota exceeded" + + +def test_anonymous_quota_allows_up_to_limit(tmp_path, request): + inner_app = FastAPI() + + @inner_app.get("/test") + async def test_endpoint(): + return {"message": "ok"} + + db_path = tmp_path / f"quota_anon_{request.node.name}.db" + quota = build_quota_config(db_path) + + app = QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ) + + client = TestClient(app) + assert client.get("/test").status_code == 200 + + +def test_anonymous_quota_blocks_after_limit(tmp_path, request): + inner_app = FastAPI() + + @inner_app.get("/test") + async def test_endpoint(): + return {"message": "ok"} + + db_path = tmp_path / f"quota_anon_{request.node.name}.db" + quota = build_quota_config(db_path) + + app = QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ) + + client = TestClient(app) + client.get("/test") + resp = client.get("/test") + assert resp.status_code == 429 + assert resp.json()["error"]["message"] == "Quota exceeded" From c25acedbcd910c9643269f655b058906ac53a0b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 21 May 2025 16:23:54 +0200 Subject: [PATCH 13/61] chore: remove k8s auth in favor of k8s jwks endpoint (#2216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Kubernetes since 1.20 exposes a JWKS endpoint that we can use with our recent oauth2 recent implementation. The CI test has been kept intact for validation. Signed-off-by: Sébastien Han --- .github/workflows/integration-auth-tests.yml | 39 ++++- docs/source/distributions/configuration.md | 68 ++++++-- llama_stack/distribution/datatypes.py | 4 +- llama_stack/distribution/server/auth.py | 5 +- .../distribution/server/auth_providers.py | 162 +++++------------- pyproject.toml | 1 - requirements.txt | 8 - tests/unit/server/test_auth.py | 121 +------------ uv.lock | 98 +---------- 9 files changed, 147 insertions(+), 359 deletions(-) diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 82a76ad32..994bd1dec 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - auth-provider: [kubernetes] + auth-provider: [oauth2_token] fail-fast: false # we want to run all tests regardless of failure steps: @@ -47,29 +47,53 @@ jobs: uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 - name: Start minikube - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | minikube start kubectl get pods -A - name: Configure Kube Auth - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token + cat <> $GITHUB_ENV + echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV + echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV + echo "KUBERNETES_AUDIENCE=$(kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | run_dir=$(mktemp -d) cat <<'EOF' > $run_dir/run.yaml @@ -81,7 +105,8 @@ jobs: port: 8321 EOF yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml source .venv/bin/activate diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 7a42f503a..77b52a621 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -118,11 +118,6 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS - auth: # Optional: Authentication configuration - provider_type: "kubernetes" # Type of auth provider - config: # Provider-specific configuration - api_server_url: "https://kubernetes.default.svc" - ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate ``` ### Authentication Configuration @@ -135,7 +130,7 @@ Authorization: Bearer The server supports multiple authentication providers: -#### Kubernetes Provider +#### OAuth 2.0/OpenID Connect Provider with Kubernetes The Kubernetes cluster must be configured to use a service account for authentication. @@ -146,14 +141,67 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token ``` -Validates tokens against the Kubernetes API server: +Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests +and that the correct RoleBinding is created to allow the service account to access the necessary +resources. If that is not the case, you can create a RoleBinding for the service account to access +the necessary resources: + +```yaml +# allow-anonymous-openid.yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: allow-anonymous-openid +rules: +- nonResourceURLs: ["/openid/v1/jwks"] + verbs: ["get"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: allow-anonymous-openid +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: allow-anonymous-openid +subjects: +- kind: User + name: system:anonymous + apiGroup: rbac.authorization.k8s.io +``` + +And then apply the configuration: +```bash +kubectl apply -f allow-anonymous-openid.yaml +``` + +Validates tokens against the Kubernetes API server through the OIDC provider: ```yaml server: auth: - provider_type: "kubernetes" + provider_type: "oauth2_token" config: - api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server - ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate + jwks: + uri: "https://kubernetes.default.svc" + cache_ttl: 3600 + tls_cafile: "/path/to/ca.crt" + issuer: "https://kubernetes.default.svc" + audience: "https://kubernetes.default.svc" +``` + +To find your cluster's audience, run: +```bash +kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud +``` + +For the issuer, you can use the OIDC provider's URL: +```bash +kubectl get --raw /.well-known/openid-configuration| jq .issuer +``` + +For the tls_cafile, you can use the CA certificate of the OIDC provider: +```bash +kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' ``` The provider extracts user information from the JWT token: diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ca3664828..eb790ad93 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -220,14 +220,14 @@ class LoggingConfig(BaseModel): class AuthProviderType(str, Enum): """Supported authentication provider types.""" - KUBERNETES = "kubernetes" + OAUTH2_TOKEN = "oauth2_token" CUSTOM = "custom" class AuthenticationConfig(BaseModel): provider_type: AuthProviderType = Field( ..., - description="Type of authentication provider (e.g., 'kubernetes', 'custom')", + description="Type of authentication provider", ) config: dict[str, Any] = Field( ..., diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 67acffe3e..fb26b49a7 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -8,7 +8,8 @@ import json import httpx -from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider +from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.server.auth_providers import create_auth_provider from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -77,7 +78,7 @@ class AuthenticationMiddleware: access resources that don't have access_attributes defined. """ - def __init__(self, app, auth_config: AuthProviderConfig): + def __init__(self, app, auth_config: AuthenticationConfig): self.app = app self.auth_provider = create_auth_provider(auth_config) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index baab75eca..39f258c3b 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import ssl import time from abc import ABC, abstractmethod from asyncio import Lock -from enum import Enum -from typing import Any +from pathlib import Path from urllib.parse import parse_qs import httpx @@ -18,7 +16,7 @@ from jose import jwt from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -76,21 +74,6 @@ class AuthRequest(BaseModel): request: AuthRequestContext = Field(description="Context information about the request being authenticated") -class AuthProviderType(str, Enum): - """Supported authentication provider types.""" - - KUBERNETES = "kubernetes" - CUSTOM = "custom" - OAUTH2_TOKEN = "oauth2_token" - - -class AuthProviderConfig(BaseModel): - """Base configuration for authentication providers.""" - - provider_type: AuthProviderType = Field(..., description="Type of authentication provider") - config: dict[str, Any] = Field(..., description="Provider-specific configuration") - - class AuthProvider(ABC): """Abstract base class for authentication providers.""" @@ -105,83 +88,6 @@ class AuthProvider(ABC): pass -class KubernetesAuthProviderConfig(BaseModel): - api_server_url: str - ca_cert_path: str | None = None - - -class KubernetesAuthProvider(AuthProvider): - """Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" - - def __init__(self, config: KubernetesAuthProviderConfig): - self.config = config - self._client = None - - async def _get_client(self): - """Get or create a Kubernetes client.""" - if self._client is None: - # kubernetes-client has not async support, see: - # https://github.com/kubernetes-client/python/issues/323 - from kubernetes import client - from kubernetes.client import ApiClient - - # Configure the client - configuration = client.Configuration() - configuration.host = self.config.api_server_url - if self.config.ca_cert_path: - configuration.ssl_ca_cert = self.config.ca_cert_path - configuration.verify_ssl = bool(self.config.ca_cert_path) - - # Create API client - self._client = ApiClient(configuration) - return self._client - - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: - """Validate a Kubernetes token and return access attributes.""" - try: - client = await self._get_client() - - # Set the token in the client - client.set_default_header("Authorization", f"Bearer {token}") - - # Make a request to validate the token - # We use the /api endpoint which requires authentication - from kubernetes.client import CoreV1Api - - api = CoreV1Api(client) - api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request - - # If we get here, the token is valid - # Extract user info from the token claims - import base64 - - # Decode the token (without verification since we've already validated it) - token_parts = token.split(".") - payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4))) - - # Extract user information from the token - username = payload.get("sub", "") - groups = payload.get("groups", []) - - return TokenValidationResult( - principal=username, - access_attributes=AccessAttributes( - roles=[username], # Use username as a role - teams=groups, # Use Kubernetes groups as teams - ), - ) - - except Exception as e: - logger.exception("Failed to validate Kubernetes token") - raise ValueError("Invalid or expired token") from e - - async def close(self): - """Close the HTTP client.""" - if self._client: - self._client.close() - self._client = None - - def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: attributes = AccessAttributes() for claim_key, attribute_key in mapping.items(): @@ -212,11 +118,13 @@ class OAuth2IntrospectionConfig(BaseModel): client_id: str client_secret: str send_secret_in_body: bool = False - tls_cafile: str | None = None class OAuth2TokenAuthProviderConfig(BaseModel): audience: str = "llama-stack" + verify_tls: bool = True + tls_cafile: Path | None = None + issuer: str | None = Field(default=None, description="The OIDC issuer URL.") claims_mapping: dict[str, str] = Field( default_factory=lambda: { "sub": "roles", @@ -265,16 +173,14 @@ class OAuth2TokenAuthProvider(AuthProvider): async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: if self.config.jwks: - return await self.validate_jwt_token(token, self.config.jwks, scope) + return await self.validate_jwt_token(token, scope) if self.config.introspection: - return await self.introspect_token(token, self.config.introspection, scope) + return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") - async def validate_jwt_token( - self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None - ) -> TokenValidationResult: + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the JWT token.""" - await self._refresh_jwks(config) + await self._refresh_jwks() try: header = jwt.get_unverified_header(token) @@ -288,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider): key_data, algorithms=[algorithm], audience=self.config.audience, - options={"verify_exp": True}, + issuer=self.config.issuer, ) except Exception as exc: raise ValueError(f"Invalid JWT token: {token}") from exc @@ -302,26 +208,27 @@ class OAuth2TokenAuthProvider(AuthProvider): access_attributes=access_attributes, ) - async def introspect_token( - self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None - ) -> TokenValidationResult: + async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using token introspection as defined by RFC 7662.""" form = { "token": token, } - if config.send_secret_in_body: - form["client_id"] = config.client_id - form["client_secret"] = config.client_secret + if self.config.introspection is None: + raise ValueError("Introspection is not configured") + + if self.config.introspection.send_secret_in_body: + form["client_id"] = self.config.introspection.client_id + form["client_secret"] = self.config.introspection.client_secret auth = None else: - auth = (config.client_id, config.client_secret) + auth = (self.config.introspection.client_id, self.config.introspection.client_secret) ssl_ctxt = None - if config.tls_cafile: - ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile) + if self.config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) try: async with httpx.AsyncClient(verify=ssl_ctxt) as client: response = await client.post( - config.url, + self.config.introspection.url, data=form, auth=auth, timeout=10.0, # Add a reasonable timeout @@ -352,11 +259,24 @@ class OAuth2TokenAuthProvider(AuthProvider): async def close(self): pass - async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None: + async def _refresh_jwks(self) -> None: + """ + Refresh the JWKS cache. + + This is a simple cache that expires after a certain amount of time (defined by `cache_ttl`). + If the cache is expired, we refresh the JWKS from the JWKS URI. + + Notes: for Kubernetes which doesn't fully implement the OIDC protocol: + * It doesn't have user authentication flows + * It doesn't have refresh tokens + """ async with self._jwks_lock: - if time.time() - self._jwks_at > config.cache_ttl: - async with httpx.AsyncClient() as client: - res = await client.get(config.uri, timeout=5) + if self.config.jwks is None: + raise ValueError("JWKS is not configured") + if time.time() - self._jwks_at > self.config.jwks.cache_ttl: + verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls + async with httpx.AsyncClient(verify=verify) as client: + res = await client.get(self.config.jwks.uri, timeout=5) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} @@ -443,13 +363,11 @@ class CustomAuthProvider(AuthProvider): self._client = None -def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: +def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_type = config.provider_type.lower() - if provider_type == "kubernetes": - return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config)) - elif provider_type == "custom": + if provider_type == "custom": return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) elif provider_type == "oauth2_token": return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) diff --git a/pyproject.toml b/pyproject.toml index a41830e64..8b922bafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "tiktoken", "pillow", "h11>=0.16.0", - "kubernetes", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 6dfcc1024..2fe72c803 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,19 +4,16 @@ annotated-types==0.7.0 anyio==4.8.0 attrs==25.1.0 blobfile==3.0.0 -cachetools==5.5.2 certifi==2025.1.31 charset-normalizer==3.4.1 click==8.1.8 colorama==0.4.6 ; sys_platform == 'win32' distro==1.9.0 -durationpy==0.9 ecdsa==0.19.1 exceptiongroup==1.2.2 ; python_full_version < '3.11' filelock==3.17.0 fire==0.7.0 fsspec==2024.12.0 -google-auth==2.38.0 h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 @@ -26,14 +23,12 @@ jinja2==3.1.6 jiter==0.8.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 -kubernetes==32.0.1 llama-stack-client==0.2.7 lxml==5.3.1 markdown-it-py==3.0.0 markupsafe==3.0.2 mdurl==0.1.2 numpy==2.2.3 -oauthlib==3.2.2 openai==1.71.0 packaging==24.2 pandas==2.2.3 @@ -41,7 +36,6 @@ pillow==11.1.0 prompt-toolkit==3.0.50 pyaml==25.1.0 pyasn1==0.4.8 -pyasn1-modules==0.4.1 pycryptodomex==3.21.0 pydantic==2.10.6 pydantic-core==2.27.2 @@ -54,7 +48,6 @@ pyyaml==6.0.2 referencing==0.36.2 regex==2024.11.6 requests==2.32.3 -requests-oauthlib==2.0.0 rich==13.9.4 rpds-py==0.22.3 rsa==4.9 @@ -68,4 +61,3 @@ typing-extensions==4.12.2 tzdata==2025.1 urllib3==2.3.0 wcwidth==0.2.13 -websocket-client==1.8.0 diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 56458c0e7..94c486f18 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -11,12 +11,10 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.datatypes import AuthenticationConfig from llama_stack.distribution.server.auth import AuthenticationMiddleware from llama_stack.distribution.server.auth_providers import ( - AuthProviderConfig, AuthProviderType, - TokenValidationResult, get_attributes_from_claims, ) @@ -62,7 +60,7 @@ def invalid_token(): @pytest.fixture def http_app(mock_auth_endpoint): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.CUSTOM, config={"endpoint": mock_auth_endpoint}, ) @@ -78,7 +76,7 @@ def http_app(mock_auth_endpoint): @pytest.fixture def k8s_app(): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.KUBERNETES, config={"api_server_url": "https://kubernetes.default.svc"}, ) @@ -118,7 +116,7 @@ def mock_scope(): @pytest.fixture def mock_http_middleware(mock_auth_endpoint): mock_app = AsyncMock() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.CUSTOM, config={"endpoint": mock_auth_endpoint}, ) @@ -128,7 +126,7 @@ def mock_http_middleware(mock_auth_endpoint): @pytest.fixture def mock_k8s_middleware(): mock_app = AsyncMock() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.KUBERNETES, config={"api_server_url": "https://kubernetes.default.svc"}, ) @@ -284,116 +282,13 @@ async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope): assert attributes["roles"] == ["test.jwt.token"] -# Kubernetes Tests -def test_missing_auth_header_k8s(k8s_client): - response = k8s_client.get("/test") - assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] - - -def test_invalid_auth_header_format_k8s(k8s_client): - response = k8s_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) - assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] - - -@patch("kubernetes.client.ApiClient") -def test_valid_k8s_authentication(mock_api_client, k8s_client, valid_token): - # Mock the Kubernetes client - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock successful token validation - mock_client.set_default_header = AsyncMock() - - # Mock the token validation to return valid access attributes - with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate: - mock_validate.return_value = TokenValidationResult( - principal="test-principal", - access_attributes=AccessAttributes( - roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"] - ), - ) - response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) - assert response.status_code == 200 - assert response.json() == {"message": "Authentication successful"} - - -@patch("kubernetes.client.ApiClient") -def test_invalid_k8s_authentication(mock_api_client, k8s_client, invalid_token): - # Mock the Kubernetes client - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock failed token validation by raising an exception - with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate: - mock_validate.side_effect = ValueError("Invalid or expired token") - response = k8s_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"}) - assert response.status_code == 401 - assert "Invalid or expired token" in response.json()["error"]["message"] - - -@pytest.mark.asyncio -async def test_k8s_middleware_with_access_attributes(mock_k8s_middleware, mock_scope): - middleware, mock_app = mock_k8s_middleware - mock_receive = AsyncMock() - mock_send = AsyncMock() - - with patch("kubernetes.client.ApiClient") as mock_api_client: - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock successful token validation - mock_client.set_default_header = AsyncMock() - - # Mock token payload with access attributes - mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiIsImdyb3VwcyI6WyJtbC10ZWFtIl19", "signature"] - mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode()) - - await middleware(mock_scope, mock_receive, mock_send) - - assert "user_attributes" in mock_scope - assert mock_scope["user_attributes"]["roles"] == ["admin"] - assert mock_scope["user_attributes"]["teams"] == ["ml-team"] - - mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) - - -@pytest.mark.asyncio -async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope): - """Test middleware behavior with no access attributes""" - middleware, mock_app = mock_k8s_middleware - mock_receive = AsyncMock() - mock_send = AsyncMock() - - with patch("kubernetes.client.ApiClient") as mock_api_client: - mock_client = AsyncMock() - mock_api_client.return_value = mock_client - - # Mock successful token validation - mock_client.set_default_header = AsyncMock() - - # Mock token payload without access attributes - mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiJ9", "signature"] - mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode()) - - await middleware(mock_scope, mock_receive, mock_send) - - assert "user_attributes" in mock_scope - attributes = mock_scope["user_attributes"] - assert "roles" in attributes - assert attributes["roles"] == ["admin"] - - mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) - - # oauth2 token provider tests @pytest.fixture def oauth2_app(): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": { @@ -530,7 +425,7 @@ def mock_introspection_endpoint(): @pytest.fixture def introspection_app(mock_introspection_endpoint): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": None, @@ -549,7 +444,7 @@ def introspection_app(mock_introspection_endpoint): @pytest.fixture def introspection_app_with_custom_mapping(mock_introspection_endpoint): app = FastAPI() - auth_config = AuthProviderConfig( + auth_config = AuthenticationConfig( provider_type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": None, diff --git a/uv.lock b/uv.lock index c30e2c4c1..a987ddc9e 100644 --- a/uv.lock +++ b/uv.lock @@ -676,15 +676,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 }, ] -[[package]] -name = "durationpy" -version = "0.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/31/e9/f49c4e7fccb77fa5c43c2480e09a857a78b41e7331a75e128ed5df45c56b/durationpy-0.9.tar.gz", hash = "sha256:fd3feb0a69a0057d582ef643c355c40d2fa1c942191f914d12203b1a01ac722a", size = 3186 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/a3/ac312faeceffd2d8f86bc6dcb5c401188ba5a01bc88e69bed97578a0dfcd/durationpy-0.9-py3-none-any.whl", hash = "sha256:e65359a7af5cedad07fb77a2dd3f390f8eb0b74cb845589fa6c057086834dd38", size = 3461 }, -] - [[package]] name = "ecdsa" version = "0.19.1" @@ -863,20 +854,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 }, ] -[[package]] -name = "google-auth" -version = "2.38.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cachetools" }, - { name = "pyasn1-modules" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/eb/d504ba1daf190af6b204a9d4714d457462b486043744901a6eeea711f913/google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4", size = 270866 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/47/603554949a37bca5b7f894d51896a9c534b9eab808e2520a748e081669d0/google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a", size = 210770 }, -] - [[package]] name = "googleapis-common-protos" version = "1.67.0" @@ -1324,28 +1301,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 }, ] -[[package]] -name = "kubernetes" -version = "32.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "durationpy" }, - { name = "google-auth" }, - { name = "oauthlib" }, - { name = "python-dateutil" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "requests-oauthlib" }, - { name = "six" }, - { name = "urllib3" }, - { name = "websocket-client" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b7/e8/0598f0e8b4af37cd9b10d8b87386cf3173cb8045d834ab5f6ec347a758b3/kubernetes-32.0.1.tar.gz", hash = "sha256:42f43d49abd437ada79a79a16bd48a604d3471a117a8347e87db693f2ba0ba28", size = 946691 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/08/10/9f8af3e6f569685ce3af7faab51c8dd9d93b9c38eba339ca31c746119447/kubernetes-32.0.1-py2.py3-none-any.whl", hash = "sha256:35282ab8493b938b08ab5526c7ce66588232df00ef5e1dbe88a419107dc10998", size = 1988070 }, -] - [[package]] name = "levenshtein" version = "0.27.1" @@ -1441,7 +1396,6 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "jsonschema" }, - { name = "kubernetes" }, { name = "llama-stack-client" }, { name = "openai" }, { name = "pillow" }, @@ -1546,7 +1500,6 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.6" }, { name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" }, { name = "jsonschema" }, - { name = "kubernetes" }, { name = "llama-stack-client", specifier = ">=0.2.7" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" }, { name = "mcp", marker = "extra == 'test'" }, @@ -1624,9 +1577,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/6b/31c07396c5b3010668e4eb38061a96ffacb47ec4b14d8aeb64c13856c485/llama_stack_client-0.2.7.tar.gz", hash = "sha256:11aee11fdd5e0e8caad07c0cce9c4d88640938844372e7e3453a91ea0757fcb3", size = 259273, upload-time = "2025-05-16T20:31:39.221Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/6b/31c07396c5b3010668e4eb38061a96ffacb47ec4b14d8aeb64c13856c485/llama_stack_client-0.2.7.tar.gz", hash = "sha256:11aee11fdd5e0e8caad07c0cce9c4d88640938844372e7e3453a91ea0757fcb3", size = 259273 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/69/6a5f4683afe355500df4376fdcbfb2fc1e6a0c3bcea5ff8f6114773a9acf/llama_stack_client-0.2.7-py3-none-any.whl", hash = "sha256:78b3f2abdb1770c7b1270a9c0ef58402a988401c564d2e6c83588779ac6fc38d", size = 292727, upload-time = "2025-05-16T20:31:37.587Z" }, + { url = "https://files.pythonhosted.org/packages/ac/69/6a5f4683afe355500df4376fdcbfb2fc1e6a0c3bcea5ff8f6114773a9acf/llama_stack_client-0.2.7-py3-none-any.whl", hash = "sha256:78b3f2abdb1770c7b1270a9c0ef58402a988401c564d2e6c83588779ac6fc38d", size = 292727 }, ] [[package]] @@ -2087,15 +2040,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/7f/d322a4125405920401450118dbdc52e0384026bd669939484670ce8b2ab9/numpy-2.2.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:783145835458e60fa97afac25d511d00a1eca94d4a8f3ace9fe2043003c678e4", size = 12839607 }, ] -[[package]] -name = "oauthlib" -version = "3.2.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 }, -] - [[package]] name = "openai" version = "1.71.0" @@ -2608,18 +2552,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/1e/a94a8d635fa3ce4cfc7f506003548d0a2447ae76fd5ca53932970fe3053f/pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", size = 77145 }, ] -[[package]] -name = "pyasn1-modules" -version = "0.4.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, -] - [[package]] name = "pycparser" version = "2.22" @@ -2875,9 +2807,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973 } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382 }, ] [[package]] @@ -3256,19 +3188,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] -[[package]] -name = "requests-oauthlib" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "oauthlib" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179 }, -] - [[package]] name = "rich" version = "13.9.4" @@ -4323,15 +4242,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, ] -[[package]] -name = "websocket-client" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e6/30/fba0d96b4b5fbf5948ed3f4681f7da2f9f64512e1d303f94b4cc174c24a5/websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da", size = 54648 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826 }, -] - [[package]] name = "websockets" version = "15.0" From 1862de4be51fa3697d54525c65aebe9edc6c8514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 21 May 2025 17:30:23 +0200 Subject: [PATCH 14/61] chore: clarify cache_ttl to be key_recheck_period (#2220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The cache_ttl config value is not in fact tied to the lifetime of any of the keys, it represents the time interval between for our key cache refresher. Signed-off-by: Sébastien Han --- docs/source/distributions/configuration.md | 2 +- llama_stack/distribution/server/auth_providers.py | 6 +++--- tests/unit/server/test_auth.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 77b52a621..de99b6576 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -183,7 +183,7 @@ server: config: jwks: uri: "https://kubernetes.default.svc" - cache_ttl: 3600 + key_recheck_period: 3600 tls_cafile: "/path/to/ca.crt" issuer: "https://kubernetes.default.svc" audience: "https://kubernetes.default.svc" diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 39f258c3b..723a65b77 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -110,7 +110,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys uri: str - cache_ttl: int = 3600 + key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") class OAuth2IntrospectionConfig(BaseModel): @@ -263,7 +263,7 @@ class OAuth2TokenAuthProvider(AuthProvider): """ Refresh the JWKS cache. - This is a simple cache that expires after a certain amount of time (defined by `cache_ttl`). + This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). If the cache is expired, we refresh the JWKS from the JWKS URI. Notes: for Kubernetes which doesn't fully implement the OIDC protocol: @@ -273,7 +273,7 @@ class OAuth2TokenAuthProvider(AuthProvider): async with self._jwks_lock: if self.config.jwks is None: raise ValueError("JWKS is not configured") - if time.time() - self._jwks_at > self.config.jwks.cache_ttl: + if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls async with httpx.AsyncClient(verify=verify) as client: res = await client.get(self.config.jwks.uri, timeout=5) diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 94c486f18..408acb88a 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -293,7 +293,7 @@ def oauth2_app(): config={ "jwks": { "uri": "http://mock-authz-service/token/introspect", - "cache_ttl": "3600", + "key_recheck_period": "3600", }, "audience": "llama-stack", }, From 6a62e783b905e57c15be351ade856c33752c0dd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 21 May 2025 17:31:14 +0200 Subject: [PATCH 15/61] chore: refactor workflow writting (#2225) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Use a composite action to avoid similar steps repetitions and centralization of the defaults. Signed-off-by: Sébastien Han --- .github/actions/setup-runner/action.yml | 22 ++++++ .github/workflows/integration-auth-tests.yml | 12 +--- .github/workflows/integration-tests.yml | 18 ++--- .github/workflows/providers-build.yml | 69 +++---------------- .github/workflows/test-external-providers.yml | 12 +--- .github/workflows/unit-tests.yml | 14 ++-- .github/workflows/update-readthedocs.yml | 12 +--- 7 files changed, 45 insertions(+), 114 deletions(-) create mode 100644 .github/actions/setup-runner/action.yml diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml new file mode 100644 index 000000000..972dcbdae --- /dev/null +++ b/.github/actions/setup-runner/action.yml @@ -0,0 +1,22 @@ +name: Setup runner +description: Prepare a runner for the tests (install uv, python, project dependencies, etc.) +runs: + using: "composite" + steps: + - name: Install uv + uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 + with: + python-version: "3.10" + activate-environment: true + version: 0.7.6 + + - name: Install dependencies + shell: bash + run: | + uv sync --all-extras + uv pip install ollama faiss-cpu + # always test against the latest version of the client + # TODO: this is not necessarily a good idea. we need to test against both published and latest + # to find out backwards compatibility issues. + uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main + uv pip install -e . diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 994bd1dec..25f696c9e 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -30,16 +30,11 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - activate-environment: true + - name: Install dependencies + uses: ./.github/actions/setup-runner - - name: Set Up Environment and Install Dependencies + - name: Build Llama Stack run: | - uv sync --extra dev --extra test - uv pip install -e . llama stack build --template ollama --image-type venv - name: Install minikube @@ -109,7 +104,6 @@ jobs: yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml - source .venv/bin/activate nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & - name: Wait for Llama Stack server to be ready diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index da41e2185..2414522a7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -32,24 +32,14 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - activate-environment: true + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Setup ollama uses: ./.github/actions/setup-ollama - - name: Set Up Environment and Install Dependencies + - name: Build Llama Stack run: | - uv sync --extra dev --extra test - uv pip install ollama faiss-cpu - # always test against the latest version of the client - # TODO: this is not necessarily a good idea. we need to test against both published and latest - # to find out backwards compatibility issues. - uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main - uv pip install -e . llama stack build --template ollama --image-type venv - name: Start Llama Stack server in background @@ -57,7 +47,6 @@ jobs: env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | - source .venv/bin/activate LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & - name: Wait for Llama Stack server to be ready @@ -85,6 +74,7 @@ jobs: echo "Ollama health check failed" exit 1 fi + - name: Check Storage and Memory Available Before Tests if: ${{ always() }} run: | diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 3c1682833..cf53459b9 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -50,21 +50,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Print build dependencies run: | @@ -79,7 +66,6 @@ jobs: - name: Print dependencies in the image if: matrix.image-type == 'venv' run: | - source test/bin/activate uv pip list build-single-provider: @@ -88,21 +74,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Build a single provider run: | @@ -114,21 +87,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Build a single provider run: | @@ -152,21 +112,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Pin template to UBI9 base run: | diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml index 2e18fc5eb..06ab7cf3c 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -25,15 +25,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Set Up Environment and Install Dependencies - run: | - uv sync --extra dev --extra test - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Apply image type to config file run: | @@ -59,7 +52,6 @@ jobs: env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | - source ci-test/bin/activate uv run pip list nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index d2dd34e05..fc0459f0f 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -30,17 +30,11 @@ jobs: - "3.12" - "3.13" steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: ${{ matrix.python }} - - - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: ${{ matrix.python }} - enable-cache: false + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Run unit tests run: | diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 04e23bca9..981332a77 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -37,16 +37,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.11' - - - name: Install the latest version of uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - - - name: Sync with uv - run: uv sync --extra docs + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Build HTML run: | From 85b5f3172b0cf3eb7febcd20cd4df4a60c3c39ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 21 May 2025 17:35:27 +0200 Subject: [PATCH 16/61] docs: misc cleanup (#2223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? * remove requirements.txt to use pyproject.toml as the source of truth * update relevant docs Signed-off-by: Sébastien Han --- CONTRIBUTING.md | 7 +--- docs/readme.md | 6 +-- docs/requirements.txt | 16 -------- docs/source/conf.py | 8 ---- pyproject.toml | 3 ++ uv.lock | 88 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 96 insertions(+), 32 deletions(-) delete mode 100644 docs/requirements.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d7c3e3e2f..8f71a6ba1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -167,14 +167,11 @@ If you have made changes to a provider's configuration in any form (introducing If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. ```bash -cd docs -uv sync --extra docs - # This rebuilds the documentation pages. -uv run make html +uv run --with ".[docs]" make -C docs/ html # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. -uv run sphinx-autobuild source build/html --write-all +uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all ``` ### Update API Documentation diff --git a/docs/readme.md b/docs/readme.md index b88a4738d..d84dbe6eb 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -3,10 +3,10 @@ Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html). ## Render locally + +From the llama-stack root directory, run the following command to render the docs locally: ```bash -pip install -r requirements.txt -cd docs -python -m sphinx_autobuild source _build +uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all ``` You can open up the docs in your browser at http://localhost:8000 diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 6cd45c33b..000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,16 +0,0 @@ -linkify -myst-parser --e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -sphinx==8.1.3 -sphinx-copybutton -sphinx-design -sphinx-pdj-theme -sphinx-rtd-theme>=1.0.0 -sphinx-tabs -sphinx_autobuild -sphinx_rtd_dark_mode -sphinxcontrib-mermaid -sphinxcontrib-openapi -sphinxcontrib-redoc -sphinxcontrib-video -tomli diff --git a/docs/source/conf.py b/docs/source/conf.py index 501a923dd..43e8dbdd5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -53,14 +53,6 @@ myst_enable_extensions = ["colon_fence"] html_theme = "sphinx_rtd_theme" html_use_relative_paths = True - -# html_theme = "sphinx_pdj_theme" -# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()] - -# html_theme = "pytorch_sphinx_theme" -# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] - - templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] diff --git a/pyproject.toml b/pyproject.toml index 8b922bafb..ce44479ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ test = [ docs = [ "sphinx-autobuild", "myst-parser", + "sphinx", "sphinx-rtd-theme", "sphinx_rtd_dark_mode", "sphinx-copybutton", @@ -102,6 +103,8 @@ docs = [ "sphinxcontrib.video", "sphinxcontrib.mermaid", "tomli", + "linkify", + "sphinxcontrib.openapi", ] codegen = ["rich", "pydantic", "jinja2>=3.1.6"] ui = [ diff --git a/uv.lock b/uv.lock index a987ddc9e..6d091193b 100644 --- a/uv.lock +++ b/uv.lock @@ -628,6 +628,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 }, ] +[[package]] +name = "deepmerge" +version = "2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/3a/b0ba594708f1ad0bc735884b3ad854d3ca3bdc1d741e56e40bbda6263499/deepmerge-2.0.tar.gz", hash = "sha256:5c3d86081fbebd04dd5de03626a0607b809a98fb6ccba5770b62466fe940ff20", size = 19890 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/82/e5d2c1c67d19841e9edc74954c827444ae826978499bde3dfc1d007c8c11/deepmerge-2.0-py3-none-any.whl", hash = "sha256:6de9ce507115cff0bed95ff0ce9ecc31088ef50cbdf09bc90a09349a318b3d00", size = 13475 }, +] + [[package]] name = "deprecated" version = "1.2.18" @@ -1384,6 +1393,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/1e/408fd10217eac0e43aea0604be22b4851a09e03d761d44d4ea12089dd70e/levenshtein-0.27.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:7987ef006a3cf56a4532bd4c90c2d3b7b4ca9ad3bf8ae1ee5713c4a3bdfda913", size = 98045 }, ] +[[package]] +name = "linkify" +version = "1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/c6/246100fa3967074d9725b3716913bd495823547bde5047050d4c3462f994/linkify-1.4.tar.gz", hash = "sha256:9ba276ba179525f7262820d90f009604e51cd4f1466c1112b882ef7eda243d5e", size = 1749 } + [[package]] name = "llama-stack" version = "0.2.7" @@ -1434,7 +1449,9 @@ dev = [ { name = "uvicorn" }, ] docs = [ + { name = "linkify" }, { name = "myst-parser" }, + { name = "sphinx" }, { name = "sphinx-autobuild" }, { name = "sphinx-copybutton" }, { name = "sphinx-design" }, @@ -1442,6 +1459,7 @@ docs = [ { name = "sphinx-rtd-theme" }, { name = "sphinx-tabs" }, { name = "sphinxcontrib-mermaid" }, + { name = "sphinxcontrib-openapi" }, { name = "sphinxcontrib-redoc" }, { name = "sphinxcontrib-video" }, { name = "tomli" }, @@ -1500,6 +1518,7 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.6" }, { name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" }, { name = "jsonschema" }, + { name = "linkify", marker = "extra == 'docs'" }, { name = "llama-stack-client", specifier = ">=0.2.7" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" }, { name = "mcp", marker = "extra == 'test'" }, @@ -1534,6 +1553,7 @@ requires-dist = [ { name = "ruamel-yaml", marker = "extra == 'dev'" }, { name = "ruff", marker = "extra == 'dev'" }, { name = "setuptools" }, + { name = "sphinx", marker = "extra == 'docs'" }, { name = "sphinx-autobuild", marker = "extra == 'docs'" }, { name = "sphinx-copybutton", marker = "extra == 'docs'" }, { name = "sphinx-design", marker = "extra == 'docs'" }, @@ -1541,6 +1561,7 @@ requires-dist = [ { name = "sphinx-rtd-theme", marker = "extra == 'docs'" }, { name = "sphinx-tabs", marker = "extra == 'docs'" }, { name = "sphinxcontrib-mermaid", marker = "extra == 'docs'" }, + { name = "sphinxcontrib-openapi", marker = "extra == 'docs'" }, { name = "sphinxcontrib-redoc", marker = "extra == 'docs'" }, { name = "sphinxcontrib-video", marker = "extra == 'docs'" }, { name = "sqlite-vec", marker = "extra == 'unit'" }, @@ -1786,6 +1807,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, ] +[[package]] +name = "mistune" +version = "3.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/79/bda47f7dd7c3c55770478d6d02c9960c430b0cf1773b72366ff89126ea31/mistune-3.1.3.tar.gz", hash = "sha256:a7035c21782b2becb6be62f8f25d3df81ccb4d6fa477a6525b15af06539f02a0", size = 94347 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/4d/23c4e4f09da849e127e9f123241946c23c1e30f45a88366879e064211815/mistune-3.1.3-py3-none-any.whl", hash = "sha256:1a32314113cff28aa6432e99e522677c8587fd83e3d51c29b82a52409c842bd9", size = 53410 }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -2228,6 +2261,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, ] +[[package]] +name = "picobox" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/b1/830714dd6778c1cb45826722b4e9bd21c94b33cca5df9ef2cc0b80c81b25/picobox-4.0.0.tar.gz", hash = "sha256:114da1b5606b2f615e8b0eb68d04198ad9de75af5adbcf5b36fe4f664ab927b6", size = 22666 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/c6/fd64ffd75d47c4fcf6c65808cc5c5c75e5d4357c197d3741ee1339e91257/picobox-4.0.0-py3-none-any.whl", hash = "sha256:4c27eb689fe45dabd9e64c382e04418147d0b746d155b4e80057dbb7ff82027e", size = 11641 }, +] + [[package]] name = "pillow" version = "11.1.0" @@ -3516,6 +3558,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/43/65c0acbd8cc6f50195a3a1fc195c404988b15c67090e73c7a41a9f57d6bd/sphinx_design-0.6.1-py3-none-any.whl", hash = "sha256:b11f37db1a802a183d61b159d9a202314d4d2fe29c163437001324fe2f19549c", size = 2215338 }, ] +[[package]] +name = "sphinx-mdinclude" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "mistune" }, + { name = "pygments" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b6/a7/c9a7888bb2187fdb06955d71e75f6f266b7e179b356ac76138d160a5b7eb/sphinx_mdinclude-0.6.2.tar.gz", hash = "sha256:447462e82cb8be61404a2204227f920769eb923d2f57608e3325f3bb88286b4c", size = 65257 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/3d/6b41fe1637cd53c4b10d56e0e6f396546f837973dabf9c4b2a1de44620ac/sphinx_mdinclude-0.6.2-py3-none-any.whl", hash = "sha256:648e78edb067c0e4bffc22943278d49d54a0714494743592032fa3ad82a86984", size = 16911 }, +] + [[package]] name = "sphinx-rtd-dark-mode" version = "1.3.0" @@ -3583,6 +3640,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/7b/18a8c0bcec9182c05a0b3ec2a776bba4ead82750a55ff798e8d406dae604/sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl", hash = "sha256:166759820b47002d22914d64a075ce08f4c46818e17cfc9470a9786b759b19f8", size = 98705 }, ] +[[package]] +name = "sphinxcontrib-httpdomain" +version = "1.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/ef/82d3cfafb7febce4f7df8dcf3cde9d072350b41066e05a4f559b4e9105d0/sphinxcontrib-httpdomain-1.8.1.tar.gz", hash = "sha256:6c2dfe6ca282d75f66df333869bb0ce7331c01b475db6809ff9d107b7cdfe04b", size = 19266 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/49/aad47b8cf27a0d7703f1311aad8c368bb22866ddee1a2d2cd3f69bc45e0c/sphinxcontrib_httpdomain-1.8.1-py2.py3-none-any.whl", hash = "sha256:21eefe1270e4d9de8d717cc89ee92cc4871b8736774393bafc5e38a6bb77b1d5", size = 25513 }, +] + [[package]] name = "sphinxcontrib-jquery" version = "4.1" @@ -3617,6 +3687,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/c8/784b9ac6ea08aa594c1a4becbd0dbe77186785362e31fd633b8c6ae0197a/sphinxcontrib_mermaid-1.0.0-py3-none-any.whl", hash = "sha256:60b72710ea02087f212028feb09711225fbc2e343a10d34822fe787510e1caa3", size = 9597 }, ] +[[package]] +name = "sphinxcontrib-openapi" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deepmerge" }, + { name = "jsonschema" }, + { name = "picobox" }, + { name = "pyyaml" }, + { name = "sphinx" }, + { name = "sphinx-mdinclude" }, + { name = "sphinxcontrib-httpdomain" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/a7/66a5c9aba7dbbb0c2b050f60e71402818cbf5f127ace13ed971029cc745e/sphinxcontrib-openapi-0.8.4.tar.gz", hash = "sha256:df883808a5b5e4b4113ad697185c43a3f42df3dce70453af78ba7076907e9a20", size = 71848 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/c3/ee00486f38d78309a60ee0d6031b2545b22ac5f0007d841dd174abc68774/sphinxcontrib_openapi-0.8.4-py3-none-any.whl", hash = "sha256:50911c18d452d9390ee3a384ef8dc8bde6135f542ba55691f81e1fbc0b71014e", size = 34510 }, +] + [[package]] name = "sphinxcontrib-qthelp" version = "2.0.0" From e92301f2d7e2645c69f7d829caa98acb3774683b Mon Sep 17 00:00:00 2001 From: Varsha Date: Wed, 21 May 2025 12:24:24 -0700 Subject: [PATCH 17/61] feat(sqlite-vec): enable keyword search for sqlite-vec (#1439) # What does this PR do? This PR introduces support for keyword based FTS5 search with BM25 relevance scoring. It makes changes to the existing EmbeddingIndex base class in order to support a search_mode and query_str parameter, that can be used for keyword based search implementations. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan run ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` Output: ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================== test session starts ======================================================= platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0 asyncio: mode=auto, asyncio_default_fixture_loop_scope=None collected 7 items llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` For reference, with the implementation, the fts table looks like below: ``` Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0 Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0 Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0 Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0 Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0 Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0 Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0 Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0 Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0 Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0 Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1 Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1 Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1 Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1 Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1 ``` Query results: Result 1: Sentence 5 from document 0 Result 2: Sentence 5 from document 1 Result 3: Sentence 5 from document 2 [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing --- docs/_static/llama-stack-spec.html | 4 + docs/_static/llama-stack-spec.yaml | 4 + docs/source/providers/vector_io/sqlite-vec.md | 19 +++ llama_stack/apis/tools/rag_tool.py | 2 + .../inline/tool_runtime/rag/memory.py | 1 + .../providers/inline/vector_io/faiss/faiss.py | 16 ++- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 133 +++++++++++++++--- .../remote/vector_io/chroma/chroma.py | 8 ++ .../remote/vector_io/milvus/milvus.py | 10 +- .../remote/vector_io/pgvector/pgvector.py | 10 +- .../remote/vector_io/qdrant/qdrant.py | 10 +- .../remote/vector_io/weaviate/weaviate.py | 10 +- .../providers/utils/memory/vector_store.py | 19 ++- tests/unit/providers/vector_io/test_qdrant.py | 2 +- .../providers/vector_io/test_sqlite_vec.py | 36 ++++- 15 files changed, 247 insertions(+), 37 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 6adfe9b2b..33befc95e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11608,6 +11608,10 @@ "type": "string", "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" + }, + "mode": { + "type": "string", + "description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 31ca3f52a..cae6331b0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8086,6 +8086,10 @@ components: placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + mode: + type: string + description: >- + Search mode for retrieval—either "vector" or "keyword". Default "vector". additionalProperties: false required: - query_generator_config diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md index 43d10c751..49ba659f7 100644 --- a/docs/source/providers/vector_io/sqlite-vec.md +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use SQLite-Vec. 3. Start storing and querying vectors. +## Supported Search Modes + +The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. + +When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in +`RAGQueryConfig`. For example: + +```python +from llama_stack.apis.tool_runtime.rag import RAGQueryConfig + +query_config = RAGQueryConfig(max_chunks=6, mode="vector") + +results = client.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="what is torchtune", + query_config=query_config, +) +``` + ## Installation You can install SQLite-Vec using pip: diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index de3e4c62c..1e3542f74 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel): :param chunk_template: Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" + :param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector". """ # This config defines how a query is generated using the messages @@ -84,6 +85,7 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + mode: str | None = None @field_validator("chunk_template") def validate_chunk_template(cls, v: str) -> str: diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c46960f75..fe16c76b8 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): query=query, params={ "max_chunks": query_config.max_chunks, + "mode": query_config.mode, }, ) for vector_db_id in vector_db_ids diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index d3dc7e694..47256d88d 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -99,9 +99,13 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector( + self, + embedding: NDArray, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) - chunks = [] scores = [] for d, i in zip(distances[0], indices[0], strict=False): @@ -112,6 +116,14 @@ class FaissIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in FAISS") + class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index ab4384021..fc1a8ddb0 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect logger = logging.getLogger(__name__) +# Specifying search mode is dependent on the VectorIO provider. +VECTOR_SEARCH = "vector" +KEYWORD_SEARCH = "keyword" +SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} + def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex): Two tables are used: - A metadata table (chunks_{bank_id}) that holds the chunk JSON. - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. + - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ def __init__(self, dimension: int, db_path: str, bank_id: str): @@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex): self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") + self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") @classmethod async def create(cls, dimension: int, db_path: str, bank_id: str): @@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex): USING vec0(embedding FLOAT[{self.dimension}], id TEXT); """) connection.commit() + # FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one + # based on query. Implementation of the change on client side will allow passing the search_mode option + # during initialization to make it easier to create the table that is required. + cur.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table} + USING fts5(id, content); + """) + connection.commit() finally: cur.close() connection.close() @@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex): try: cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") + cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};") connection.commit() finally: cur.close() @@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex): For each chunk, we insert its JSON into the metadata table and then insert its embedding (serialized to raw bytes) into the virtual table using the assigned rowid. If any insert fails, the transaction is rolled back to maintain consistency. + Also inserts chunk content into FTS table for keyword search support. """ assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" @@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex): cur = connection.cursor() try: - # Start transaction a single transcation for all batches cur.execute("BEGIN TRANSACTION") for i in range(0, len(chunks), batch_size): batch_chunks = chunks[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] - # Prepare metadata inserts + + # Insert metadata metadata_data = [ (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) for chunk in batch_chunks - if isinstance(chunk.content, str) ] - # Insert metadata (ON CONFLICT to avoid duplicates) cur.executemany( f""" INSERT INTO {self.metadata_table} (id, chunk) @@ -132,21 +147,43 @@ class SQLiteVecIndex(EmbeddingIndex): """, metadata_data, ) - # Prepare embeddings inserts + + # Insert vector embeddings embedding_data = [ ( - generate_chunk_id(chunk.metadata["document_id"], chunk.content), - serialize_vector(emb.tolist()), + ( + generate_chunk_id(chunk.metadata["document_id"], chunk.content), + serialize_vector(emb.tolist()), + ) ) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) - if isinstance(chunk.content, str) ] - # Insert embeddings in batch - cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) + cur.executemany( + f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", + embedding_data, + ) + + # Insert FTS content + fts_data = [ + (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content) + for chunk in batch_chunks + ] + # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) + cur.executemany( + f"DELETE FROM {self.fts_table} WHERE id = ?;", + [(row[0],) for row in fts_data], + ) + + # INSERT new entries + cur.executemany( + f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);", + fts_data, + ) + connection.commit() except sqlite3.Error as e: - connection.rollback() # Rollback on failure + connection.rollback() logger.error(f"Error inserting into {self.vector_table}: {e}") raise @@ -154,22 +191,25 @@ class SQLiteVecIndex(EmbeddingIndex): cur.close() connection.close() - # Process all batches in a single thread + # Run batch insertion in a background thread await asyncio.to_thread(_execute_all_batch_inserts) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector( + self, + embedding: NDArray, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: """ - Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query - against the virtual table. The SQL joins the metadata table to recover the chunk JSON. + Performs vector-based search using a virtual table for vector similarity. """ - emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) - emb_blob = serialize_vector(emb_list) def _execute_query(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() - try: + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) query_sql = f""" SELECT m.id, m.chunk, v.distance FROM {self.vector_table} AS v @@ -184,17 +224,66 @@ class SQLiteVecIndex(EmbeddingIndex): connection.close() rows = await asyncio.to_thread(_execute_query) - chunks, scores = [], [] - for _id, chunk_json, distance in rows: + for row in rows: + _id, chunk_json, distance = row + score = 1.0 / distance if distance != 0 else float("inf") + if score < score_threshold: + continue + try: + chunk = Chunk.model_validate_json(chunk_json) + except Exception as e: + logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + continue + chunks.append(chunk) + scores.append(score) + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search. + """ + if query_string is None: + raise ValueError("query_string is required for keyword search.") + + def _execute_query(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + query_sql = f""" + SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score + FROM {self.fts_table} AS f + JOIN {self.metadata_table} AS m ON m.id = f.id + WHERE f.content MATCH ? + ORDER BY score ASC + LIMIT ?; + """ + cur.execute(query_sql, (query_string, k)) + return cur.fetchall() + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_execute_query) + chunks, scores = [], [] + for row in rows: + _id, chunk_json, score = row + # BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative). + # This design is intentional to simplify sorting by ascending score. + # Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html + if score > -score_threshold: + continue try: chunk = Chunk.model_validate_json(chunk_json) except Exception as e: logger.error(f"Error parsing chunk JSON for id {_id}: {e}") continue chunks.append(chunk) - # Mimic the Faiss scoring: score = 1/distance (avoid division by zero) - score = 1.0 / distance if distance != 0 else float("inf") scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index a919963ab..a59a38573 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex): async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Chroma") + class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index c98417b56..6628292db 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex): logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: search_res = await asyncio.to_thread( self.client.search, collection_name=self.collection_name, @@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex): scores = [res["distance"] for res in search_res[0]] return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Milvus") + class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 94546c6cf..ea918c552 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: execute_values(cur, query, values, template="(%s, %s, %s::vector)") - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" @@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in PGVector") + async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 514a6c70d..ff0690083 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( await self.client.query_points( collection_name=self.collection_name, @@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Qdrant") + async def delete(self): await self.client.delete_collection(collection_name=self.collection_name) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 308d2eb3d..e6fe8ccd3 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( @@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex): collection = self.client.collections.get(self.collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Weaviate") + class WeaviateVectorIOAdapter( VectorIO, diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index e0e9d0679..3655c7049 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -177,7 +177,11 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + raise NotImplementedError() + + @abstractmethod + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() @abstractmethod @@ -210,9 +214,12 @@ class VectorDBWithIndex: if params is None: params = {} k = params.get("max_chunks", 3) + mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) - - query_str = interleaved_content_as_str(query) - embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str]) - query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) - return await self.index.query(query_vector, k, score_threshold) + query_string = interleaved_content_as_str(query) + if mode == "keyword": + return await self.index.query_keyword(query_string, k, score_threshold) + else: + embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) + return await self.index.query_vector(query_vector, k, score_threshold) diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py index bc97719c0..34df9b52f 100644 --- a/tests/unit/providers/vector_io/test_qdrant.py +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -98,7 +98,7 @@ async def test_qdrant_adapter_returns_expected_chunks( response = await qdrant_adapter.query_chunks( query=__QUERY, vector_db_id=vector_db_id, - params={"max_chunks": max_query_chunks}, + params={"max_chunks": max_query_chunks, "mode": "vector"}, ) assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == expected_chunks diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 32b60ffa5..010a0ca42 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -57,14 +57,46 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): @pytest.mark.asyncio -async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): +async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0) + response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == 2 +@pytest.mark.asyncio +async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + query_string = "Sentence 5" + response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 3, f"Expected three chunks, but got {len(response.chunks)}" + + non_existent_query_str = "blablabla" + response_no_results = await sqlite_vec_index.query_keyword( + query_string=non_existent_query_str, k=1, score_threshold=0.0 + ) + + assert isinstance(response_no_results, QueryChunksResponse) + assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}" + + +@pytest.mark.asyncio +async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings): + # Re-initialize with a clean index + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + query_str = "Sentence 1 from document 0" # Should match only one chunk + response = await sqlite_vec_index.query_keyword(k=5, score_threshold=0.0, query_string=query_str) + + assert isinstance(response, QueryChunksResponse) + assert 0 < len(response.chunks) < 5, f"Expected results between [1, 4], got {len(response.chunks)}" + assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found" + + @pytest.mark.asyncio async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension): """Test that chunk IDs do not conflict across batches when inserting chunks.""" From 37f1e8a7f7a6b930b77278a7d0e27fd735ddd8a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 22 May 2025 00:28:21 +0200 Subject: [PATCH 18/61] fix: use proper service account for kube auth (#2227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Not sure why it passed CI earlier... Strange only 24 workflows run here https://github.com/meta-llama/llama-stack/pull/2216 so the test never ran... Signed-off-by: Sébastien Han --- .github/workflows/integration-auth-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 25f696c9e..a3a746246 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -83,7 +83,7 @@ jobs: echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV - echo "KUBERNETES_AUDIENCE=$(kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV + echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: From 02e5e8a633921792a8e5bf8ebf9e3e5dcd058964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 22 May 2025 00:30:29 +0200 Subject: [PATCH 19/61] fix: only print routes that match the runtime config (#2226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? We now only print the 'active' routes, not all the possible routes. This is based on the distribution server config by looking at enabled APIs and their respective providers. Signed-off-by: Sébastien Han --- llama_stack/distribution/inspect.py | 38 ++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 23f644ec6..3321ec291 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -31,7 +31,7 @@ async def get_provider_impl(config, deps): class DistributionInspectImpl(Inspect): - def __init__(self, config, deps): + def __init__(self, config: DistributionInspectConfig, deps): self.config = config self.deps = deps @@ -39,22 +39,36 @@ class DistributionInspectImpl(Inspect): pass async def list_routes(self) -> ListRoutesResponse: - run_config = self.config.run_config + run_config: StackRunConfig = self.config.run_config ret = [] all_endpoints = get_all_api_endpoints() for api, endpoints in all_endpoints.items(): - providers = run_config.providers.get(api.value, []) - ret.extend( - [ - RouteInfo( - route=e.route, - method=e.method, - provider_types=[p.provider_type for p in providers], + # Always include provider and inspect APIs, filter others based on run config + if api.value in ["providers", "inspect"]: + ret.extend( + [ + RouteInfo( + route=e.route, + method=e.method, + provider_types=[], # These APIs don't have "real" providers - they're internal to the stack + ) + for e in endpoints + ] + ) + else: + providers = run_config.providers.get(api.value, []) + if providers: # Only process if there are providers for this API + ret.extend( + [ + RouteInfo( + route=e.route, + method=e.method, + provider_types=[p.provider_type for p in providers], + ) + for e in endpoints + ] ) - for e in endpoints - ] - ) return ListRoutesResponse(data=ret) From 633bb9c5b3273762ccf0d91845243bb11166066d Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Wed, 21 May 2025 17:33:02 -0500 Subject: [PATCH 20/61] feat(providers): sambanova safety provider (#2221) # What does this PR do? Includes SambaNova safety adaptor to use the sambanova cloud served Meta-Llama-Guard-3-8B minor updates in sambanova docs ## Test Plan pytest -s -v tests/integration/safety/test_safety.py --stack-config=sambanova --safety-shield=sambanova/Meta-Llama-Guard-3-8B --- README.md | 2 +- .../self_hosted_distro/sambanova.md | 25 +++-- llama_stack/providers/registry/safety.py | 10 ++ .../remote/safety/sambanova/__init__.py | 18 ++++ .../remote/safety/sambanova/config.py | 37 +++++++ .../remote/safety/sambanova/sambanova.py | 100 ++++++++++++++++++ llama_stack/templates/sambanova/build.yaml | 4 +- .../templates/sambanova/doc_template.md | 23 ++-- llama_stack/templates/sambanova/run.yaml | 10 +- llama_stack/templates/sambanova/sambanova.py | 14 ++- pyproject.toml | 1 + 11 files changed, 222 insertions(+), 22 deletions(-) create mode 100644 llama_stack/providers/remote/safety/sambanova/__init__.py create mode 100644 llama_stack/providers/remote/safety/sambanova/config.py create mode 100644 llama_stack/providers/remote/safety/sambanova/sambanova.py diff --git a/README.md b/README.md index 5dfe3577a..e54b505cf 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ Here is a list of the various API providers and available distributions that can | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| | Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | -| SambaNova | Hosted | | ✅ | | | | +| SambaNova | Hosted | | ✅ | | ✅ | | | Cerebras | Hosted | | ✅ | | | | | Fireworks | Hosted | ✅ | ✅ | ✅ | | | | AWS Bedrock | Hosted | | ✅ | | ✅ | | diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index aaa8fd3cc..bb4842362 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p |-----|-------------| | agents | `inline::meta-reference` | | inference | `remote::sambanova`, `inline::sentence-transformers` | -| safety | `inline::llama-guard` | +| safety | `remote::sambanova` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | @@ -48,33 +48,44 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. -### Via Docker -This method allows you to get started quickly without having to build the distribution code. +### Via Docker ```bash LLAMA_STACK_PORT=8321 +llama stack build --template sambanova --image-type container docker run \ -it \ - --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - llamastack/distribution-sambanova \ + -v ~/.llama:/root/.llama \ + distribution-sambanova \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` + +### Via Venv + +```bash +llama stack build --template sambanova --image-type venv +llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + + ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run ./run.yaml \ +llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index c209da092..e0a04be48 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -63,4 +63,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_type="sambanova", + pip_packages=["litellm"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/safety/sambanova/__init__.py b/llama_stack/providers/remote/safety/sambanova/__init__.py new file mode 100644 index 000000000..bb9d15374 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import SambaNovaSafetyConfig + + +async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any: + from .sambanova import SambaNovaSafetyAdapter + + impl = SambaNovaSafetyAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/safety/sambanova/config.py b/llama_stack/providers/remote/safety/sambanova/config.py new file mode 100644 index 000000000..383cea244 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/config.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.schema_utils import json_schema_type + + +class SambaNovaProviderDataValidator(BaseModel): + sambanova_api_key: str | None = Field( + default=None, + description="Sambanova Cloud API key", + ) + + +@json_schema_type +class SambaNovaSafetyConfig(BaseModel): + url: str = Field( + default="https://api.sambanova.ai/v1", + description="The URL for the SambaNova AI server", + ) + api_key: SecretStr | None = Field( + default=None, + description="The SambaNova cloud API Key", + ) + + @classmethod + def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: + return { + "url": "https://api.sambanova.ai/v1", + "api_key": api_key, + } diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py new file mode 100644 index 000000000..84c8267ae --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging +from typing import Any + +import litellm +import requests + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new + +from .config import SambaNovaSafetyConfig + +logger = logging.getLogger(__name__) + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" + + +class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): + def __init__(self, config: SambaNovaSafetyConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def _get_api_key(self) -> str: + config_api_key = self.config.api_key if self.config.api_key else None + if config_api_key: + return config_api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.sambanova_api_key: + raise ValueError( + 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' + ) + return provider_data.sambanova_api_key + + async def register_shield(self, shield: Shield) -> None: + list_models_url = self.config.url + "/models" + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + available_models = [model.get("id") for model in response.json().get("data", {})] + if ( + len(available_models) == 0 + or "guard" not in shield.provider_resource_id.lower() + or shield.provider_resource_id.split("sambanova/")[-1] not in available_models + ): + raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") + + async def run_shield( + self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Shield {shield_id} not found") + + shield_params = shield.params + logger.debug(f"run_shield::{shield_params}::messages={messages}") + content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] + logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + + response = litellm.completion( + model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() + ) + shield_message = response.choices[0].message.content + + if "unsafe" in shield_message.lower(): + user_message = CANNED_RESPONSE_TEXT + violation_type = shield_message.split("\n")[-1] + metadata = {"violation_type": violation_type} + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) + + return RunShieldResponse() diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 81d90f420..79bb68c68 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use SambaNova for running LLM inference + description: Use SambaNova for running LLM inference and safety providers: inference: - remote::sambanova @@ -10,7 +10,7 @@ distribution_spec: - remote::chromadb - remote::pgvector safety: - - inline::llama-guard + - remote::sambanova agents: - inline::meta-reference telemetry: diff --git a/llama_stack/templates/sambanova/doc_template.md b/llama_stack/templates/sambanova/doc_template.md index 42d9efb66..1dc76fd3f 100644 --- a/llama_stack/templates/sambanova/doc_template.md +++ b/llama_stack/templates/sambanova/doc_template.md @@ -37,33 +37,44 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. -### Via Docker -This method allows you to get started quickly without having to build the distribution code. +### Via Docker ```bash LLAMA_STACK_PORT=8321 +llama stack build --template sambanova --image-type container docker run \ -it \ - --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - llamastack/distribution-{{ name }} \ + -v ~/.llama:/root/.llama \ + distribution-{{ name }} \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` + +### Via Venv + +```bash +llama stack build --template sambanova --image-type venv +llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + + ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run ./run.yaml \ +llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 620d50307..fa8735002 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -38,10 +38,11 @@ providers: user: ${env.PGVECTOR_USER:} password: ${env.PGVECTOR_PASSWORD:} safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_id: sambanova + provider_type: remote::sambanova config: - excluded_categories: [] + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY} agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -189,6 +190,9 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B +- shield_id: sambanova/Meta-Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 2f8a0b08a..54a49423d 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], + "safety": ["remote::sambanova"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], "tool_runtime": [ @@ -110,7 +110,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova for running LLM inference", + description="Use SambaNova for running LLM inference and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -122,7 +122,15 @@ def get_distribution_template() -> DistributionTemplate: "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_shields=[ + ShieldInput( + shield_id="meta-llama/Llama-Guard-3-8B", provider_shield_id="sambanova/Meta-Llama-Guard-3-8B" + ), + ShieldInput( + shield_id="sambanova/Meta-Llama-Guard-3-8B", + provider_shield_id="sambanova/Meta-Llama-Guard-3-8B", + ), + ], default_tool_groups=default_tool_groups, ), }, diff --git a/pyproject.toml b/pyproject.toml index ce44479ca..6b873968a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -270,6 +270,7 @@ exclude = [ "^llama_stack/providers/remote/inference/watsonx/", "^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/nvidia/", + "^llama_stack/providers/remote/safety/sambanova/", "^llama_stack/providers/remote/safety/sample/", "^llama_stack/providers/remote/tool_runtime/bing_search/", "^llama_stack/providers/remote/tool_runtime/brave_search/", From 549812f51e793566df0d730160ecf87bf26d9e7d Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 21 May 2025 22:21:52 -0700 Subject: [PATCH 21/61] feat: implement get chat completions APIs (#2200) # What does this PR do? * Provide sqlite implementation of the APIs introduced in https://github.com/meta-llama/llama-stack/pull/2145. * Introduced a SqlStore API: llama_stack/providers/utils/sqlstore/api.py and the first Sqlite implementation * Pagination support will be added in a future PR. ## Test Plan Unit test on sql store: image Integration test: ``` INFERENCE_MODEL="llama3.2:3b-instruct-fp16" llama stack build --template ollama --image-type conda --run ``` ``` LLAMA_STACK_CONFIG=http://localhost:5001 INFERENCE_MODEL="llama3.2:3b-instruct-fp16" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-fp16" -k 'inference_store and openai' ``` --- llama_stack/distribution/build.py | 15 ++ llama_stack/distribution/datatypes.py | 12 ++ llama_stack/distribution/resolver.py | 12 +- llama_stack/distribution/routers/__init__.py | 12 +- llama_stack/distribution/routers/routers.py | 33 +++- .../utils/inference/inference_store.py | 123 +++++++++++++ .../providers/utils/inference/stream_utils.py | 129 ++++++++++++++ llama_stack/providers/utils/sqlstore/api.py | 90 ++++++++++ .../providers/utils/sqlstore/sqlite/sqlite.py | 161 ++++++++++++++++++ .../providers/utils/sqlstore/sqlstore.py | 72 ++++++++ llama_stack/templates/bedrock/build.yaml | 2 + llama_stack/templates/bedrock/run.yaml | 3 + llama_stack/templates/cerebras/build.yaml | 2 + llama_stack/templates/cerebras/run.yaml | 3 + llama_stack/templates/ci-tests/build.yaml | 2 + llama_stack/templates/ci-tests/run.yaml | 3 + llama_stack/templates/dell/build.yaml | 3 + .../templates/dell/run-with-safety.yaml | 3 + llama_stack/templates/dell/run.yaml | 3 + llama_stack/templates/dependencies.json | 22 +++ llama_stack/templates/fireworks/build.yaml | 3 + .../templates/fireworks/run-with-safety.yaml | 3 + llama_stack/templates/fireworks/run.yaml | 3 + llama_stack/templates/groq/build.yaml | 2 + llama_stack/templates/groq/run.yaml | 3 + llama_stack/templates/hf-endpoint/build.yaml | 3 + .../hf-endpoint/run-with-safety.yaml | 3 + llama_stack/templates/hf-endpoint/run.yaml | 3 + .../templates/hf-serverless/build.yaml | 3 + .../hf-serverless/run-with-safety.yaml | 3 + llama_stack/templates/hf-serverless/run.yaml | 3 + llama_stack/templates/llama_api/build.yaml | 2 + llama_stack/templates/llama_api/run.yaml | 3 + .../templates/meta-reference-gpu/build.yaml | 3 + .../meta-reference-gpu/run-with-safety.yaml | 3 + .../templates/meta-reference-gpu/run.yaml | 3 + llama_stack/templates/nvidia/build.yaml | 3 + .../templates/nvidia/run-with-safety.yaml | 3 + llama_stack/templates/nvidia/run.yaml | 3 + llama_stack/templates/ollama/build.yaml | 3 + .../templates/ollama/run-with-safety.yaml | 3 + llama_stack/templates/ollama/run.yaml | 3 + .../templates/open-benchmark/build.yaml | 2 + llama_stack/templates/open-benchmark/run.yaml | 3 + llama_stack/templates/passthrough/build.yaml | 3 + .../passthrough/run-with-safety.yaml | 3 + llama_stack/templates/passthrough/run.yaml | 3 + llama_stack/templates/remote-vllm/build.yaml | 3 + .../remote-vllm/run-with-safety.yaml | 3 + llama_stack/templates/remote-vllm/run.yaml | 3 + llama_stack/templates/sambanova/build.yaml | 2 + llama_stack/templates/sambanova/run.yaml | 3 + llama_stack/templates/starter/build.yaml | 2 + llama_stack/templates/starter/run.yaml | 3 + llama_stack/templates/template.py | 13 +- llama_stack/templates/tgi/build.yaml | 3 + .../templates/tgi/run-with-safety.yaml | 3 + llama_stack/templates/tgi/run.yaml | 3 + llama_stack/templates/together/build.yaml | 3 + .../templates/together/run-with-safety.yaml | 3 + llama_stack/templates/together/run.yaml | 3 + llama_stack/templates/verification/build.yaml | 2 + llama_stack/templates/verification/run.yaml | 3 + llama_stack/templates/vllm-gpu/build.yaml | 2 + llama_stack/templates/vllm-gpu/run.yaml | 3 + llama_stack/templates/watsonx/build.yaml | 2 + llama_stack/templates/watsonx/run.yaml | 3 + pyproject.toml | 2 + .../inference/test_openai_completion.py | 102 +++++++++++ tests/unit/utils/test_sqlstore.py | 62 +++++++ uv.lock | 107 +++++++++++- 71 files changed, 1111 insertions(+), 10 deletions(-) create mode 100644 llama_stack/providers/utils/inference/inference_store.py create mode 100644 llama_stack/providers/utils/inference/stream_utils.py create mode 100644 llama_stack/providers/utils/sqlstore/api.py create mode 100644 llama_stack/providers/utils/sqlstore/sqlite/sqlite.py create mode 100644 llama_stack/providers/utils/sqlstore/sqlstore.py create mode 100644 tests/unit/utils/test_sqlstore.py diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 1d39063f0..3e9dc2028 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -43,8 +43,20 @@ def get_provider_dependencies( # Extract providers based on config type if isinstance(config, DistributionTemplate): providers = config.providers + + # TODO: This is a hack to get the dependencies for internal APIs into build + # We should have a better way to do this by formalizing the concept of "internal" APIs + # and providers, with a way to specify dependencies for them. + run_configs = config.run_configs + additional_pip_packages: list[str] = [] + if run_configs: + for run_config in run_configs.values(): + run_config_ = run_config.run_config(name="", providers={}, container_image=None) + if run_config_.inference_store: + additional_pip_packages.extend(run_config_.inference_store.pip_packages) elif isinstance(config, BuildConfig): providers = config.distribution_spec.providers + additional_pip_packages = config.additional_pip_packages deps = [] registry = get_provider_registry(config) for api_str, provider_or_providers in providers.items(): @@ -72,6 +84,9 @@ def get_provider_dependencies( else: normal_deps.append(package) + if additional_pip_packages: + normal_deps.extend(additional_pip_packages) + return list(set(normal_deps)), list(set(special_deps)) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index eb790ad93..aeb2b997a 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -26,6 +26,7 @@ from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -314,6 +315,13 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) + inference_store: SqlStoreConfig | None = Field( + default=None, + description=""" +Configuration for the persistence store used by the inference API. If not specified, +a default SQLite store will be used.""", + ) + # registry of "resources" in the distribution models: list[ModelInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list) @@ -362,6 +370,10 @@ class BuildConfig(BaseModel): description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " "pip_packages MUST contain the provider package name.", ) + additional_pip_packages: list[str] = Field( + default_factory=list, + description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.", + ) @field_validator("external_providers_dir") @classmethod diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 257c495c3..8b846d051 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -140,7 +140,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry) + return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -243,7 +243,10 @@ def sort_providers_by_deps( async def instantiate_providers( - sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry + sorted_providers: list[tuple[str, ProviderWithSpec]], + router_apis: set[Api], + dist_registry: DistributionRegistry, + run_config: StackRunConfig, ) -> dict: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} @@ -258,7 +261,7 @@ async def instantiate_providers( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl @@ -308,6 +311,7 @@ async def instantiate_provider( deps: dict[Api, Any], inner_impls: dict[str, Any], dist_registry: DistributionRegistry, + run_config: StackRunConfig, ): provider_spec = provider.spec if not hasattr(provider_spec, "module"): @@ -327,7 +331,7 @@ async def instantiate_provider( method = "get_auto_router_impl" config = None - args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] + args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config] elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index cd2a296f2..84560b355 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,8 +7,10 @@ from typing import Any from llama_stack.distribution.datatypes import RoutedProtocol +from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable +from llama_stack.providers.utils.inference.inference_store import InferenceStore from .routing_tables import ( BenchmarksRoutingTable, @@ -45,7 +47,9 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any: +async def get_auto_router_impl( + api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig +) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -76,6 +80,12 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict if dep_api in deps: api_to_dep_impl[dep_name] = deps[dep_api] + # TODO: move pass configs to routers instead + if api == Api.inference and run_config.inference_store: + inference_store = InferenceStore(run_config.inference_store) + await inference_store.initialize() + api_to_dep_impl["store"] = inference_store + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 371d34904..0515b19f8 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -32,8 +32,11 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, EmbeddingTaskType, Inference, + ListOpenAIChatCompletionResponse, LogProbConfig, Message, + OpenAICompletionWithInputMessages, + Order, ResponseFormat, SamplingParams, StopReason, @@ -73,6 +76,8 @@ from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable +from llama_stack.providers.utils.inference.inference_store import InferenceStore +from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -141,10 +146,12 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, telemetry: Telemetry | None = None, + store: InferenceStore | None = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table self.telemetry = telemetry + self.store = store if self.telemetry: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) @@ -607,9 +614,31 @@ class InferenceRouter(Inference): provider = self.routing_table.get_provider_impl(model_obj.identifier) if stream: - return await provider.openai_chat_completion(**params) + response_stream = await provider.openai_chat_completion(**params) + if self.store: + return stream_and_store_openai_completion(response_stream, model, self.store, messages) + return response_stream else: - return await self._nonstream_openai_chat_completion(provider, params) + response = await self._nonstream_openai_chat_completion(provider, params) + if self.store: + await self.store.store_chat_completion(response, messages) + return response + + async def list_chat_completions( + self, + after: str | None = None, + limit: int | None = 20, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIChatCompletionResponse: + if self.store: + return await self.store.list_chat_completions(after, limit, model, order) + raise NotImplementedError("List chat completions is not supported: inference store is not configured.") + + async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + if self.store: + return await self.store.get_chat_completion(completion_id) + raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: response = await provider.openai_chat_completion(**params) diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py new file mode 100644 index 000000000..7b6bc2e3d --- /dev/null +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.inference import ( + ListOpenAIChatCompletionResponse, + OpenAIChatCompletion, + OpenAICompletionWithInputMessages, + OpenAIMessageParam, + Order, +) +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl + + +class InferenceStore: + def __init__(self, sql_store_config: SqlStoreConfig): + if not sql_store_config: + sql_store_config = SqliteSqlStoreConfig( + db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + ) + self.sql_store_config = sql_store_config + self.sql_store = None + + async def initialize(self): + """Create the necessary tables if they don't exist.""" + self.sql_store = sqlstore_impl(self.sql_store_config) + await self.sql_store.create_table( + "chat_completions", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "created": ColumnType.INTEGER, + "model": ColumnType.STRING, + "choices": ColumnType.JSON, + "input_messages": ColumnType.JSON, + }, + ) + + async def store_chat_completion( + self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] + ) -> None: + if not self.sql_store: + raise ValueError("Inference store is not initialized") + + data = chat_completion.model_dump() + + await self.sql_store.insert( + "chat_completions", + { + "id": data["id"], + "created": data["created"], + "model": data["model"], + "choices": data["choices"], + "input_messages": [message.model_dump() for message in input_messages], + }, + ) + + async def list_chat_completions( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIChatCompletionResponse: + """ + List chat completions from the database. + + :param after: The ID of the last chat completion to return. + :param limit: The maximum number of chat completions to return. + :param model: The model to filter by. + :param order: The order to sort the chat completions by. + """ + if not self.sql_store: + raise ValueError("Inference store is not initialized") + + # TODO: support after + if after: + raise NotImplementedError("After is not supported for SQLite") + if not order: + order = Order.desc + + rows = await self.sql_store.fetch_all( + "chat_completions", + where={"model": model} if model else None, + order_by=[("created", order.value)], + limit=limit, + ) + + data = [ + OpenAICompletionWithInputMessages( + id=row["id"], + created=row["created"], + model=row["model"], + choices=row["choices"], + input_messages=row["input_messages"], + ) + for row in rows + ] + return ListOpenAIChatCompletionResponse( + data=data, + # TODO: implement has_more + has_more=False, + first_id=data[0].id if data else "", + last_id=data[-1].id if data else "", + ) + + async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + if not self.sql_store: + raise ValueError("Inference store is not initialized") + + row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id}) + if not row: + raise ValueError(f"Chat completion with id {completion_id} not found") from None + return OpenAICompletionWithInputMessages( + id=row["id"], + created=row["created"], + model=row["model"], + choices=row["choices"], + input_messages=row["input_messages"], + ) diff --git a/llama_stack/providers/utils/inference/stream_utils.py b/llama_stack/providers/utils/inference/stream_utils.py new file mode 100644 index 000000000..a2edbb9c8 --- /dev/null +++ b/llama_stack/providers/utils/inference/stream_utils.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from typing import Any + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIChoiceLogprobs, + OpenAIMessageParam, +) +from llama_stack.providers.utils.inference.inference_store import InferenceStore + + +async def stream_and_store_openai_completion( + provider_stream: AsyncIterator[OpenAIChatCompletionChunk], + model: str, + store: InferenceStore, + input_messages: list[OpenAIMessageParam], +) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Wraps a provider's stream, yields chunks, and stores the full completion at the end. + """ + id = None + created = None + choices_data: dict[int, dict[str, Any]] = {} + + try: + async for chunk in provider_stream: + if id is None and chunk.id: + id = chunk.id + if created is None and chunk.created: + created = chunk.created + + if chunk.choices: + for choice_delta in chunk.choices: + idx = choice_delta.index + if idx not in choices_data: + choices_data[idx] = { + "content_parts": [], + "tool_calls_builder": {}, + "finish_reason": None, + "logprobs_content_parts": [], + } + current_choice_data = choices_data[idx] + + if choice_delta.delta: + delta = choice_delta.delta + if delta.content: + current_choice_data["content_parts"].append(delta.content) + if delta.tool_calls: + for tool_call_delta in delta.tool_calls: + tc_idx = tool_call_delta.index + if tc_idx not in current_choice_data["tool_calls_builder"]: + # Initialize with correct structure for _ToolCallBuilderData + current_choice_data["tool_calls_builder"][tc_idx] = { + "id": None, + "type": "function", + "function_name_parts": [], + "function_arguments_parts": [], + } + builder = current_choice_data["tool_calls_builder"][tc_idx] + if tool_call_delta.id: + builder["id"] = tool_call_delta.id + if tool_call_delta.type: + builder["type"] = tool_call_delta.type + if tool_call_delta.function: + if tool_call_delta.function.name: + builder["function_name_parts"].append(tool_call_delta.function.name) + if tool_call_delta.function.arguments: + builder["function_arguments_parts"].append(tool_call_delta.function.arguments) + if choice_delta.finish_reason: + current_choice_data["finish_reason"] = choice_delta.finish_reason + if choice_delta.logprobs and choice_delta.logprobs.content: + # Ensure that we are extending with the correct type + current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + yield chunk + finally: + if id: + assembled_choices: list[OpenAIChoice] = [] + for choice_idx, choice_data in choices_data.items(): + content_str = "".join(choice_data["content_parts"]) + assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] + if choice_data["tool_calls_builder"]: + for tc_build_data in choice_data["tool_calls_builder"].values(): + if tc_build_data["id"]: + func_name = "".join(tc_build_data["function_name_parts"]) + func_args = "".join(tc_build_data["function_arguments_parts"]) + assembled_tool_calls.append( + OpenAIChatCompletionToolCall( + id=tc_build_data["id"], + type=tc_build_data["type"], # No or "function" needed, already set + function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args), + ) + ) + message = OpenAIAssistantMessageParam( + role="assistant", + content=content_str if content_str else None, + tool_calls=assembled_tool_calls if assembled_tool_calls else None, + ) + logprobs_content = choice_data["logprobs_content_parts"] + final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None + + assembled_choices.append( + OpenAIChoice( + finish_reason=choice_data["finish_reason"], + index=choice_idx, + message=message, + logprobs=final_logprobs, + ) + ) + + final_response = OpenAIChatCompletion( + id=id, + choices=assembled_choices, + created=created or int(datetime.now(timezone.utc).timestamp()), + model=model, + object="chat.completion", + ) + await store.store_chat_completion(final_response, input_messages) diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py new file mode 100644 index 000000000..ace40e4c4 --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Mapping +from enum import Enum +from typing import Any, Literal, Protocol + +from pydantic import BaseModel + + +class ColumnType(Enum): + INTEGER = "INTEGER" + STRING = "STRING" + TEXT = "TEXT" + FLOAT = "FLOAT" + BOOLEAN = "BOOLEAN" + JSON = "JSON" + DATETIME = "DATETIME" + + +class ColumnDefinition(BaseModel): + type: ColumnType + primary_key: bool = False + nullable: bool = True + default: Any = None + + +class SqlStore(Protocol): + """ + A protocol for a SQL store. + """ + + async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: + """ + Create a table. + """ + pass + + async def insert(self, table: str, data: Mapping[str, Any]) -> None: + """ + Insert a row into a table. + """ + pass + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> list[dict[str, Any]]: + """ + Fetch all rows from a table. + """ + pass + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + """ + Fetch one row from a table. + """ + pass + + async def update( + self, + table: str, + data: Mapping[str, Any], + where: Mapping[str, Any], + ) -> None: + """ + Update a row in a table. + """ + pass + + async def delete( + self, + table: str, + where: Mapping[str, Any], + ) -> None: + """ + Delete a row from a table. + """ + pass diff --git a/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py b/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py new file mode 100644 index 000000000..0ef5f0fa1 --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from collections.abc import Mapping +from typing import Any, Literal + +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Float, + Integer, + MetaData, + String, + Table, + Text, + select, +) +from sqlalchemy.ext.asyncio import create_async_engine + +from ..api import ColumnDefinition, ColumnType, SqlStore +from ..sqlstore import SqliteSqlStoreConfig + +TYPE_MAPPING: dict[ColumnType, Any] = { + ColumnType.INTEGER: Integer, + ColumnType.STRING: String, + ColumnType.FLOAT: Float, + ColumnType.BOOLEAN: Boolean, + ColumnType.DATETIME: DateTime, + ColumnType.TEXT: Text, + ColumnType.JSON: JSON, +} + + +class SqliteSqlStoreImpl(SqlStore): + def __init__(self, config: SqliteSqlStoreConfig): + self.engine = create_async_engine(config.engine_str) + self.metadata = MetaData() + + async def create_table( + self, + table: str, + schema: Mapping[str, ColumnType | ColumnDefinition], + ) -> None: + if not schema: + raise ValueError(f"No columns defined for table '{table}'.") + + sqlalchemy_columns: list[Column] = [] + + for col_name, col_props in schema.items(): + col_type = None + is_primary_key = False + is_nullable = True # Default to nullable + + if isinstance(col_props, ColumnType): + col_type = col_props + elif isinstance(col_props, ColumnDefinition): + col_type = col_props.type + is_primary_key = col_props.primary_key + is_nullable = col_props.nullable + + sqlalchemy_type = TYPE_MAPPING.get(col_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.") + + sqlalchemy_columns.append( + Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) + ) + + # Check if table already exists in metadata, otherwise define it + if table not in self.metadata.tables: + sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) + else: + sqlalchemy_table = self.metadata.tables[table] + + # Create the table in the database if it doesn't exist + # checkfirst=True ensures it doesn't try to recreate if it's already there + async with self.engine.begin() as conn: + await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) + + async def insert(self, table: str, data: Mapping[str, Any]) -> None: + async with self.engine.begin() as conn: + await conn.execute(self.metadata.tables[table].insert(), data) + await conn.commit() + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> list[dict[str, Any]]: + async with self.engine.begin() as conn: + query = select(self.metadata.tables[table]) + if where: + for key, value in where.items(): + query = query.where(self.metadata.tables[table].c[key] == value) + if limit: + query = query.limit(limit) + if order_by: + if not isinstance(order_by, list): + raise ValueError( + f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" + ) + for order in order_by: + if not isinstance(order, tuple): + raise ValueError( + f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" + ) + name, order_type = order + if order_type == "asc": + query = query.order_by(self.metadata.tables[table].c[name].asc()) + elif order_type == "desc": + query = query.order_by(self.metadata.tables[table].c[name].desc()) + else: + raise ValueError(f"Invalid order '{order_type}' for column '{name}'") + result = await conn.execute(query) + if result.rowcount == 0: + return [] + return [dict(row._mapping) for row in result] + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + rows = await self.fetch_all(table, where, limit=1, order_by=order_by) + if not rows: + return None + return rows[0] + + async def update( + self, + table: str, + data: Mapping[str, Any], + where: Mapping[str, Any], + ) -> None: + if not where: + raise ValueError("where is required for update") + + async with self.engine.begin() as conn: + stmt = self.metadata.tables[table].update() + for key, value in where.items(): + stmt = stmt.where(self.metadata.tables[table].c[key] == value) + await conn.execute(stmt, data) + await conn.commit() + + async def delete(self, table: str, where: Mapping[str, Any]) -> None: + if not where: + raise ValueError("where is required for delete") + + async with self.engine.begin() as conn: + stmt = self.metadata.tables[table].delete() + for key, value in where.items(): + stmt = stmt.where(self.metadata.tables[table].c[key] == value) + await conn.execute(stmt) + await conn.commit() diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py new file mode 100644 index 000000000..99f64805f --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from enum import Enum +from pathlib import Path +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from .api import SqlStore + + +class SqlStoreType(Enum): + sqlite = "sqlite" + postgres = "postgres" + + +class SqliteSqlStoreConfig(BaseModel): + type: Literal["sqlite"] = SqlStoreType.sqlite.value + db_path: str = Field( + default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db", + ) + + @property + def engine_str(self) -> str: + return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix() + + @classmethod + def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"): + return cls( + type="sqlite", + db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name, + ) + + # TODO: move this when we have a better way to specify dependencies with internal APIs + @property + def pip_packages(self) -> list[str]: + return ["sqlalchemy[asyncio]"] + + +class PostgresSqlStoreConfig(BaseModel): + type: Literal["postgres"] = SqlStoreType.postgres.value + + @property + def pip_packages(self) -> list[str]: + raise NotImplementedError("Postgres is not implemented yet") + + +SqlStoreConfig = Annotated[ + SqliteSqlStoreConfig | PostgresSqlStoreConfig, + Field(discriminator="type", default=SqlStoreType.sqlite.value), +] + + +def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: + if config.type == SqlStoreType.sqlite.value: + from .sqlite.sqlite import SqliteSqlStoreImpl + + impl = SqliteSqlStoreImpl(config) + elif config.type == SqlStoreType.postgres.value: + raise NotImplementedError("Postgres is not implemented yet") + else: + raise ValueError(f"Unknown sqlstore type {config.type}") + + return impl diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index 46d5b9c69..09fbf307d 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -29,3 +29,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index 30599a6c0..c39b08ff9 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -96,6 +96,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/inference_store.db models: - metadata: {} model_id: meta.llama3-1-8b-instruct-v1:0 diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index 0498da1cd..95b0302f2 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -29,3 +29,5 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 0731b1df9..025033f59 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/inference_store.db models: - metadata: {} model_id: llama3.1-8b diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index a4c5893c4..6fe96c603 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -30,3 +30,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index d9ee5b3cf..342388b78 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index f5beb6c2f..d37215f35 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -30,3 +30,6 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 24c515112..77843858c 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index fdece894f..fd0d4a1f6 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -95,6 +95,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index fb4ab9fda..78da0603b 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -31,6 +31,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -67,6 +68,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -105,6 +107,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -145,6 +148,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -184,6 +188,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -221,6 +226,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -259,6 +265,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -297,6 +304,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -335,6 +343,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -379,6 +388,7 @@ "scipy", "sentence-transformers", "sentencepiece", + "sqlalchemy[asyncio]", "torch", "torchao==0.8.0", "torchvision", @@ -414,6 +424,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "uvicorn" @@ -452,6 +463,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "torch", "tqdm", "transformers", @@ -490,6 +502,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "sqlite-vec", "together", "tqdm", @@ -528,6 +541,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -566,6 +580,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -599,6 +614,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "uvicorn", @@ -637,6 +653,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -678,6 +695,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -716,6 +734,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "together", "tqdm", "transformers", @@ -755,6 +774,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "sqlite-vec", "tqdm", "transformers", @@ -794,6 +814,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", @@ -833,6 +854,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlalchemy[asyncio]", "tqdm", "transformers", "tree_sitter", diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index 7c74157ee..f162d9b43 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 0ab07613e..1f66983f4 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 81c293a46..1fbf4be6e 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/groq/build.yaml b/llama_stack/templates/groq/build.yaml index 800c3e3ae..92b46ce66 100644 --- a/llama_stack/templates/groq/build.yaml +++ b/llama_stack/templates/groq/build.yaml @@ -26,3 +26,5 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index 79c350c73..7d257d379 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/inference_store.db models: - metadata: {} model_id: groq/llama3-8b-8192 diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 2a40c3909..4d09cc33e 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -29,3 +29,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index 82bcaa3cf..b3938bf93 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -107,6 +107,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index ec7c55032..1e60dd25c 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -102,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index f77f8773b..d06c628ac 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -30,3 +30,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 320976e2c..640506632 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -107,6 +107,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index 2b22b20c6..a8b46a0aa 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -102,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/llama_api/build.yaml b/llama_stack/templates/llama_api/build.yaml index f97ee4091..d0dc08923 100644 --- a/llama_stack/templates/llama_api/build.yaml +++ b/llama_stack/templates/llama_api/build.yaml @@ -30,3 +30,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/llama_api/run.yaml b/llama_stack/templates/llama_api/run.yaml index a879482d7..1d5739fe2 100644 --- a/llama_stack/templates/llama_api/run.yaml +++ b/llama_stack/templates/llama_api/run.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/inference_store.db models: - metadata: {} model_id: Llama-3.3-70B-Instruct diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index a9d03490b..e0ac87e47 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -29,3 +29,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 180d44e0f..bbf7ad767 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -117,6 +117,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index d879667e0..9ce69c209 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -107,6 +107,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index a05cf97ad..e1e6fb3d8 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -24,3 +24,6 @@ distribution_spec: tool_runtime: - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 3cdb8e3d2..32359b805 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -92,6 +92,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 3337b7942..d4e935727 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -80,6 +80,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} model_id: meta/llama3-8b-instruct diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 7d5363575..9d8ba3a1e 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -32,3 +32,6 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 74d0822ca..a19ac73c6 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -112,6 +112,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 71229be97..551af3a99 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -110,6 +110,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index b14e96435..aa6d876fe 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -33,3 +33,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 30a27cbd8..7b43ce6e7 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -125,6 +125,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/passthrough/build.yaml b/llama_stack/templates/passthrough/build.yaml index f8d099070..7560f1032 100644 --- a/llama_stack/templates/passthrough/build.yaml +++ b/llama_stack/templates/passthrough/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/passthrough/run-with-safety.yaml b/llama_stack/templates/passthrough/run-with-safety.yaml index a91b9fc92..cddda39fa 100644 --- a/llama_stack/templates/passthrough/run-with-safety.yaml +++ b/llama_stack/templates/passthrough/run-with-safety.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml index d1dd3b885..1fc3914a6 100644 --- a/llama_stack/templates/passthrough/run.yaml +++ b/llama_stack/templates/passthrough/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index 4baaaf9c8..fcd4deeff 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 6931d4ba9..89f3aa082 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -115,6 +115,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 05671165d..4d4395fd7 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -108,6 +108,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 79bb68c68..b644dcfdc 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -22,3 +22,5 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index fa8735002..907bc013e 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -82,6 +82,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/inference_store.db models: - metadata: {} model_id: sambanova/Meta-Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 35bd0c713..652814ffd 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -35,3 +35,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 402695850..3327e576c 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -133,6 +133,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index e4d28d904..ec5cd38ea 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -29,6 +29,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig def get_model_registry( @@ -117,6 +118,10 @@ class RunConfigSettings(BaseModel): __distro_dir__=f"~/.llama/distributions/{name}", db_name="registry.db", ), + inference_store=SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=f"~/.llama/distributions/{name}", + db_name="inference_store.db", + ), models=self.default_models or [], shields=self.default_shields or [], tool_groups=self.default_tool_groups or [], @@ -146,14 +151,20 @@ class DistributionTemplate(BaseModel): available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None def build_config(self) -> BuildConfig: + additional_pip_packages: list[str] = [] + for run_config in self.run_configs.values(): + run_config_ = run_config.run_config(self.name, self.providers, self.container_image) + if run_config_.inference_store: + additional_pip_packages.extend(run_config_.inference_store.pip_packages) + return BuildConfig( - name=self.name, distribution_spec=DistributionSpec( description=self.description, container_image=self.container_image, providers=self.providers, ), image_type="conda", # default to conda, can be overridden + additional_pip_packages=additional_pip_packages, ) def generate_markdown_docs(self) -> str: diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index d2ba1c3e9..652900c84 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -30,3 +30,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index 3255e9c0b..bd197b93f 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -102,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 179087258..230fe9a5a 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -101,6 +101,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index b7338795c..4a556a66f 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index fe8c8e397..1c05e5e42 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index b903fc659..aebf4e1a2 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/verification/build.yaml b/llama_stack/templates/verification/build.yaml index aae24c3ca..cb7ab4798 100644 --- a/llama_stack/templates/verification/build.yaml +++ b/llama_stack/templates/verification/build.yaml @@ -35,3 +35,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index 11af41da9..de8b0d850 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -135,6 +135,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index 53e257f22..5a9d003cb 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -30,3 +30,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index 5d3482528..a0257f704 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 638b16029..87233fb26 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -28,3 +28,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 8de6a2b6c..86ec01953 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -103,6 +103,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/inference_store.db models: - metadata: {} model_id: meta-llama/llama-3-3-70b-instruct diff --git a/pyproject.toml b/pyproject.toml index 6b873968a..d7ea43ca6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,8 @@ unit = [ "chardet", "qdrant-client", "opentelemetry-exporter-otlp-proto-http", + "sqlalchemy", + "sqlalchemy[asyncio]>=2.0.41", ] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 46ec03d2e..6f8a05a45 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -222,3 +222,105 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex streamed_content.append(chunk.choices[0].delta.content.lower().strip()) assert len(streamed_content) > 0 assert expected.lower() in "".join(streamed_content) + + +@pytest.mark.parametrize( + "stream", + [ + True, + False, + ], +) +def test_inference_store(openai_client, client_with_models, text_model_id, stream): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) + client = openai_client + # make a chat completion + message = "Hello, world!" + response = client.chat.completions.create( + model=text_model_id, + messages=[ + { + "role": "user", + "content": message, + } + ], + stream=stream, + ) + if stream: + # accumulate the streamed content + content = "" + response_id = None + for chunk in response: + if response_id is None: + response_id = chunk.id + content += chunk.choices[0].delta.content + else: + response_id = response.id + content = response.choices[0].message.content + + responses = client.chat.completions.list() + assert response_id in [r.id for r in responses.data] + + retrieved_response = client.chat.completions.retrieve(response_id) + assert retrieved_response.id == response_id + assert retrieved_response.input_messages[0]["content"] == message + assert retrieved_response.choices[0].message.content == content + + +@pytest.mark.parametrize( + "stream", + [ + True, + False, + ], +) +def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) + client = openai_client + # make a chat completion + message = "What's the weather in Tokyo? Use the get_weather function to get the weather." + response = client.chat.completions.create( + model=text_model_id, + messages=[ + { + "role": "user", + "content": message, + } + ], + stream=stream, + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather in a given city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city to get the weather for"}, + }, + }, + }, + } + ], + ) + if stream: + # accumulate the streamed content + content = "" + response_id = None + for chunk in response: + if response_id is None: + response_id = chunk.id + content += chunk.choices[0].delta.content + else: + response_id = response.id + content = response.choices[0].message.content + + responses = client.chat.completions.list() + assert response_id in [r.id for r in responses.data] + + retrieved_response = client.chat.completions.retrieve(response_id) + assert retrieved_response.id == response_id + assert retrieved_response.input_messages[0]["content"] == message + assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather" + assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}' diff --git a/tests/unit/utils/test_sqlstore.py b/tests/unit/utils/test_sqlstore.py new file mode 100644 index 000000000..8ded760ef --- /dev/null +++ b/tests/unit/utils/test_sqlstore.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from tempfile import TemporaryDirectory + +import pytest + +from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +@pytest.mark.asyncio +async def test_sqlite_sqlstore(): + with TemporaryDirectory() as tmp_dir: + db_name = "test.db" + sqlstore = SqliteSqlStoreImpl( + SqliteSqlStoreConfig( + db_path=tmp_dir + "/" + db_name, + ) + ) + await sqlstore.create_table( + table="test", + schema={ + "id": ColumnType.INTEGER, + "name": ColumnType.STRING, + }, + ) + await sqlstore.insert("test", {"id": 1, "name": "test"}) + await sqlstore.insert("test", {"id": 12, "name": "test12"}) + rows = await sqlstore.fetch_all("test") + assert rows == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}] + + row = await sqlstore.fetch_one("test", {"id": 1}) + assert row == {"id": 1, "name": "test"} + + row = await sqlstore.fetch_one("test", {"name": "test12"}) + assert row == {"id": 12, "name": "test12"} + + # order by + rows = await sqlstore.fetch_all("test", order_by=[("id", "asc")]) + assert rows == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}] + + rows = await sqlstore.fetch_all("test", order_by=[("id", "desc")]) + assert rows == [{"id": 12, "name": "test12"}, {"id": 1, "name": "test"}] + + # limit + rows = await sqlstore.fetch_all("test", limit=1) + assert rows == [{"id": 1, "name": "test"}] + + # update + await sqlstore.update("test", {"name": "test123"}, {"id": 1}) + row = await sqlstore.fetch_one("test", {"id": 1}) + assert row == {"id": 1, "name": "test123"} + + # delete + await sqlstore.delete("test", {"id": 1}) + rows = await sqlstore.fetch_all("test") + assert rows == [{"id": 12, "name": "test12"}] diff --git a/uv.lock b/uv.lock index 6d091193b..1a3657567 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -875,6 +874,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/30/2bd0eb03a7dee7727cd2ec643d1e992979e62d5e7443507381cce0455132/googleapis_common_protos-1.67.0-py2.py3-none-any.whl", hash = "sha256:579de760800d13616f51cf8be00c876f00a9f146d3e6510e19d1f4111758b741", size = 164985 }, ] +[[package]] +name = "greenlet" +version = "3.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/34/c1/a82edae11d46c0d83481aacaa1e578fea21d94a1ef400afd734d47ad95ad/greenlet-3.2.2.tar.gz", hash = "sha256:ad053d34421a2debba45aa3cc39acf454acbcd025b3fc1a9f8a0dee237abd485", size = 185797 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/66/910217271189cc3f32f670040235f4bf026ded8ca07270667d69c06e7324/greenlet-3.2.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:c49e9f7c6f625507ed83a7485366b46cbe325717c60837f7244fc99ba16ba9d6", size = 267395 }, + { url = "https://files.pythonhosted.org/packages/a8/36/8d812402ca21017c82880f399309afadb78a0aa300a9b45d741e4df5d954/greenlet-3.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3cc1a3ed00ecfea8932477f729a9f616ad7347a5e55d50929efa50a86cb7be7", size = 625742 }, + { url = "https://files.pythonhosted.org/packages/7b/77/66d7b59dfb7cc1102b2f880bc61cb165ee8998c9ec13c96606ba37e54c77/greenlet-3.2.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7c9896249fbef2c615853b890ee854f22c671560226c9221cfd27c995db97e5c", size = 637014 }, + { url = "https://files.pythonhosted.org/packages/36/a7/ff0d408f8086a0d9a5aac47fa1b33a040a9fca89bd5a3f7b54d1cd6e2793/greenlet-3.2.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7409796591d879425997a518138889d8d17e63ada7c99edc0d7a1c22007d4907", size = 632874 }, + { url = "https://files.pythonhosted.org/packages/a1/75/1dc2603bf8184da9ebe69200849c53c3c1dca5b3a3d44d9f5ca06a930550/greenlet-3.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7791dcb496ec53d60c7f1c78eaa156c21f402dda38542a00afc3e20cae0f480f", size = 631652 }, + { url = "https://files.pythonhosted.org/packages/7b/74/ddc8c3bd4c2c20548e5bf2b1d2e312a717d44e2eca3eadcfc207b5f5ad80/greenlet-3.2.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8009ae46259e31bc73dc183e402f548e980c96f33a6ef58cc2e7865db012e13", size = 580619 }, + { url = "https://files.pythonhosted.org/packages/7e/f2/40f26d7b3077b1c7ae7318a4de1f8ffc1d8ccbad8f1d8979bf5080250fd6/greenlet-3.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:fd9fb7c941280e2c837b603850efc93c999ae58aae2b40765ed682a6907ebbc5", size = 1109809 }, + { url = "https://files.pythonhosted.org/packages/c5/21/9329e8c276746b0d2318b696606753f5e7b72d478adcf4ad9a975521ea5f/greenlet-3.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:00cd814b8959b95a546e47e8d589610534cfb71f19802ea8a2ad99d95d702057", size = 1133455 }, + { url = "https://files.pythonhosted.org/packages/bb/1e/0dca9619dbd736d6981f12f946a497ec21a0ea27262f563bca5729662d4d/greenlet-3.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:d0cb7d47199001de7658c213419358aa8937df767936506db0db7ce1a71f4a2f", size = 294991 }, + { url = "https://files.pythonhosted.org/packages/a3/9f/a47e19261747b562ce88219e5ed8c859d42c6e01e73da6fbfa3f08a7be13/greenlet-3.2.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:dcb9cebbf3f62cb1e5afacae90761ccce0effb3adaa32339a0670fe7805d8068", size = 268635 }, + { url = "https://files.pythonhosted.org/packages/11/80/a0042b91b66975f82a914d515e81c1944a3023f2ce1ed7a9b22e10b46919/greenlet-3.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf3fc9145141250907730886b031681dfcc0de1c158f3cc51c092223c0f381ce", size = 628786 }, + { url = "https://files.pythonhosted.org/packages/38/a2/8336bf1e691013f72a6ebab55da04db81a11f68e82bb691f434909fa1327/greenlet-3.2.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:efcdfb9df109e8a3b475c016f60438fcd4be68cd13a365d42b35914cdab4bb2b", size = 640866 }, + { url = "https://files.pythonhosted.org/packages/f8/7e/f2a3a13e424670a5d08826dab7468fa5e403e0fbe0b5f951ff1bc4425b45/greenlet-3.2.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4bd139e4943547ce3a56ef4b8b1b9479f9e40bb47e72cc906f0f66b9d0d5cab3", size = 636752 }, + { url = "https://files.pythonhosted.org/packages/fd/5d/ce4a03a36d956dcc29b761283f084eb4a3863401c7cb505f113f73af8774/greenlet-3.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71566302219b17ca354eb274dfd29b8da3c268e41b646f330e324e3967546a74", size = 636028 }, + { url = "https://files.pythonhosted.org/packages/4b/29/b130946b57e3ceb039238413790dd3793c5e7b8e14a54968de1fe449a7cf/greenlet-3.2.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3091bc45e6b0c73f225374fefa1536cd91b1e987377b12ef5b19129b07d93ebe", size = 583869 }, + { url = "https://files.pythonhosted.org/packages/ac/30/9f538dfe7f87b90ecc75e589d20cbd71635531a617a336c386d775725a8b/greenlet-3.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:44671c29da26539a5f142257eaba5110f71887c24d40df3ac87f1117df589e0e", size = 1112886 }, + { url = "https://files.pythonhosted.org/packages/be/92/4b7deeb1a1e9c32c1b59fdca1cac3175731c23311ddca2ea28a8b6ada91c/greenlet-3.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c23ea227847c9dbe0b3910f5c0dd95658b607137614eb821e6cbaecd60d81cc6", size = 1138355 }, + { url = "https://files.pythonhosted.org/packages/c5/eb/7551c751a2ea6498907b2fcbe31d7a54b602ba5e8eb9550a9695ca25d25c/greenlet-3.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:0a16fb934fcabfdfacf21d79e6fed81809d8cd97bc1be9d9c89f0e4567143d7b", size = 295437 }, + { url = "https://files.pythonhosted.org/packages/2c/a1/88fdc6ce0df6ad361a30ed78d24c86ea32acb2b563f33e39e927b1da9ea0/greenlet-3.2.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:df4d1509efd4977e6a844ac96d8be0b9e5aa5d5c77aa27ca9f4d3f92d3fcf330", size = 270413 }, + { url = "https://files.pythonhosted.org/packages/a6/2e/6c1caffd65490c68cd9bcec8cb7feb8ac7b27d38ba1fea121fdc1f2331dc/greenlet-3.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da956d534a6d1b9841f95ad0f18ace637668f680b1339ca4dcfb2c1837880a0b", size = 637242 }, + { url = "https://files.pythonhosted.org/packages/98/28/088af2cedf8823b6b7ab029a5626302af4ca1037cf8b998bed3a8d3cb9e2/greenlet-3.2.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c7b15fb9b88d9ee07e076f5a683027bc3befd5bb5d25954bb633c385d8b737e", size = 651444 }, + { url = "https://files.pythonhosted.org/packages/4a/9f/0116ab876bb0bc7a81eadc21c3f02cd6100dcd25a1cf2a085a130a63a26a/greenlet-3.2.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:752f0e79785e11180ebd2e726c8a88109ded3e2301d40abced2543aa5d164275", size = 646067 }, + { url = "https://files.pythonhosted.org/packages/35/17/bb8f9c9580e28a94a9575da847c257953d5eb6e39ca888239183320c1c28/greenlet-3.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ae572c996ae4b5e122331e12bbb971ea49c08cc7c232d1bd43150800a2d6c65", size = 648153 }, + { url = "https://files.pythonhosted.org/packages/2c/ee/7f31b6f7021b8df6f7203b53b9cc741b939a2591dcc6d899d8042fcf66f2/greenlet-3.2.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02f5972ff02c9cf615357c17ab713737cccfd0eaf69b951084a9fd43f39833d3", size = 603865 }, + { url = "https://files.pythonhosted.org/packages/b5/2d/759fa59323b521c6f223276a4fc3d3719475dc9ae4c44c2fe7fc750f8de0/greenlet-3.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4fefc7aa68b34b9224490dfda2e70ccf2131368493add64b4ef2d372955c207e", size = 1119575 }, + { url = "https://files.pythonhosted.org/packages/30/05/356813470060bce0e81c3df63ab8cd1967c1ff6f5189760c1a4734d405ba/greenlet-3.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a31ead8411a027c2c4759113cf2bd473690517494f3d6e4bf67064589afcd3c5", size = 1147460 }, + { url = "https://files.pythonhosted.org/packages/07/f4/b2a26a309a04fb844c7406a4501331b9400e1dd7dd64d3450472fd47d2e1/greenlet-3.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:b24c7844c0a0afc3ccbeb0b807adeefb7eff2b5599229ecedddcfeb0ef333bec", size = 296239 }, + { url = "https://files.pythonhosted.org/packages/89/30/97b49779fff8601af20972a62cc4af0c497c1504dfbb3e93be218e093f21/greenlet-3.2.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:3ab7194ee290302ca15449f601036007873028712e92ca15fc76597a0aeb4c59", size = 269150 }, + { url = "https://files.pythonhosted.org/packages/21/30/877245def4220f684bc2e01df1c2e782c164e84b32e07373992f14a2d107/greenlet-3.2.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc5c43bb65ec3669452af0ab10729e8fdc17f87a1f2ad7ec65d4aaaefabf6bf", size = 637381 }, + { url = "https://files.pythonhosted.org/packages/8e/16/adf937908e1f913856b5371c1d8bdaef5f58f251d714085abeea73ecc471/greenlet-3.2.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:decb0658ec19e5c1f519faa9a160c0fc85a41a7e6654b3ce1b44b939f8bf1325", size = 651427 }, + { url = "https://files.pythonhosted.org/packages/ad/49/6d79f58fa695b618654adac64e56aff2eeb13344dc28259af8f505662bb1/greenlet-3.2.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6fadd183186db360b61cb34e81117a096bff91c072929cd1b529eb20dd46e6c5", size = 645795 }, + { url = "https://files.pythonhosted.org/packages/5a/e6/28ed5cb929c6b2f001e96b1d0698c622976cd8f1e41fe7ebc047fa7c6dd4/greenlet-3.2.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1919cbdc1c53ef739c94cf2985056bcc0838c1f217b57647cbf4578576c63825", size = 648398 }, + { url = "https://files.pythonhosted.org/packages/9d/70/b200194e25ae86bc57077f695b6cc47ee3118becf54130c5514456cf8dac/greenlet-3.2.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3885f85b61798f4192d544aac7b25a04ece5fe2704670b4ab73c2d2c14ab740d", size = 606795 }, + { url = "https://files.pythonhosted.org/packages/f8/c8/ba1def67513a941154ed8f9477ae6e5a03f645be6b507d3930f72ed508d3/greenlet-3.2.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:85f3e248507125bf4af607a26fd6cb8578776197bd4b66e35229cdf5acf1dfbf", size = 1117976 }, + { url = "https://files.pythonhosted.org/packages/c3/30/d0e88c1cfcc1b3331d63c2b54a0a3a4a950ef202fb8b92e772ca714a9221/greenlet-3.2.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1e76106b6fc55fa3d6fe1c527f95ee65e324a13b62e243f77b48317346559708", size = 1145509 }, + { url = "https://files.pythonhosted.org/packages/90/2e/59d6491834b6e289051b252cf4776d16da51c7c6ca6a87ff97e3a50aa0cd/greenlet-3.2.2-cp313-cp313-win_amd64.whl", hash = "sha256:fe46d4f8e94e637634d54477b0cfabcf93c53f29eedcbdeecaf2af32029b4421", size = 296023 }, + { url = "https://files.pythonhosted.org/packages/65/66/8a73aace5a5335a1cba56d0da71b7bd93e450f17d372c5b7c5fa547557e9/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba30e88607fb6990544d84caf3c706c4b48f629e18853fc6a646f82db9629418", size = 629911 }, + { url = "https://files.pythonhosted.org/packages/48/08/c8b8ebac4e0c95dcc68ec99198842e7db53eda4ab3fb0a4e785690883991/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:055916fafad3e3388d27dd68517478933a97edc2fc54ae79d3bec827de2c64c4", size = 635251 }, + { url = "https://files.pythonhosted.org/packages/37/26/7db30868f73e86b9125264d2959acabea132b444b88185ba5c462cb8e571/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2593283bf81ca37d27d110956b79e8723f9aa50c4bcdc29d3c0543d4743d2763", size = 632620 }, + { url = "https://files.pythonhosted.org/packages/10/ec/718a3bd56249e729016b0b69bee4adea0dfccf6ca43d147ef3b21edbca16/greenlet-3.2.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89c69e9a10670eb7a66b8cef6354c24671ba241f46152dd3eed447f79c29fb5b", size = 628851 }, + { url = "https://files.pythonhosted.org/packages/9b/9d/d1c79286a76bc62ccdc1387291464af16a4204ea717f24e77b0acd623b99/greenlet-3.2.2-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02a98600899ca1ca5d3a2590974c9e3ec259503b2d6ba6527605fcd74e08e207", size = 593718 }, + { url = "https://files.pythonhosted.org/packages/cd/41/96ba2bf948f67b245784cd294b84e3d17933597dffd3acdb367a210d1949/greenlet-3.2.2-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:b50a8c5c162469c3209e5ec92ee4f95c8231b11db6a04db09bbe338176723bb8", size = 1105752 }, + { url = "https://files.pythonhosted.org/packages/68/3b/3b97f9d33c1f2eb081759da62bd6162159db260f602f048bc2f36b4c453e/greenlet-3.2.2-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:45f9f4853fb4cc46783085261c9ec4706628f3b57de3e68bae03e8f8b3c0de51", size = 1125170 }, + { url = "https://files.pythonhosted.org/packages/31/df/b7d17d66c8d0f578d2885a3d8f565e9e4725eacc9d3fdc946d0031c055c4/greenlet-3.2.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:9ea5231428af34226c05f927e16fc7f6fa5e39e3ad3cd24ffa48ba53a47f4240", size = 269899 }, +] + [[package]] name = "grpcio" version = "1.71.0" @@ -1495,6 +1546,7 @@ unit = [ { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "pypdf" }, { name = "qdrant-client" }, + { name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlite-vec" }, ] @@ -1564,6 +1616,8 @@ requires-dist = [ { name = "sphinxcontrib-openapi", marker = "extra == 'docs'" }, { name = "sphinxcontrib-redoc", marker = "extra == 'docs'" }, { name = "sphinxcontrib-video", marker = "extra == 'docs'" }, + { name = "sqlalchemy", marker = "extra == 'unit'" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'unit'", specifier = ">=2.0.41" }, { name = "sqlite-vec", marker = "extra == 'unit'" }, { name = "streamlit", marker = "extra == 'ui'" }, { name = "streamlit-option-menu", marker = "extra == 'ui'" }, @@ -1577,7 +1631,6 @@ requires-dist = [ { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "uvicorn", marker = "extra == 'dev'" }, ] -provides-extras = ["dev", "unit", "test", "docs", "codegen", "ui"] [[package]] name = "llama-stack-client" @@ -3748,6 +3801,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/8b/a0271fe65357860ccc52168181891e9fc9d354bfdc9be273e6a77b84f905/sphinxcontrib_video-0.4.1-py3-none-any.whl", hash = "sha256:d63ec68983dac36960557973281a616b5d9e68838369763313fc80533b1ad774", size = 10066 }, ] +[[package]] +name = "sqlalchemy" +version = "2.0.41" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "greenlet", marker = "(python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'WIN32') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'ppc64le') or (python_full_version < '3.14' and platform_machine == 'win32') or (python_full_version < '3.14' and platform_machine == 'x86_64')" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/66/45b165c595ec89aa7dcc2c1cd222ab269bc753f1fc7a1e68f8481bd957bf/sqlalchemy-2.0.41.tar.gz", hash = "sha256:edba70118c4be3c2b1f90754d308d0b79c6fe2c0fdc52d8ddf603916f83f4db9", size = 9689424 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/12/d7c445b1940276a828efce7331cb0cb09d6e5f049651db22f4ebb0922b77/sqlalchemy-2.0.41-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b1f09b6821406ea1f94053f346f28f8215e293344209129a9c0fcc3578598d7b", size = 2117967 }, + { url = "https://files.pythonhosted.org/packages/6f/b8/cb90f23157e28946b27eb01ef401af80a1fab7553762e87df51507eaed61/sqlalchemy-2.0.41-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1936af879e3db023601196a1684d28e12f19ccf93af01bf3280a3262c4b6b4e5", size = 2107583 }, + { url = "https://files.pythonhosted.org/packages/9e/c2/eef84283a1c8164a207d898e063edf193d36a24fb6a5bb3ce0634b92a1e8/sqlalchemy-2.0.41-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2ac41acfc8d965fb0c464eb8f44995770239668956dc4cdf502d1b1ffe0d747", size = 3186025 }, + { url = "https://files.pythonhosted.org/packages/bd/72/49d52bd3c5e63a1d458fd6d289a1523a8015adedbddf2c07408ff556e772/sqlalchemy-2.0.41-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c24e0c0fde47a9723c81d5806569cddef103aebbf79dbc9fcbb617153dea30", size = 3186259 }, + { url = "https://files.pythonhosted.org/packages/4f/9e/e3ffc37d29a3679a50b6bbbba94b115f90e565a2b4545abb17924b94c52d/sqlalchemy-2.0.41-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:23a8825495d8b195c4aa9ff1c430c28f2c821e8c5e2d98089228af887e5d7e29", size = 3126803 }, + { url = "https://files.pythonhosted.org/packages/8a/76/56b21e363f6039978ae0b72690237b38383e4657281285a09456f313dd77/sqlalchemy-2.0.41-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:60c578c45c949f909a4026b7807044e7e564adf793537fc762b2489d522f3d11", size = 3148566 }, + { url = "https://files.pythonhosted.org/packages/3b/92/11b8e1b69bf191bc69e300a99badbbb5f2f1102f2b08b39d9eee2e21f565/sqlalchemy-2.0.41-cp310-cp310-win32.whl", hash = "sha256:118c16cd3f1b00c76d69343e38602006c9cfb9998fa4f798606d28d63f23beda", size = 2086696 }, + { url = "https://files.pythonhosted.org/packages/5c/88/2d706c9cc4502654860f4576cd54f7db70487b66c3b619ba98e0be1a4642/sqlalchemy-2.0.41-cp310-cp310-win_amd64.whl", hash = "sha256:7492967c3386df69f80cf67efd665c0f667cee67032090fe01d7d74b0e19bb08", size = 2110200 }, + { url = "https://files.pythonhosted.org/packages/37/4e/b00e3ffae32b74b5180e15d2ab4040531ee1bef4c19755fe7926622dc958/sqlalchemy-2.0.41-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6375cd674fe82d7aa9816d1cb96ec592bac1726c11e0cafbf40eeee9a4516b5f", size = 2121232 }, + { url = "https://files.pythonhosted.org/packages/ef/30/6547ebb10875302074a37e1970a5dce7985240665778cfdee2323709f749/sqlalchemy-2.0.41-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9f8c9fdd15a55d9465e590a402f42082705d66b05afc3ffd2d2eb3c6ba919560", size = 2110897 }, + { url = "https://files.pythonhosted.org/packages/9e/21/59df2b41b0f6c62da55cd64798232d7349a9378befa7f1bb18cf1dfd510a/sqlalchemy-2.0.41-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32f9dc8c44acdee06c8fc6440db9eae8b4af8b01e4b1aee7bdd7241c22edff4f", size = 3273313 }, + { url = "https://files.pythonhosted.org/packages/62/e4/b9a7a0e5c6f79d49bcd6efb6e90d7536dc604dab64582a9dec220dab54b6/sqlalchemy-2.0.41-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c11ceb9a1f482c752a71f203a81858625d8df5746d787a4786bca4ffdf71c6", size = 3273807 }, + { url = "https://files.pythonhosted.org/packages/39/d8/79f2427251b44ddee18676c04eab038d043cff0e764d2d8bb08261d6135d/sqlalchemy-2.0.41-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:911cc493ebd60de5f285bcae0491a60b4f2a9f0f5c270edd1c4dbaef7a38fc04", size = 3209632 }, + { url = "https://files.pythonhosted.org/packages/d4/16/730a82dda30765f63e0454918c982fb7193f6b398b31d63c7c3bd3652ae5/sqlalchemy-2.0.41-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03968a349db483936c249f4d9cd14ff2c296adfa1290b660ba6516f973139582", size = 3233642 }, + { url = "https://files.pythonhosted.org/packages/04/61/c0d4607f7799efa8b8ea3c49b4621e861c8f5c41fd4b5b636c534fcb7d73/sqlalchemy-2.0.41-cp311-cp311-win32.whl", hash = "sha256:293cd444d82b18da48c9f71cd7005844dbbd06ca19be1ccf6779154439eec0b8", size = 2086475 }, + { url = "https://files.pythonhosted.org/packages/9d/8e/8344f8ae1cb6a479d0741c02cd4f666925b2bf02e2468ddaf5ce44111f30/sqlalchemy-2.0.41-cp311-cp311-win_amd64.whl", hash = "sha256:3d3549fc3e40667ec7199033a4e40a2f669898a00a7b18a931d3efb4c7900504", size = 2110903 }, + { url = "https://files.pythonhosted.org/packages/3e/2a/f1f4e068b371154740dd10fb81afb5240d5af4aa0087b88d8b308b5429c2/sqlalchemy-2.0.41-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:81f413674d85cfd0dfcd6512e10e0f33c19c21860342a4890c3a2b59479929f9", size = 2119645 }, + { url = "https://files.pythonhosted.org/packages/9b/e8/c664a7e73d36fbfc4730f8cf2bf930444ea87270f2825efbe17bf808b998/sqlalchemy-2.0.41-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:598d9ebc1e796431bbd068e41e4de4dc34312b7aa3292571bb3674a0cb415dd1", size = 2107399 }, + { url = "https://files.pythonhosted.org/packages/5c/78/8a9cf6c5e7135540cb682128d091d6afa1b9e48bd049b0d691bf54114f70/sqlalchemy-2.0.41-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a104c5694dfd2d864a6f91b0956eb5d5883234119cb40010115fd45a16da5e70", size = 3293269 }, + { url = "https://files.pythonhosted.org/packages/3c/35/f74add3978c20de6323fb11cb5162702670cc7a9420033befb43d8d5b7a4/sqlalchemy-2.0.41-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6145afea51ff0af7f2564a05fa95eb46f542919e6523729663a5d285ecb3cf5e", size = 3303364 }, + { url = "https://files.pythonhosted.org/packages/6a/d4/c990f37f52c3f7748ebe98883e2a0f7d038108c2c5a82468d1ff3eec50b7/sqlalchemy-2.0.41-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b46fa6eae1cd1c20e6e6f44e19984d438b6b2d8616d21d783d150df714f44078", size = 3229072 }, + { url = "https://files.pythonhosted.org/packages/15/69/cab11fecc7eb64bc561011be2bd03d065b762d87add52a4ca0aca2e12904/sqlalchemy-2.0.41-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41836fe661cc98abfae476e14ba1906220f92c4e528771a8a3ae6a151242d2ae", size = 3268074 }, + { url = "https://files.pythonhosted.org/packages/5c/ca/0c19ec16858585d37767b167fc9602593f98998a68a798450558239fb04a/sqlalchemy-2.0.41-cp312-cp312-win32.whl", hash = "sha256:a8808d5cf866c781150d36a3c8eb3adccfa41a8105d031bf27e92c251e3969d6", size = 2084514 }, + { url = "https://files.pythonhosted.org/packages/7f/23/4c2833d78ff3010a4e17f984c734f52b531a8c9060a50429c9d4b0211be6/sqlalchemy-2.0.41-cp312-cp312-win_amd64.whl", hash = "sha256:5b14e97886199c1f52c14629c11d90c11fbb09e9334fa7bb5f6d068d9ced0ce0", size = 2111557 }, + { url = "https://files.pythonhosted.org/packages/d3/ad/2e1c6d4f235a97eeef52d0200d8ddda16f6c4dd70ae5ad88c46963440480/sqlalchemy-2.0.41-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4eeb195cdedaf17aab6b247894ff2734dcead6c08f748e617bfe05bd5a218443", size = 2115491 }, + { url = "https://files.pythonhosted.org/packages/cf/8d/be490e5db8400dacc89056f78a52d44b04fbf75e8439569d5b879623a53b/sqlalchemy-2.0.41-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d4ae769b9c1c7757e4ccce94b0641bc203bbdf43ba7a2413ab2523d8d047d8dc", size = 2102827 }, + { url = "https://files.pythonhosted.org/packages/a0/72/c97ad430f0b0e78efaf2791342e13ffeafcbb3c06242f01a3bb8fe44f65d/sqlalchemy-2.0.41-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a62448526dd9ed3e3beedc93df9bb6b55a436ed1474db31a2af13b313a70a7e1", size = 3225224 }, + { url = "https://files.pythonhosted.org/packages/5e/51/5ba9ea3246ea068630acf35a6ba0d181e99f1af1afd17e159eac7e8bc2b8/sqlalchemy-2.0.41-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc56c9788617b8964ad02e8fcfeed4001c1f8ba91a9e1f31483c0dffb207002a", size = 3230045 }, + { url = "https://files.pythonhosted.org/packages/78/2f/8c14443b2acea700c62f9b4a8bad9e49fc1b65cfb260edead71fd38e9f19/sqlalchemy-2.0.41-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c153265408d18de4cc5ded1941dcd8315894572cddd3c58df5d5b5705b3fa28d", size = 3159357 }, + { url = "https://files.pythonhosted.org/packages/fc/b2/43eacbf6ccc5276d76cea18cb7c3d73e294d6fb21f9ff8b4eef9b42bbfd5/sqlalchemy-2.0.41-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f67766965996e63bb46cfbf2ce5355fc32d9dd3b8ad7e536a920ff9ee422e23", size = 3197511 }, + { url = "https://files.pythonhosted.org/packages/fa/2e/677c17c5d6a004c3c45334ab1dbe7b7deb834430b282b8a0f75ae220c8eb/sqlalchemy-2.0.41-cp313-cp313-win32.whl", hash = "sha256:bfc9064f6658a3d1cadeaa0ba07570b83ce6801a1314985bf98ec9b95d74e15f", size = 2082420 }, + { url = "https://files.pythonhosted.org/packages/e9/61/e8c1b9b6307c57157d328dd8b8348ddc4c47ffdf1279365a13b2b98b8049/sqlalchemy-2.0.41-cp313-cp313-win_amd64.whl", hash = "sha256:82ca366a844eb551daff9d2e6e7a9e5e76d2612c8564f58db6c19a726869c1df", size = 2108329 }, + { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224 }, +] + +[package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] + [[package]] name = "sqlite-vec" version = "0.1.6" From 8feb1827c880b5fe65d30d11ee1c4b75c1ecc0b2 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 22 May 2025 14:51:01 -0700 Subject: [PATCH 22/61] fix: openai provider model id (#2229) # What does this PR do? Since https://github.com/meta-llama/llama-stack/pull/2193 switched to openai sdk, we need to strip 'openai/' from the model_id ## Test Plan start server with openai provider and send a chat completion call --- .../providers/remote/inference/openai/openai.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 9a1ec7ee0..c3c25edd3 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): if prompt_logprobs is not None: logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=(await self.model_store.get_model(model)).provider_resource_id, + model=model_id, prompt=prompt, best_of=best_of, echo=echo, @@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=(await self.model_store.get_model(model)).provider_resource_id, + model=model_id, messages=messages, frequency_penalty=frequency_penalty, function_call=function_call, From d8c6ab9bfc97c7b20bba55345bab3ac7261b197f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 22 May 2025 16:43:08 -0700 Subject: [PATCH 23/61] feat: add MCP tool signature to Responses API (#2232) --- docs/_static/llama-stack-spec.html | 110 +++++++++++++++++++- docs/_static/llama-stack-spec.yaml | 62 +++++++++++ llama_stack/apis/agents/openai_responses.py | 25 ++++- 3 files changed, 195 insertions(+), 2 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 33befc95e..cdbba5dd1 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -6742,6 +6742,9 @@ }, { "$ref": "#/components/schemas/OpenAIResponseInputToolFunction" + }, + { + "$ref": "#/components/schemas/OpenAIResponseInputToolMCP" } ], "discriminator": { @@ -6749,7 +6752,8 @@ "mapping": { "web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch", "file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch", - "function": "#/components/schemas/OpenAIResponseInputToolFunction" + "function": "#/components/schemas/OpenAIResponseInputToolFunction", + "mcp": "#/components/schemas/OpenAIResponseInputToolMCP" } } }, @@ -6839,6 +6843,110 @@ ], "title": "OpenAIResponseInputToolFunction" }, + "OpenAIResponseInputToolMCP": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "mcp", + "default": "mcp" + }, + "server_label": { + "type": "string" + }, + "server_url": { + "type": "string" + }, + "headers": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "require_approval": { + "oneOf": [ + { + "type": "string", + "const": "always" + }, + { + "type": "string", + "const": "never" + }, + { + "type": "object", + "properties": { + "always": { + "type": "array", + "items": { + "type": "string" + } + }, + "never": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "title": "ApprovalFilter" + } + ], + "default": "never" + }, + "allowed_tools": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "string" + } + }, + { + "type": "object", + "properties": { + "tool_names": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "title": "AllowedToolsFilter" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "type", + "server_label", + "server_url", + "require_approval" + ], + "title": "OpenAIResponseInputToolMCP" + }, "OpenAIResponseInputToolWebSearch": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index cae6331b0..1b368762f 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -4762,12 +4762,14 @@ components: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFunction' + - $ref: '#/components/schemas/OpenAIResponseInputToolMCP' discriminator: propertyName: type mapping: web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch' file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch' function: '#/components/schemas/OpenAIResponseInputToolFunction' + mcp: '#/components/schemas/OpenAIResponseInputToolMCP' OpenAIResponseInputToolFileSearch: type: object properties: @@ -4822,6 +4824,66 @@ components: - type - name title: OpenAIResponseInputToolFunction + OpenAIResponseInputToolMCP: + type: object + properties: + type: + type: string + const: mcp + default: mcp + server_label: + type: string + server_url: + type: string + headers: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + require_approval: + oneOf: + - type: string + const: always + - type: string + const: never + - type: object + properties: + always: + type: array + items: + type: string + never: + type: array + items: + type: string + additionalProperties: false + title: ApprovalFilter + default: never + allowed_tools: + oneOf: + - type: array + items: + type: string + - type: object + properties: + tool_names: + type: array + items: + type: string + additionalProperties: false + title: AllowedToolsFilter + additionalProperties: false + required: + - type + - server_label + - server_url + - require_approval + title: OpenAIResponseInputToolMCP OpenAIResponseInputToolWebSearch: type: object properties: diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index dcf0c7f9c..bb463bd57 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -186,8 +186,31 @@ class OpenAIResponseInputToolFileSearch(BaseModel): # TODO: add filters +class ApprovalFilter(BaseModel): + always: list[str] | None = None + never: list[str] | None = None + + +class AllowedToolsFilter(BaseModel): + tool_names: list[str] | None = None + + +@json_schema_type +class OpenAIResponseInputToolMCP(BaseModel): + type: Literal["mcp"] = "mcp" + server_label: str + server_url: str + headers: dict[str, Any] | None = None + + require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never" + allowed_tools: list[str] | AllowedToolsFilter | None = None + + OpenAIResponseInputTool = Annotated[ - OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction, + OpenAIResponseInputToolWebSearch + | OpenAIResponseInputToolFileSearch + | OpenAIResponseInputToolFunction + | OpenAIResponseInputToolMCP, Field(discriminator="type"), ] register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") From 2708312168e8182e4aa3ffb2ee8959a37458fb29 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 22 May 2025 22:05:54 -0700 Subject: [PATCH 24/61] feat(ui): implement chat completion views (#2201) # What does this PR do? Implements table and detail views for chat completions image image ## Test Plan npm run test --- llama_stack/distribution/server/server.py | 12 + llama_stack/ui/.prettierignore | 3 + llama_stack/ui/.prettierrc | 1 + llama_stack/ui/README.md | 3 +- llama_stack/ui/app/layout.tsx | 2 +- .../app/logs/chat-completions/[id]/page.tsx | 62 + .../ui/app/logs/chat-completions/layout.tsx | 45 + .../ui/app/logs/chat-completions/page.tsx | 55 +- .../chat-completion-detail.test.tsx | 193 + .../chat-completion-detail.tsx | 198 + .../chat-completion-table.test.tsx | 340 ++ .../chat-completion-table.tsx | 120 + .../chat-completions/chat-messasge-item.tsx | 107 + .../components/{ => layout}/app-sidebar.tsx | 41 +- .../ui/components/layout/page-breadcrumb.tsx | 49 + llama_stack/ui/components/ui/breadcrumb.tsx | 109 + llama_stack/ui/components/ui/card.tsx | 92 + llama_stack/ui/components/ui/table.tsx | 116 + llama_stack/ui/jest.config.ts | 210 + .../ui/lib/format-message-content.test.ts | 193 + llama_stack/ui/lib/format-message-content.ts | 61 + llama_stack/ui/lib/format-tool-call.tsx | 33 + llama_stack/ui/lib/truncate-text.ts | 8 + llama_stack/ui/lib/types.ts | 44 + llama_stack/ui/lib/{utils.ts => utils.tsx} | 0 llama_stack/ui/package-lock.json | 4655 ++++++++++++++++- llama_stack/ui/package.json | 15 +- 27 files changed, 6729 insertions(+), 38 deletions(-) create mode 100644 llama_stack/ui/.prettierignore create mode 100644 llama_stack/ui/.prettierrc create mode 100644 llama_stack/ui/app/logs/chat-completions/[id]/page.tsx create mode 100644 llama_stack/ui/app/logs/chat-completions/layout.tsx create mode 100644 llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx create mode 100644 llama_stack/ui/components/chat-completions/chat-completion-detail.tsx create mode 100644 llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx create mode 100644 llama_stack/ui/components/chat-completions/chat-completion-table.tsx create mode 100644 llama_stack/ui/components/chat-completions/chat-messasge-item.tsx rename llama_stack/ui/components/{ => layout}/app-sidebar.tsx (50%) create mode 100644 llama_stack/ui/components/layout/page-breadcrumb.tsx create mode 100644 llama_stack/ui/components/ui/breadcrumb.tsx create mode 100644 llama_stack/ui/components/ui/card.tsx create mode 100644 llama_stack/ui/components/ui/table.tsx create mode 100644 llama_stack/ui/jest.config.ts create mode 100644 llama_stack/ui/lib/format-message-content.test.ts create mode 100644 llama_stack/ui/lib/format-message-content.ts create mode 100644 llama_stack/ui/lib/format-tool-call.tsx create mode 100644 llama_stack/ui/lib/truncate-text.ts create mode 100644 llama_stack/ui/lib/types.ts rename llama_stack/ui/lib/{utils.ts => utils.tsx} (100%) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 52f2b71b0..7069390cf 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -23,6 +23,7 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError @@ -465,6 +466,17 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + # --- CORS middleware for local development --- + # TODO: move to reverse proxy + ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322) + app.add_middleware( + CORSMiddleware, + allow_origins=[f"http://localhost:{ui_port}"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore new file mode 100644 index 000000000..1b8ac8894 --- /dev/null +++ b/llama_stack/ui/.prettierignore @@ -0,0 +1,3 @@ +# Ignore artifacts: +build +coverage diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/llama_stack/ui/.prettierrc @@ -0,0 +1 @@ +{} diff --git a/llama_stack/ui/README.md b/llama_stack/ui/README.md index 665619bf1..b6f803509 100644 --- a/llama_stack/ui/README.md +++ b/llama_stack/ui/README.md @@ -1,6 +1,5 @@ ## This is WIP. - We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components. ## Getting Started @@ -23,4 +22,4 @@ pnpm dev bun dev ``` -Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. +Open [http://localhost:8322](http://localhost:8322) with your browser to see the result. diff --git a/llama_stack/ui/app/layout.tsx b/llama_stack/ui/app/layout.tsx index f029002dd..ed8a6cd5d 100644 --- a/llama_stack/ui/app/layout.tsx +++ b/llama_stack/ui/app/layout.tsx @@ -20,7 +20,7 @@ export const metadata: Metadata = { }; import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; -import { AppSidebar } from "@/components/app-sidebar"; +import { AppSidebar } from "@/components/layout/app-sidebar"; export default function Layout({ children }: { children: React.ReactNode }) { return ( diff --git a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx new file mode 100644 index 000000000..f7c2580da --- /dev/null +++ b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx @@ -0,0 +1,62 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { useParams } from "next/navigation"; +import LlamaStackClient from "llama-stack-client"; +import { ChatCompletion } from "@/lib/types"; +import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail"; + +export default function ChatCompletionDetailPage() { + const params = useParams(); + const id = params.id as string; + + const [completionDetail, setCompletionDetail] = + useState(null); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(null); + + useEffect(() => { + if (!id) { + setError(new Error("Completion ID is missing.")); + setIsLoading(false); + return; + } + + const client = new LlamaStackClient({ + baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL, + }); + + const fetchCompletionDetail = async () => { + setIsLoading(true); + setError(null); + setCompletionDetail(null); + try { + const response = await client.chat.completions.retrieve(id); + setCompletionDetail(response as ChatCompletion); + } catch (err) { + console.error( + `Error fetching chat completion detail for ID ${id}:`, + err, + ); + setError( + err instanceof Error + ? err + : new Error("Failed to fetch completion detail"), + ); + } finally { + setIsLoading(false); + } + }; + + fetchCompletionDetail(); + }, [id]); + + return ( + + ); +} diff --git a/llama_stack/ui/app/logs/chat-completions/layout.tsx b/llama_stack/ui/app/logs/chat-completions/layout.tsx new file mode 100644 index 000000000..3dd8c1222 --- /dev/null +++ b/llama_stack/ui/app/logs/chat-completions/layout.tsx @@ -0,0 +1,45 @@ +"use client"; + +import React from "react"; +import { usePathname, useParams } from "next/navigation"; +import { + PageBreadcrumb, + BreadcrumbSegment, +} from "@/components/layout/page-breadcrumb"; +import { truncateText } from "@/lib/truncate-text"; + +export default function ChatCompletionsLayout({ + children, +}: { + children: React.ReactNode; +}) { + const pathname = usePathname(); + const params = useParams(); + + let segments: BreadcrumbSegment[] = []; + + // Default for /logs/chat-completions + if (pathname === "/logs/chat-completions") { + segments = [{ label: "Chat Completions" }]; + } + + // For /logs/chat-completions/[id] + const idParam = params?.id; + if (idParam && typeof idParam === "string") { + segments = [ + { label: "Chat Completions", href: "/logs/chat-completions" }, + { label: `Details (${truncateText(idParam, 20)})` }, + ]; + } + + return ( +
+ <> + {segments.length > 0 && ( + + )} + {children} + +
+ ); +} diff --git a/llama_stack/ui/app/logs/chat-completions/page.tsx b/llama_stack/ui/app/logs/chat-completions/page.tsx index 84cceb8b7..3de77a042 100644 --- a/llama_stack/ui/app/logs/chat-completions/page.tsx +++ b/llama_stack/ui/app/logs/chat-completions/page.tsx @@ -1,7 +1,54 @@ -export default function ChatCompletions() { +"use client"; + +import { useEffect, useState } from "react"; +import LlamaStackClient from "llama-stack-client"; +import { ChatCompletion } from "@/lib/types"; +import { ChatCompletionsTable } from "@/components/chat-completions/chat-completion-table"; + +export default function ChatCompletionsPage() { + const [completions, setCompletions] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(null); + + useEffect(() => { + const client = new LlamaStackClient({ + baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL, + }); + const fetchCompletions = async () => { + setIsLoading(true); + setError(null); + try { + const response = await client.chat.completions.list(); + const data = Array.isArray(response) + ? response + : (response as any).data; + + if (Array.isArray(data)) { + setCompletions(data); + } else { + console.error("Unexpected response structure:", response); + setError(new Error("Unexpected response structure")); + setCompletions([]); + } + } catch (err) { + console.error("Error fetching chat completions:", err); + setError( + err instanceof Error ? err : new Error("Failed to fetch completions"), + ); + setCompletions([]); + } finally { + setIsLoading(false); + } + }; + + fetchCompletions(); + }, []); + return ( -
-

Under Construction

-
+ ); } diff --git a/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx b/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx new file mode 100644 index 000000000..33247ed26 --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx @@ -0,0 +1,193 @@ +import React from "react"; +import { render, screen } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { ChatCompletionDetailView } from "./chat-completion-detail"; +import { ChatCompletion } from "@/lib/types"; + +// Initial test file setup for ChatCompletionDetailView + +describe("ChatCompletionDetailView", () => { + test("renders skeleton UI when isLoading is true", () => { + const { container } = render( + , + ); + // Use the data-slot attribute for Skeletons + const skeletons = container.querySelectorAll('[data-slot="skeleton"]'); + expect(skeletons.length).toBeGreaterThan(0); + }); + + test("renders error message when error prop is provided", () => { + render( + , + ); + expect( + screen.getByText(/Error loading details for ID err-id: Network Error/), + ).toBeInTheDocument(); + }); + + test("renders default error message when error.message is empty", () => { + render( + , + ); + // Use regex to match the error message regardless of whitespace + expect( + screen.getByText(/Error loading details for ID\s*err-id\s*:/), + ).toBeInTheDocument(); + }); + + test("renders error message when error prop is an object without message", () => { + render( + , + ); + // Use regex to match the error message regardless of whitespace + expect( + screen.getByText(/Error loading details for ID\s*err-id\s*:/), + ).toBeInTheDocument(); + }); + + test("renders not found message when completion is null and not loading/error", () => { + render( + , + ); + expect( + screen.getByText("No details found for completion ID: notfound-id."), + ).toBeInTheDocument(); + }); + + test("renders input, output, and properties for valid completion", () => { + const mockCompletion: ChatCompletion = { + id: "comp_123", + object: "chat.completion", + created: 1710000000, + model: "llama-test-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Test output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Test input" }], + }; + render( + , + ); + // Input + expect(screen.getByText("Input")).toBeInTheDocument(); + expect(screen.getByText("Test input")).toBeInTheDocument(); + // Output + expect(screen.getByText("Output")).toBeInTheDocument(); + expect(screen.getByText("Test output")).toBeInTheDocument(); + // Properties + expect(screen.getByText("Properties")).toBeInTheDocument(); + expect(screen.getByText("Created:")).toBeInTheDocument(); + expect( + screen.getByText(new Date(1710000000 * 1000).toLocaleString()), + ).toBeInTheDocument(); + expect(screen.getByText("ID:")).toBeInTheDocument(); + expect(screen.getByText("comp_123")).toBeInTheDocument(); + expect(screen.getByText("Model:")).toBeInTheDocument(); + expect(screen.getByText("llama-test-model")).toBeInTheDocument(); + expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); + expect(screen.getByText("stop")).toBeInTheDocument(); + }); + + test("renders tool call in output and properties when present", () => { + const toolCall = { + function: { name: "search", arguments: '{"query":"llama"}' }, + }; + const mockCompletion: ChatCompletion = { + id: "comp_tool", + object: "chat.completion", + created: 1710001000, + model: "llama-tool-model", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Tool output", + tool_calls: [toolCall], + }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Tool input" }], + }; + render( + , + ); + // Output should include the tool call block (should be present twice: input and output) + const toolCallLabels = screen.getAllByText("Tool Call"); + expect(toolCallLabels.length).toBeGreaterThanOrEqual(1); // At least one, but could be two + // The tool call block should contain the formatted tool call string in both input and output + const toolCallBlocks = screen.getAllByText('search({"query":"llama"})'); + expect(toolCallBlocks.length).toBe(2); + // Properties should include the tool call name + expect(screen.getByText("Functions/Tools Called:")).toBeInTheDocument(); + expect(screen.getByText("search")).toBeInTheDocument(); + }); + + test("handles missing/empty fields gracefully", () => { + const mockCompletion: ChatCompletion = { + id: "comp_edge", + object: "chat.completion", + created: 1710002000, + model: "llama-edge-model", + choices: [], // No choices + input_messages: [], // No input messages + }; + render( + , + ); + // Input section should be present but empty + expect(screen.getByText("Input")).toBeInTheDocument(); + // Output section should show fallback message + expect( + screen.getByText("No message found in assistant's choice."), + ).toBeInTheDocument(); + // Properties should show N/A for finish reason + expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); + expect(screen.getByText("N/A")).toBeInTheDocument(); + }); +}); diff --git a/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx b/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx new file mode 100644 index 000000000..e76418d1a --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx @@ -0,0 +1,198 @@ +"use client"; + +import { ChatMessage, ChatCompletion } from "@/lib/types"; +import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item"; +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Skeleton } from "@/components/ui/skeleton"; + +function ChatCompletionDetailLoadingView() { + return ( + <> + {/* Title Skeleton */} +
+
+ {[...Array(2)].map((_, i) => ( + + + + + + + + + + + + + ))} +
+
+
+ {" "} + {/* Properties Title Skeleton */} + {[...Array(5)].map((_, i) => ( +
+ + +
+ ))} +
+
+
+ + ); +} + +interface ChatCompletionDetailViewProps { + completion: ChatCompletion | null; + isLoading: boolean; + error: Error | null; + id: string; +} + +export function ChatCompletionDetailView({ + completion, + isLoading, + error, + id, +}: ChatCompletionDetailViewProps) { + if (error) { + return ( + <> + {/* We still want a title for consistency on error pages */} +

Chat Completion Details

+

+ Error loading details for ID {id}: {error.message} +

+ + ); + } + + if (isLoading) { + return ; + } + + if (!completion) { + // This state means: not loading, no error, but no completion data + return ( + <> + {/* We still want a title for consistency on not-found pages */} +

Chat Completion Details

+

No details found for completion ID: {id}.

+ + ); + } + + // If no error, not loading, and completion exists, render the details: + return ( + <> +

Chat Completion Details

+
+
+ + + Input + + + {completion.input_messages?.map((msg, index) => ( + + ))} + {completion.choices?.[0]?.message?.tool_calls && + !completion.input_messages?.some( + (im) => + im.role === "assistant" && + im.tool_calls && + im.tool_calls.length > 0, + ) && + completion.choices[0].message.tool_calls.map( + (toolCall: any, index: number) => { + const assistantToolCallMessage: ChatMessage = { + role: "assistant", + tool_calls: [toolCall], + content: "", // Ensure content is defined, even if empty + }; + return ( + + ); + }, + )} + + + + + + Output + + + {completion.choices?.[0]?.message ? ( + + ) : ( +

+ No message found in assistant's choice. +

+ )} +
+
+
+ +
+ + + Properties + + +
    +
  • + Created:{" "} + + {new Date(completion.created * 1000).toLocaleString()} + +
  • +
  • + ID:{" "} + + {completion.id} + +
  • +
  • + Model:{" "} + + {completion.model} + +
  • +
  • + Finish Reason:{" "} + + {completion.choices?.[0]?.finish_reason || "N/A"} + +
  • + {completion.choices?.[0]?.message?.tool_calls && + completion.choices[0].message.tool_calls.length > 0 && ( +
  • + Functions/Tools Called: +
      + {completion.choices[0].message.tool_calls.map( + (toolCall: any, index: number) => ( +
    • + + {toolCall.function?.name || "N/A"} + +
    • + ), + )} +
    +
  • + )} +
+
+
+
+
+ + ); +} diff --git a/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx b/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx new file mode 100644 index 000000000..e71ef3d43 --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx @@ -0,0 +1,340 @@ +import React from "react"; +import { render, screen, fireEvent } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { ChatCompletionsTable } from "./chat-completion-table"; +import { ChatCompletion } from "@/lib/types"; // Assuming this path is correct + +// Mock next/navigation +const mockPush = jest.fn(); +jest.mock("next/navigation", () => ({ + useRouter: () => ({ + push: mockPush, + }), +})); + +// Mock helper functions +// These are hoisted, so their mocks are available throughout the file +jest.mock("@/lib/truncate-text"); +jest.mock("@/lib/format-tool-call"); + +// Import the mocked functions to set up default or specific implementations +import { truncateText as originalTruncateText } from "@/lib/truncate-text"; +import { formatToolCallToString as originalFormatToolCallToString } from "@/lib/format-tool-call"; + +// Cast to jest.Mock for typings +const truncateText = originalTruncateText as jest.Mock; +const formatToolCallToString = originalFormatToolCallToString as jest.Mock; + +describe("ChatCompletionsTable", () => { + const defaultProps = { + completions: [] as ChatCompletion[], + isLoading: false, + error: null, + }; + + beforeEach(() => { + // Reset all mocks before each test + mockPush.mockClear(); + truncateText.mockClear(); + formatToolCallToString.mockClear(); + + // Default pass-through implementation for tests not focusing on truncation/formatting + truncateText.mockImplementation((text: string | undefined) => text); + formatToolCallToString.mockImplementation((toolCall: any) => + toolCall && typeof toolCall === "object" && toolCall.name + ? `[DefaultToolCall:${toolCall.name}]` + : "[InvalidToolCall]", + ); + }); + + test("renders without crashing with default props", () => { + render(); + // Check for a unique element that should be present in the non-empty, non-loading, non-error state + // For now, as per Task 1, we will test the empty state message + expect(screen.getByText("No chat completions found.")).toBeInTheDocument(); + }); + + test("click on a row navigates to the correct URL", () => { + const { rerender } = render(); + + // Simulate a scenario where a completion exists and is clicked + const mockCompletion: ChatCompletion = { + id: "comp_123", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "llama-test-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Test output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Test input" }], + }; + + rerender( + , + ); + const row = screen.getByText("Test input").closest("tr"); + if (row) { + fireEvent.click(row); + expect(mockPush).toHaveBeenCalledWith("/logs/chat-completions/comp_123"); + } else { + throw new Error('Row with "Test input" not found for router mock test.'); + } + }); + + describe("Loading State", () => { + test("renders skeleton UI when isLoading is true", () => { + const { container } = render( + , + ); + + // The Skeleton component uses data-slot="skeleton" + const skeletonSelector = '[data-slot="skeleton"]'; + + // Check for skeleton in the table caption + const tableCaption = container.querySelector("caption"); + expect(tableCaption).toBeInTheDocument(); + if (tableCaption) { + const captionSkeleton = tableCaption.querySelector(skeletonSelector); + expect(captionSkeleton).toBeInTheDocument(); + } + + // Check for skeletons in the table body cells + const tableBody = container.querySelector("tbody"); + expect(tableBody).toBeInTheDocument(); + if (tableBody) { + const bodySkeletons = tableBody.querySelectorAll( + `td ${skeletonSelector}`, + ); + expect(bodySkeletons.length).toBeGreaterThan(0); // Ensure at least one skeleton cell exists + } + + // General check: ensure multiple skeleton elements are present in the table overall + const allSkeletonsInTable = container.querySelectorAll( + `table ${skeletonSelector}`, + ); + expect(allSkeletonsInTable.length).toBeGreaterThan(3); // e.g., caption + at least one row of 3 cells, or just a few + }); + }); + + describe("Error State", () => { + test("renders error message when error prop is provided", () => { + const errorMessage = "Network Error"; + render( + , + ); + expect( + screen.getByText(`Error fetching data: ${errorMessage}`), + ).toBeInTheDocument(); + }); + + test("renders default error message when error.message is not available", () => { + render( + , + ); // Error with empty message + expect( + screen.getByText("Error fetching data: An unknown error occurred"), + ).toBeInTheDocument(); + }); + + test("renders default error message when error prop is an object without message", () => { + render(); // Empty error object + expect( + screen.getByText("Error fetching data: An unknown error occurred"), + ).toBeInTheDocument(); + }); + }); + + describe("Empty State", () => { + test('renders "No chat completions found." and no table when completions array is empty', () => { + render( + , + ); + expect( + screen.getByText("No chat completions found."), + ).toBeInTheDocument(); + + // Ensure that the table structure is NOT rendered in the empty state + const table = screen.queryByRole("table"); + expect(table).not.toBeInTheDocument(); + }); + }); + + describe("Data Rendering", () => { + test("renders table caption, headers, and completion data correctly", () => { + const mockCompletions = [ + { + id: "comp_1", + object: "chat.completion", + created: 1710000000, // Fixed timestamp for test + model: "llama-test-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Test output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Test input" }], + }, + { + id: "comp_2", + object: "chat.completion", + created: 1710001000, + model: "llama-another-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Another output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Another input" }], + }, + ]; + + render( + , + ); + + // Table caption + expect( + screen.getByText("A list of your recent chat completions."), + ).toBeInTheDocument(); + + // Table headers + expect(screen.getByText("Input")).toBeInTheDocument(); + expect(screen.getByText("Output")).toBeInTheDocument(); + expect(screen.getByText("Model")).toBeInTheDocument(); + expect(screen.getByText("Created")).toBeInTheDocument(); + + // Data rows + expect(screen.getByText("Test input")).toBeInTheDocument(); + expect(screen.getByText("Test output")).toBeInTheDocument(); + expect(screen.getByText("llama-test-model")).toBeInTheDocument(); + expect( + screen.getByText(new Date(1710000000 * 1000).toLocaleString()), + ).toBeInTheDocument(); + + expect(screen.getByText("Another input")).toBeInTheDocument(); + expect(screen.getByText("Another output")).toBeInTheDocument(); + expect(screen.getByText("llama-another-model")).toBeInTheDocument(); + expect( + screen.getByText(new Date(1710001000 * 1000).toLocaleString()), + ).toBeInTheDocument(); + }); + }); + + describe("Text Truncation and Tool Call Formatting", () => { + test("truncates long input and output text", () => { + // Specific mock implementation for this test + truncateText.mockImplementation( + (text: string | undefined, maxLength?: number) => { + const defaultTestMaxLength = 10; + const effectiveMaxLength = maxLength ?? defaultTestMaxLength; + return typeof text === "string" && text.length > effectiveMaxLength + ? text.slice(0, effectiveMaxLength) + "..." + : text; + }, + ); + + const longInput = + "This is a very long input message that should be truncated."; + const longOutput = + "This is a very long output message that should also be truncated."; + const mockCompletions = [ + { + id: "comp_trunc", + object: "chat.completion", + created: 1710002000, + model: "llama-trunc-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: longOutput }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: longInput }], + }, + ]; + + render( + , + ); + + // The truncated text should be present for both input and output + const truncatedTexts = screen.getAllByText( + longInput.slice(0, 10) + "...", + ); + expect(truncatedTexts.length).toBe(2); // one for input, one for output + // Optionally, verify each one is in the document if getAllByText doesn't throw on not found + truncatedTexts.forEach((textElement) => + expect(textElement).toBeInTheDocument(), + ); + }); + + test("formats tool call output using formatToolCallToString", () => { + // Specific mock implementation for this test + formatToolCallToString.mockImplementation( + (toolCall: any) => `[TOOL:${toolCall.name}]`, + ); + // Ensure no truncation interferes for this specific test for clarity of tool call format + truncateText.mockImplementation((text: string | undefined) => text); + + const toolCall = { name: "search", args: { query: "llama" } }; + const mockCompletions = [ + { + id: "comp_tool", + object: "chat.completion", + created: 1710003000, + model: "llama-tool-model", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Tool output", // Content that will be prepended + tool_calls: [toolCall], + }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Tool input" }], + }, + ]; + + render( + , + ); + + // The component concatenates message.content and the formatted tool call + expect(screen.getByText("Tool output [TOOL:search]")).toBeInTheDocument(); + }); + }); +}); diff --git a/llama_stack/ui/components/chat-completions/chat-completion-table.tsx b/llama_stack/ui/components/chat-completions/chat-completion-table.tsx new file mode 100644 index 000000000..e11acf376 --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-table.tsx @@ -0,0 +1,120 @@ +"use client"; + +import { useRouter } from "next/navigation"; +import { ChatCompletion } from "@/lib/types"; +import { truncateText } from "@/lib/truncate-text"; +import { + extractTextFromContentPart, + extractDisplayableText, +} from "@/lib/format-message-content"; +import { + Table, + TableBody, + TableCaption, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { Skeleton } from "@/components/ui/skeleton"; + +interface ChatCompletionsTableProps { + completions: ChatCompletion[]; + isLoading: boolean; + error: Error | null; +} + +export function ChatCompletionsTable({ + completions, + isLoading, + error, +}: ChatCompletionsTableProps) { + const router = useRouter(); + + const tableHeader = ( + + + Input + Output + Model + Created + + + ); + + if (isLoading) { + return ( + + + + + {tableHeader} + + {[...Array(3)].map((_, i) => ( + + + + + + + + + + + + + + + ))} + +
+ ); + } + + if (error) { + return ( +

Error fetching data: {error.message || "An unknown error occurred"}

+ ); + } + + if (completions.length === 0) { + return

No chat completions found.

; + } + + return ( + + A list of your recent chat completions. + {tableHeader} + + {completions.map((completion) => ( + + router.push(`/logs/chat-completions/${completion.id}`) + } + className="cursor-pointer hover:bg-muted/50" + > + + {truncateText( + extractTextFromContentPart( + completion.input_messages?.[0]?.content, + ), + )} + + + {(() => { + const message = completion.choices?.[0]?.message; + const outputText = extractDisplayableText(message); + return truncateText(outputText); + })()} + + {completion.model} + + {new Date(completion.created * 1000).toLocaleString()} + + + ))} + +
+ ); +} diff --git a/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx new file mode 100644 index 000000000..58a009aed --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx @@ -0,0 +1,107 @@ +"use client"; + +import { ChatMessage } from "@/lib/types"; +import React from "react"; +import { formatToolCallToString } from "@/lib/format-tool-call"; +import { extractTextFromContentPart } from "@/lib/format-message-content"; + +// Sub-component or helper for the common label + content structure +const MessageBlock: React.FC<{ + label: string; + labelDetail?: string; + content: React.ReactNode; +}> = ({ label, labelDetail, content }) => { + return ( +
+

+ {label} + {labelDetail && ( + + {labelDetail} + + )} +

+
{content}
+
+ ); +}; + +interface ToolCallBlockProps { + children: React.ReactNode; + className?: string; +} + +const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => { + // Common styling for both function call arguments and tool output blocks + // Let's use slate-50 background as it's good for code-like content. + const baseClassName = + "p-3 bg-slate-50 border border-slate-200 rounded-md text-sm"; + + return ( +
+
{children}
+
+ ); +}; + +interface ChatMessageItemProps { + message: ChatMessage; +} +export function ChatMessageItem({ message }: ChatMessageItemProps) { + switch (message.role) { + case "system": + return ( + + ); + case "user": + return ( + + ); + + case "assistant": + if (message.tool_calls && message.tool_calls.length > 0) { + return ( + <> + {message.tool_calls.map((toolCall: any, index: number) => { + const formattedToolCall = formatToolCallToString(toolCall); + const toolCallContent = ( + + {formattedToolCall || "Error: Could not display tool call"} + + ); + return ( + + ); + })} + + ); + } else { + return ( + + ); + } + case "tool": + const toolOutputContent = ( + + {extractTextFromContentPart(message.content)} + + ); + return ( + + ); + } + return null; +} diff --git a/llama_stack/ui/components/app-sidebar.tsx b/llama_stack/ui/components/layout/app-sidebar.tsx similarity index 50% rename from llama_stack/ui/components/app-sidebar.tsx rename to llama_stack/ui/components/layout/app-sidebar.tsx index 3d541856f..1c53d6cc5 100644 --- a/llama_stack/ui/components/app-sidebar.tsx +++ b/llama_stack/ui/components/layout/app-sidebar.tsx @@ -1,5 +1,9 @@ +"use client"; + import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react"; import Link from "next/link"; +import { usePathname } from "next/navigation"; +import { cn } from "@/lib/utils"; import { Sidebar, @@ -32,6 +36,8 @@ const logItems = [ ]; export function AppSidebar() { + const pathname = usePathname(); + return ( @@ -42,16 +48,31 @@ export function AppSidebar() { Logs - {logItems.map((item) => ( - - - - - {item.title} - - - - ))} + {logItems.map((item) => { + const isActive = pathname.startsWith(item.url); + return ( + + + + + {item.title} + + + + ); + })} diff --git a/llama_stack/ui/components/layout/page-breadcrumb.tsx b/llama_stack/ui/components/layout/page-breadcrumb.tsx new file mode 100644 index 000000000..fdb561d68 --- /dev/null +++ b/llama_stack/ui/components/layout/page-breadcrumb.tsx @@ -0,0 +1,49 @@ +"use client"; + +import Link from "next/link"; +import React from "react"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from "@/components/ui/breadcrumb"; + +export interface BreadcrumbSegment { + label: string; + href?: string; +} + +interface PageBreadcrumbProps { + segments: BreadcrumbSegment[]; + className?: string; +} + +export function PageBreadcrumb({ segments, className }: PageBreadcrumbProps) { + if (!segments || segments.length === 0) { + return null; + } + + return ( + + + {segments.map((segment, index) => ( + + + {segment.href ? ( + + {segment.label} + + ) : ( + {segment.label} + )} + + {index < segments.length - 1 && } + + ))} + + + ); +} diff --git a/llama_stack/ui/components/ui/breadcrumb.tsx b/llama_stack/ui/components/ui/breadcrumb.tsx new file mode 100644 index 000000000..f63ae19af --- /dev/null +++ b/llama_stack/ui/components/ui/breadcrumb.tsx @@ -0,0 +1,109 @@ +import * as React from "react"; +import { Slot } from "@radix-ui/react-slot"; +import { ChevronRight, MoreHorizontal } from "lucide-react"; + +import { cn } from "@/lib/utils"; + +function Breadcrumb({ ...props }: React.ComponentProps<"nav">) { + return