From 825ce39879f55acf77681e1a38b1e32366884c4b Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 22 Apr 2025 11:47:53 -0400 Subject: [PATCH 01/36] fix: Together provider shutdown and default to non-streaming (#2001) # What does this PR do? The together inference provider was throwing a stack trace every time it shut down, as it was trying to call a non-existent `close` method on the AsyncTogether client. While fixing that, I also adjusted its shutdown logic to close the OpenAI client if we've created one of those, as that client does have a `close` method. In testing that, I also realized we were defaulting to treating all requests as streaming requests instead of defaulting to non-streaming. So, this flips that default to non-streaming to match how the other providers work. ## Test Plan I tested this by ensuring the together inference provider no longer spits out a long stack trace when shutting it down and by running the OpenAI API chat completion verification suite to ensure the change in default streaming logic didn't mess anything else up. Signed-off-by: Ben Browning --- .../providers/remote/inference/together/together.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 001e6aac4..48e41f5b0 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def shutdown(self) -> None: if self._client: - await self._client.close() + # Together client has no close method, so just set to None self._client = None + if self._openai_client: + await self._openai_client.close() + self._openai_client = None async def completion( self, @@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi top_p=top_p, user=user, ) - if params.get("stream", True): + if params.get("stream", False): return self._stream_openai_chat_completion(params) return await self._get_openai_client().chat.completions.create(**params) # type: ignore From d6e88e0bc67bfdb16186b4e2e896283fd2930986 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 23 Apr 2025 03:44:18 -0400 Subject: [PATCH 02/36] docs: add RamaLama to list of known external providers (#2004) The RamaLama project now has an external provider offering for Llama Stack: https://github.com/containers/llama-stack-provider-ramalama See also: https://github.com/meta-llama/llama-stack/pull/1676 Signed-off-by: Nathan Weinberg --- docs/source/providers/external.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md index 90fc77979..345b6e71d 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external.md @@ -53,6 +53,7 @@ Here's a list of known external providers that you can use with Llama Stack: | Type | Name | Description | Repository | |------|------|-------------|------------| | Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | +| Remote | RamaLama | Inference models with RamaLama | [llama-stack-provider-ramalama](https://github.com/containers/llama-stack-provider-ramalama) | ### Remote Provider Specification From d39462d073dee76ce1e568e49452922b3af2a205 Mon Sep 17 00:00:00 2001 From: Ilya Kolchinsky <58424190+ilya-kolchinsky@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:32:12 +0200 Subject: [PATCH 03/36] feat: Hide tool output under an expander in Playground UI (#2003) # What does this PR do? Now, tool outputs and retrieved chunks from the vector DB (i.e., everything except for the actual model reply) are hidden under an expander form when presented to the user. # Test Plan Navigate to the RAG page in the Playground UI. --- .../distribution/ui/page/playground/rag.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 392c9afe2..696d89bc2 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -24,6 +24,13 @@ def rag_chat_page(): def should_disable_input(): return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0 + def log_message(message): + with st.chat_message(message["role"]): + if "tool_output" in message and message["tool_output"]: + with st.expander(label="Tool Output", expanded=False, icon="πŸ› "): + st.write(message["tool_output"]) + st.markdown(message["content"]) + with st.sidebar: # File/Directory Upload Section st.subheader("Upload Documents", divider=True) @@ -146,8 +153,7 @@ def rag_chat_page(): # Display chat history for message in st.session_state.displayed_messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) + log_message(message) if temperature > 0.0: strategy = { @@ -201,7 +207,7 @@ def rag_chat_page(): # Display assistant response with st.chat_message("assistant"): - retrieval_message_placeholder = st.empty() + retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="πŸ› ") message_placeholder = st.empty() full_response = "" retrieval_response = "" @@ -209,14 +215,16 @@ def rag_chat_page(): log.print() if log.role == "tool_execution": retrieval_response += log.content.replace("====", "").strip() - retrieval_message_placeholder.info(retrieval_response) + retrieval_message_placeholder.write(retrieval_response) else: full_response += log.content message_placeholder.markdown(full_response + "β–Œ") message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) - st.session_state.displayed_messages.append({"role": "assistant", "content": full_response}) + st.session_state.displayed_messages.append( + {"role": "assistant", "content": full_response, "tool_output": retrieval_response} + ) def direct_process_prompt(prompt): # Add the system prompt in the beginning of the conversation @@ -230,15 +238,14 @@ def rag_chat_page(): prompt_context = rag_response.content with st.chat_message("assistant"): + with st.expander(label="Retrieval Output", expanded=False): + st.write(prompt_context) + retrieval_message_placeholder = st.empty() message_placeholder = st.empty() full_response = "" retrieval_response = "" - # Display the retrieved content - retrieval_response += str(prompt_context) - retrieval_message_placeholder.info(retrieval_response) - # Construct the extended prompt extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" From deee355952594d230b8ed060a69eaf5d8d45a194 Mon Sep 17 00:00:00 2001 From: Ilya Kolchinsky <58424190+ilya-kolchinsky@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:33:19 +0200 Subject: [PATCH 04/36] fix: Added lazy initialization of the remote vLLM client to avoid issues with expired asyncio event loop (#1969) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Closes #1968. The asynchronous client in `VLLMInferenceAdapter` is now initialized directly before first use and not in `VLLMInferenceAdapter.initialize`. This prevents issues arising due to accessing an expired event loop from a completed `asyncio.run`. ## Test Plan Ran unit tests, including `test_remote_vllm.py`. Ran the code snippet mentioned in #1968. --------- Co-authored-by: SΓ©bastien Han --- .../providers/remote/inference/vllm/vllm.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d141afa86..8cfef2ee0 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - log.info(f"Initializing VLLM client with base_url={self.config.url}") - self.client = AsyncOpenAI( - base_url=self.config.url, - api_key=self.config.api_token, - http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), - ) + pass async def shutdown(self) -> None: pass @@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) + def _lazy_initialize_client(self): + if self.client is not None: + return + + log.info(f"Initializing vLLM client with base_url={self.config.url}") + self.client = self._create_client() + + def _create_client(self): + return AsyncOpenAI( + base_url=self.config.url, + api_key=self.config.api_token, + http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), + ) + async def completion( self, model_id: str, @@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: + self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -357,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def register_model(self, model: Model) -> Model: - assert self.client is not None + # register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet. + # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. + # Changing this may lead to unpredictable behavior. + client = self._create_client() if self.client is None else self.client model = await self.register_helper.register_model(model) - res = await self.client.models.list() + res = await client.models.list() available_models = [m.id async for m in res] if model.provider_resource_id not in available_models: raise ValueError( @@ -410,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: + self._lazy_initialize_client() assert self.client is not None model = await self._get_model(model_id) @@ -449,6 +464,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: + self._lazy_initialize_client() model_obj = await self._get_model(model) extra_body: Dict[str, Any] = {} @@ -505,6 +521,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): top_p: Optional[float] = None, user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + self._lazy_initialize_client() model_obj = await self._get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, From e0fa67c81c7bfc00d366acfe6c8447cbcfbdd747 Mon Sep 17 00:00:00 2001 From: Kevin Postlethwait Date: Wed, 23 Apr 2025 09:39:18 -0400 Subject: [PATCH 05/36] docs: add examples for how to define RAG docs (#1981) # What does this PR do? Add examples for how to define RAGDocuments. Not sure if this is the best place for these docs. @raghotham Please advise ## Test Plan None, documentation [//]: # (## Documentation) Signed-off-by: Kevin --- docs/source/building_applications/rag.md | 38 +++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 39d1ba333..db6303209 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -68,7 +68,8 @@ chunks_response = client.vector_io.query( ### Using the RAG Tool A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. -and automatically chunks them into smaller pieces. +and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the +[appendix](#more-ragdocument-examples). ```python from llama_stack_client import RAGDocument @@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list(): print(f"Unregistering vector database: {vector_db_id.identifier}") client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier) ``` + +### Appendix + +#### More RAGDocument Examples +```python +from llama_stack_client import RAGDocument +import base64 + +RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"}) +RAGDocument(document_id="num-1", content="plain text") +RAGDocument( + document_id="num-2", + content={ + "type": "text", + "text": "plain text input", + }, # for inputs that should be treated as text explicitly +) +RAGDocument( + document_id="num-3", + content={ + "type": "image", + "image": {"url": {"uri": "https://mywebsite.com/image.jpg"}}, + }, +) +B64_ENCODED_IMAGE = base64.b64encode( + requests.get( + "https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png" + ).content +) +RAGDocuemnt( + document_id="num-4", + content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}}, +) +``` +for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py). From dc46725f56d6a404f24793c1f7242c6fcdea8e5b Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 23 Apr 2025 09:44:28 -0400 Subject: [PATCH 06/36] fix: properly handle streaming client disconnects (#2000) # What does this PR do? Previously, when a streaming client would disconnect before we were finished streaming the entire response, an error like the below would get raised from the `sse_generator` function in `llama_stack/distribution/server/server.py`: ``` AttributeError: 'coroutine' object has no attribute 'aclose'. Did you mean: 'close'? ``` This was because we were calling `aclose` on a coroutine instead of the awaited value from that coroutine. This change fixes that, so that we save off the awaited value and then can call `aclose` on it if we encounter an `asyncio.CancelledError`, like we see when a client disconnects before we're finished streaming. The other changes in here are to add a simple set of tests for the happy path of our SSE streaming and this client disconnect path. That unfortunately requires adding one more dependency into our unit test section of pyproject.toml since `server.py` requires loading some of the telemetry code for me to test this functionality. ## Test Plan I wrote the tests in `tests/unit/server/test_sse.py` first, verified the client disconnected test failed before my change, and that it passed afterwards. ``` python -m pytest -s -v tests/unit/server/test_sse.py ``` Signed-off-by: Ben Browning --- llama_stack/distribution/server/server.py | 5 ++- pyproject.toml | 11 ++++- tests/unit/server/test_sse.py | 55 +++++++++++++++++++++++ uv.lock | 2 + 4 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 tests/unit/server/test_sse.py diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 6c5e2506c..50cf44ec9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -162,9 +162,10 @@ async def maybe_await(value): return value -async def sse_generator(event_gen): +async def sse_generator(event_gen_coroutine): + event_gen = await event_gen_coroutine try: - async for item in await event_gen: + async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: diff --git a/pyproject.toml b/pyproject.toml index 47d845c30..209367c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,16 @@ dev = [ "ruamel.yaml", # needed for openapi generator ] # These are the dependencies required for running unit tests. -unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"] +unit = [ + "sqlite-vec", + "openai", + "aiosqlite", + "aiohttp", + "pypdf", + "chardet", + "qdrant-client", + "opentelemetry-exporter-otlp-proto-http" +] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment # separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py new file mode 100644 index 000000000..4a76bdc9b --- /dev/null +++ b/tests/unit/server/test_sse.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +import pytest + +from llama_stack.distribution.server.server import create_sse_event, sse_generator + + +@pytest.mark.asyncio +async def test_sse_generator_basic(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + yield "Test event 2" + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Test that the events are streamed correctly + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 2 + assert seen_events[0] == create_sse_event("Test event 1") + assert seen_events[1] == create_sse_event("Test event 2") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + # Simulate a client disconnect before emitting event 2 + raise asyncio.CancelledError() + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Start reading the events, ensuring this doesn't raise an exception + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 1 + assert seen_events[0] == create_sse_event("Test event 1") diff --git a/uv.lock b/uv.lock index cd82a016c..e6368f131 100644 --- a/uv.lock +++ b/uv.lock @@ -1458,6 +1458,7 @@ unit = [ { name = "aiosqlite" }, { name = "chardet" }, { name = "openai" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "pypdf" }, { name = "qdrant-client" }, { name = "sqlite-vec" }, @@ -1491,6 +1492,7 @@ requires-dist = [ { name = "openai", marker = "extra == 'test'" }, { name = "openai", marker = "extra == 'unit'" }, { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" }, + { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" }, { name = "opentelemetry-sdk", marker = "extra == 'test'" }, { name = "pandas", marker = "extra == 'ui'" }, { name = "pillow" }, From 64f747fe095570923a331cf29cb6b92d5588512a Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Wed, 23 Apr 2025 09:57:54 -0400 Subject: [PATCH 07/36] feat: add tool name to chat output in playground (#1996) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This PR adds the name of the tool that is used by the agent on the "tools" page of the playground. See image below for an example. ![Screenshot 2025-04-18 at 3 14 18β€―PM](https://github.com/user-attachments/assets/04e97783-4003-4121-9446-9e0ad7209256) ## Test Plan Run the playground and navigate to the tools page. There users can see that this additional text is present when tools are invoked and absent when they are not. ``` streamlit run llama_stack/distribution/ui/app.py ``` Signed-off-by: Michael Clifford --- llama_stack/distribution/ui/page/playground/tools.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index c5bb2216a..96c6a1783 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -144,7 +144,11 @@ def tool_chat_page(): yield response.event.payload.delta.text if response.event.payload.event_type == "step_complete": if response.event.payload.step_details.step_type == "tool_execution": - yield " πŸ›  " + if response.event.payload.step_details.tool_calls: + tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name) + yield f'\n\nπŸ›  :grey[_Using "{tool_name}" tool:_]\n\n' + else: + yield "No tool_calls present in step_details" else: yield f"Error occurred in the Llama Stack Cluster: {response}" From 6a44e7ba20d1106ee49066e270023250bafcc3cb Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 23 Apr 2025 09:58:10 -0400 Subject: [PATCH 08/36] docs: add API to external providers table (#2006) Also does a minor reorg of the columns Signed-off-by: Nathan Weinberg --- docs/source/providers/external.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md index 345b6e71d..4935b1fe6 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external.md @@ -50,10 +50,10 @@ Llama Stack supports two types of external providers: Here's a list of known external providers that you can use with Llama Stack: -| Type | Name | Description | Repository | -|------|------|-------------|------------| -| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | -| Remote | RamaLama | Inference models with RamaLama | [llama-stack-provider-ramalama](https://github.com/containers/llama-stack-provider-ramalama) | +| Name | Description | API | Type | Repository | +|------|-------------|-----|------|------------| +| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | +| RamaLama | Inference models with RamaLama | Inference | Remote | [llama-stack-provider-ramalama](https://github.com/containers/llama-stack-provider-ramalama) | ### Remote Provider Specification From fa5dfee07b251b1fcb85e7d42377aee29e268cd9 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 23 Apr 2025 11:48:32 -0400 Subject: [PATCH 09/36] fix: Return HTTP 400 for OpenAI API validation errors (#2002) # What does this PR do? When clients called the Open AI API with invalid input that wasn't caught by our own Pydantic API validation but instead only caught by the backend inference provider, that backend inference provider was returning a HTTP 400 error. However, we were wrapping that into a HTTP 500 error, obfuscating the actual issue from calling clients and triggering OpenAI client retry logic. This change adjusts our existing `translate_exception` method in `server.py` to wrap `openai.BadRequestError` as HTTP 400 errors, passing through the string representation of the error message to the calling user so they can see the actual input validation error and correct it. I tried changing this in a few other places, but ultimately `translate_exception` was the only real place to handle this for both streaming and non-streaming requests across all inference providers that use the OpenAI server APIs. This also tightens up our validation a bit for the OpenAI chat completions API, to catch empty `messages` parameters, invalid `tool_choice` parameters, invalid `tools` items, or passing `tool_choice` when `tools` isn't given. Lastly, this extends our OpenAI API chat completions verifications to also check for consistent input validation across providers. Providers behind Llama Stack should automatically pass all the new tests due to the input validation added here, but some of the providers fail this test when not run behind Llama Stack due to differences in how they handle input validation and errors. (Closes #1951) ## Test Plan To test this, start an OpenAI API verification stack: ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` Then, run the new verification tests with your provider(s) of choice: ``` python -m pytest -s -v \ tests/verifications/openai_api/test_chat_completion.py \ --provider openai-llama-stack python -m pytest -s -v \ tests/verifications/openai_api/test_chat_completion.py \ --provider together-llama-stack ``` Signed-off-by: Ben Browning --- llama_stack/distribution/routers/routers.py | 17 ++++++- llama_stack/distribution/server/server.py | 3 ++ .../fixtures/test_cases/chat_completion.yaml | 46 +++++++++++++++++++ .../openai_api/test_chat_completion.py | 45 ++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 17aecdaf8..d88df00bd 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,6 +8,11 @@ import asyncio import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam +from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -526,7 +531,7 @@ class InferenceRouter(Inference): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], + messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, @@ -558,6 +563,16 @@ class InferenceRouter(Inference): if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + # Use the OpenAI client for a bit of extra input validation without + # exposing the OpenAI client itself as part of our API surface + if tool_choice: + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) + if tools is None: + raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") + if tools: + for tool in tools: + TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) + params = dict( model=model_obj.identifier, messages=messages, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 50cf44ec9..2942920d4 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse +from openai import BadRequestError from pydantic import BaseModel, ValidationError from typing_extensions import Annotated @@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) elif isinstance(exc, ValueError): return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, BadRequestError): + return HTTPException(status_code=400, detail=str(exc)) elif isinstance(exc, PermissionError): return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, TimeoutError): diff --git a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml index 1ace76e34..0c9f1fe9e 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml @@ -15,6 +15,52 @@ test_chat_basic: S? role: user output: Saturn +test_chat_input_validation: + test_name: test_chat_input_validation + test_params: + case: + - case_id: "messages_missing" + input: + messages: [] + output: + error: + status_code: 400 + - case_id: "messages_role_invalid" + input: + messages: + - content: Which planet do humans live on? + role: fake_role + output: + error: + status_code: 400 + - case_id: "tool_choice_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: invalid + output: + error: + status_code: 400 + - case_id: "tool_choice_no_tools" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: required + output: + error: + status_code: 400 + - case_id: "tools_type_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tools: + - type: invalid + output: + error: + status_code: 400 test_chat_image: test_name: test_chat_image test_params: diff --git a/tests/verifications/openai_api/test_chat_completion.py b/tests/verifications/openai_api/test_chat_completion.py index 3a311667a..277eaafa3 100644 --- a/tests/verifications/openai_api/test_chat_completion.py +++ b/tests/verifications/openai_api/test_chat_completion.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any import pytest +from openai import APIError from pydantic import BaseModel from tests.verifications.openai_api.fixtures.fixtures import ( @@ -136,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat assert case["output"].lower() in content.lower() +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=False, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + assert case["output"]["error"]["status_code"] == e.value.status_code + + +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + response = openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=True, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + for _chunk in response: + pass + assert str(case["output"]["error"]["status_code"]) in e.value.message + + @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_image"]["test_params"]["case"], From a673697858c185be965077bf99f1adaa69838c5a Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Thu, 24 Apr 2025 03:34:15 -0400 Subject: [PATCH 10/36] chore: rename ramalama provider (#2008) # What does this PR do? the ramalama team has decided to rename their external provider `ramalama-stack` (more catchy!). Update docs accordingly Signed-off-by: Charlie Doern --- docs/source/providers/external.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md index 4935b1fe6..5aab5ee0f 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external.md @@ -53,7 +53,7 @@ Here's a list of known external providers that you can use with Llama Stack: | Name | Description | API | Type | Repository | |------|-------------|-----|------|------------| | KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | -| RamaLama | Inference models with RamaLama | Inference | Remote | [llama-stack-provider-ramalama](https://github.com/containers/llama-stack-provider-ramalama) | +| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) | ### Remote Provider Specification From 14e60e3c02b4673f4b67bbfefaeb4be93a324f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 24 Apr 2025 11:29:53 +0200 Subject: [PATCH 11/36] feat: include run.yaml in the container image (#2005) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As part of the build process, we now include the generated run.yaml (based of the provided build configuration file) into the container. We updated the entrypoint to use this run configuration as well. Given this simple distribution configuration: ``` # build.yaml version: '2' distribution_spec: description: Use (an external) Ollama server for running LLM inference providers: inference: - remote::ollama vector_io: - inline::faiss safety: - inline::llama-guard agents: - inline::meta-reference telemetry: - inline::meta-reference eval: - inline::meta-reference datasetio: - remote::huggingface - inline::localfs scoring: - inline::basic - inline::llm-as-judge - inline::braintrust tool_runtime: - remote::brave-search - remote::tavily-search - inline::code-interpreter - inline::rag-runtime - remote::model-context-protocol - remote::wolfram-alpha container_image: "registry.access.redhat.com/ubi9" image_type: container image_name: test ``` Build it: ``` llama stack build --config build.yaml ``` Run it: ``` podman run --rm \ -p 8321:8321 \ -e OLLAMA_URL=http://host.containers.internal:11434 \ --name llama-stack-server \ localhost/leseb-test:0.2.2 ``` Signed-off-by: SΓ©bastien Han --- .github/workflows/providers-build.yml | 38 +++++++++ llama_stack/cli/stack/_build.py | 22 ++++-- llama_stack/distribution/build.py | 6 ++ llama_stack/distribution/build_container.sh | 86 ++++++++++++++++++--- tests/unit/distribution/test_build_path.py | 4 +- 5 files changed, 139 insertions(+), 17 deletions(-) diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 117c8b6d2..23257d7dc 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -107,3 +107,41 @@ jobs: - name: Build a single provider run: | USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama + + build-custom-container-distribution: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1 + with: + python-version: "3.10" + + - name: Install LlamaStack + run: | + uv venv + source .venv/bin/activate + uv pip install -e . + + - name: Build a single provider + run: | + yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml + yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml + + - name: Inspect the container image entrypoint + run: | + IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) + echo "Entrypoint: $entrypoint" + if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then + echo "Entrypoint is not correct" + exit 1 + fi diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 26c09af4e..80ab0631b 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -317,11 +317,15 @@ def _generate_run_config( to_write = json.loads(run_config.model_dump_json()) f.write(yaml.dump(to_write, sort_keys=False)) - # this path is only invoked when no template is provided - cprint( - f"You can now run your stack with `llama stack run {run_config_file}`", - color="green", - ) + # Only print this message for non-container builds since it will be displayed before the + # container is built + # For non-container builds, the run.yaml is generated at the very end of the build process so it + # makes sense to display this message + if build_config.image_type != LlamaStackImageType.CONTAINER.value: + cprint( + f"You can now run your stack with `llama stack run {run_config_file}`", + color="green", + ) return run_config_file @@ -355,6 +359,13 @@ def _run_stack_build_command_from_build_config( build_file_path = build_dir / f"{image_name}-build.yaml" os.makedirs(build_dir, exist_ok=True) + run_config_file = None + # Generate the run.yaml so it can be included in the container image with the proper entrypoint + # Only do this if we're building a container image and we're not using a template + if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path: + cprint("Generating run.yaml file", color="green") + run_config_file = _generate_run_config(build_config, build_dir, image_name) + with open(build_file_path, "w") as f: to_write = json.loads(build_config.model_dump_json()) f.write(yaml.dump(to_write, sort_keys=False)) @@ -364,6 +375,7 @@ def _run_stack_build_command_from_build_config( build_file_path, image_name, template_or_config=template_name or config_path or str(build_file_path), + run_config=run_config_file, ) if return_code != 0: raise RuntimeError(f"Failed to build image {image_name}") diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 5b61ae081..9664449f3 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -93,6 +93,7 @@ def build_image( build_file_path: Path, image_name: str, template_or_config: str, + run_config: str | None = None, ): container_base = build_config.distribution_spec.container_image or "python:3.10-slim" @@ -108,6 +109,11 @@ def build_image( container_base, " ".join(normal_deps), ] + + # When building from a config file (not a template), include the run config path in the + # build arguments + if run_config is not None: + args.append(run_config) elif build_config.image_type == LlamaStackImageType.CONDA.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") args = [ diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index fb4780432..ad316d45e 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} # mounting is not supported by docker buildx, so we use COPY instead USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} +# Path to the run.yaml file in the container +RUN_CONFIG_PATH=/app/run.yaml + +BUILD_CONTEXT_DIR=$(pwd) + if [ "$#" -lt 4 ]; then # This only works for templates - echo "Usage: $0 []" >&2 + echo "Usage: $0 [] []" >&2 exit 1 fi - set -euo pipefail template_or_config="$1" @@ -35,8 +39,27 @@ container_base="$1" shift pip_dependencies="$1" shift -special_pip_deps="${1:-}" +# Handle optional arguments +run_config="" +special_pip_deps="" + +# Check if there are more arguments +# The logics is becoming cumbersom, we should refactor it if we can do better +if [ $# -gt 0 ]; then + # Check if the argument ends with .yaml + if [[ "$1" == *.yaml ]]; then + run_config="$1" + shift + # If there's another argument after .yaml, it must be special_pip_deps + if [ $# -gt 0 ]; then + special_pip_deps="$1" + fi + else + # If it's not .yaml, it must be special_pip_deps + special_pip_deps="$1" + fi +fi # Define color codes RED='\033[0;31m' @@ -75,7 +98,7 @@ WORKDIR /app # We install the Python 3.11 dev headers and build tools so that any # C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully. -RUN dnf -y update && dnf install -y iputils net-tools wget \ +RUN dnf -y update && dnf install -y iputils git net-tools wget \ vim-minimal python3.11 python3.11-pip python3.11-wheel \ python3.11-setuptools python3.11-devel gcc make && \ ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all @@ -119,6 +142,45 @@ EOF done fi +# Function to get Python command +get_python_cmd() { + if is_command_available python; then + echo "python" + elif is_command_available python3; then + echo "python3" + else + echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2 + exit 1 + fi +} + +if [ -n "$run_config" ]; then + # Copy the run config to the build context since it's an absolute path + cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml" + add_to_container << EOF +COPY run.yaml $RUN_CONFIG_PATH +EOF + + # Parse the run.yaml configuration to identify external provider directories + # If external providers are specified, copy their directory to the container + # and update the configuration to reference the new container path + python_cmd=$(get_python_cmd) + external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')") + if [ -n "$external_providers_dir" ]; then + echo "Copying external providers directory: $external_providers_dir" + add_to_container << EOF +COPY $external_providers_dir /app/providers.d +EOF + # Edit the run.yaml file to change the external_providers_dir to /app/providers.d + if [ "$(uname)" = "Darwin" ]; then + sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" + rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak" + else + sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" + fi + fi +fi + stack_mount="/app/llama-stack-source" client_mount="/app/llama-stack-client-source" @@ -178,15 +240,16 @@ fi RUN pip uninstall -y uv EOF -# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag -if [[ "$template_or_config" != *.yaml ]]; then +# If a run config is provided, we use the --config flag +if [[ -n "$run_config" ]]; then + add_to_container << EOF +ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"] +EOF +# If a template is provided (not a yaml file), we use the --template flag +elif [[ "$template_or_config" != *.yaml ]]; then add_to_container << EOF ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"] EOF -else - add_to_container << EOF -ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"] -EOF fi # Add other require item commands genearic to all containers @@ -258,9 +321,10 @@ $CONTAINER_BINARY build \ "${CLI_ARGS[@]}" \ -t "$image_tag" \ -f "$TEMP_DIR/Containerfile" \ - "." + "$BUILD_CONTEXT_DIR" # clean up tmp/configs +rm -f "$BUILD_CONTEXT_DIR/run.yaml" set +x echo "Success!" diff --git a/tests/unit/distribution/test_build_path.py b/tests/unit/distribution/test_build_path.py index a913bd88b..555cdda4a 100644 --- a/tests/unit/distribution/test_build_path.py +++ b/tests/unit/distribution/test_build_path.py @@ -16,8 +16,9 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType def test_container_build_passes_path(monkeypatch, tmp_path): called_with = {} - def spy_build_image(cfg, build_file_path, image_name, template_or_config): + def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None): called_with["path"] = template_or_config + called_with["run_config"] = run_config return 0 monkeypatch.setattr( @@ -36,3 +37,4 @@ def test_container_build_passes_path(monkeypatch, tmp_path): assert "path" in called_with assert isinstance(called_with["path"], str) assert Path(called_with["path"]).exists() + assert called_with["run_config"] is None From e664ba91d87bf1735b3b4f1aae43772359c25ca3 Mon Sep 17 00:00:00 2001 From: Ilya Kolchinsky <58424190+ilya-kolchinsky@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:38:38 +0200 Subject: [PATCH 12/36] fix: prevent the knowledge search tool from confusing the model with long content (#1908) # What does this PR do? This PR addresses the content dominance problem that frequently arises with multiple models when executing queries with the RAG tool. When the retrieved content is too large, it disproportionately influences the generation process, causing the model to ignore the original question and to provide meaningless comments on the retrieved information instead. This situation is especially common with agentic RAG, which is the standard way of doing RAG in Llama Stack, since directly manipulating the prompt combining the query with the retrieved content is not possible. This PR appends a grounding message to the results returned by the knowledge search tool, reminding the model about the original query and the purpose of the inference call. This makes the problem significantly less likely to occur. ## Test Plan Running the following script before the fix demonstrates the content dominance problem where the model insists to comment on the retrieved content and refuses to address the question. Running the script after the fix results in getting the correct answer. ``` import os import uuid from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient # the server endpoint LLAMA_STACK_SERVER_URL = "http://localhost:8321" # inference settings MODEL_ID = ""meta-llama/Llama-3.1-8B-Instruct" SYSTEM_PROMPT = "You are a helpful assistant. " # RAG settings VECTOR_DB_EMBEDDING_MODEL = "all-MiniLM-L6-v2" VECTOR_DB_EMBEDDING_DIMENSION = 384 VECTOR_DB_CHUNK_SIZE = 512 # initialize the server connection client = LlamaStackClient(base_url=os.environ.get("LLAMA_STACK_ENDPOINT", LLAMA_STACK_SERVER_URL)) # init the RAG retrieval parameters vector_db_id = f"test_vector_db_{uuid.uuid4()}" vector_providers = [ provider for provider in client.providers.list() if provider.api == "vector_io" ] vector_provider_to_use = vector_providers[0] # define and register the document collection to be used client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=VECTOR_DB_EMBEDDING_MODEL, embedding_dimension=VECTOR_DB_EMBEDDING_DIMENSION, provider_id=vector_provider_to_use.provider_id, ) # ingest the documents into the newly created document collection urls = [ ("https://www.openshift.guide/openshift-guide-screen.pdf", "application/pdf"), ] documents = [ RAGDocument( document_id=f"num-{i}", content=url, mime_type=url_type, metadata={}, ) for i, (url, url_type) in enumerate(urls) ] client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, chunk_size_in_tokens=VECTOR_DB_CHUNK_SIZE, ) queries = [ "How to install OpenShift?", ] # initializing the agent agent = Agent( client, model=MODEL_ID, instructions=SYSTEM_PROMPT, # we make our agent aware of the RAG tool by including builtin::rag/knowledge_search in the list of tools tools=[ dict( name="builtin::rag/knowledge_search", args={ "vector_db_ids": [vector_db_id], # list of IDs of document collections to consider during retrieval }, ) ], ) for prompt in queries: print(f"User> {prompt}") # create a new turn with a new session ID for each prompt response = agent.create_turn( messages=[ { "role": "user", "content": prompt, } ], session_id=agent.create_session(f"rag-session_{uuid.uuid4()}") ) # print the response, including tool calls output for log in AgentEventLogger().log(response): print(log.content, end='') ``` --- llama_stack/providers/inline/tool_runtime/rag/memory.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 97c53d454..8d4689e5d 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -33,6 +33,7 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, make_overlapped_chunks, @@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) ) picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) + picked.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n', + ) + ) return RAGQueryResult( content=picked, From dc0d4763a013b560f0efe739923c413acbed866c Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Thu, 24 Apr 2025 09:24:07 -0600 Subject: [PATCH 13/36] chore: Update External Providers CI to not run on changes to docs, rfcs, and scripts (#2009) # What does this PR do? Update External Providers CI to not run on changes to docs, rfcs, and scripts [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --------- Signed-off-by: Francisco Javier Arceo --- .github/workflows/test-external-providers.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml index f7801c8d3..7ba5924e5 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -5,6 +5,14 @@ on: branches: [ main ] pull_request: branches: [ main ] + paths: + - 'distributions/**' + - 'llama_stack/**' + - 'tests/integration/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/test-external-providers.yml' # This workflow jobs: test-external-providers: From 70488abe9c57b67c171ced427999e1f6cf9a682f Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Thu, 24 Apr 2025 09:39:31 -0600 Subject: [PATCH 14/36] chore: Remove `distributions/**` from integration, external provider, and unit tests (#2018) # What does this PR do? Remove `distributions/**` from integration, external provider, and unit tests [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan N/A [//]: # (## Documentation) Signed-off-by: Francisco Javier Arceo --- .github/workflows/integration-tests.yml | 1 - .github/workflows/test-external-providers.yml | 1 - .github/workflows/unit-tests.yml | 1 - 3 files changed, 3 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 0eb252695..f54bed839 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -6,7 +6,6 @@ on: pull_request: branches: [ main ] paths: - - 'distributions/**' - 'llama_stack/**' - 'tests/integration/**' - 'uv.lock' diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml index 7ba5924e5..37f5c45ab 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -6,7 +6,6 @@ on: pull_request: branches: [ main ] paths: - - 'distributions/**' - 'llama_stack/**' - 'tests/integration/**' - 'uv.lock' diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 4b0c58b99..962141744 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -6,7 +6,6 @@ on: pull_request: branches: [ main ] paths: - - 'distributions/**' - 'llama_stack/**' - 'tests/unit/**' - 'uv.lock' From a5d6ab16b22c2cc8d774683992b20abf441ba82c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Apr 2025 11:27:49 -0700 Subject: [PATCH 15/36] fix: meta-reference parallel utils bug, use isinstance not equality --- .../providers/inline/inference/meta_reference/parallel_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 8752f06f3..9ffcf99fe 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -231,7 +231,7 @@ def worker_process_entrypoint( while True: try: task = req_gen.send(result) - if isinstance(task, str) and task == EndSentinel(): + if isinstance(task, EndSentinel): break assert isinstance(task, TaskRequest) From 7ed137e96310984e9a518beac21007e56bef881b Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 24 Apr 2025 13:03:35 -0700 Subject: [PATCH 16/36] fix: meta ref inference (#2022) MAX_BATCH_SIZE=10 LLAMA_MODELS_DEBUG=1 LLAMA_STACK_PORT=5002 LLAMA_STACK_LOGGING='all=info' llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=... LLAMA_STACK_CONFIG=http://localhost:5002/ pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang --- .../inline/inference/meta_reference/inference.py | 3 ++- .../inference/meta_reference/parallel_utils.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 0e69c2e7e..1bc098fab 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl( def impl(): stop_reason = None - for token_result in self.generator.completion(request): + for token_results in self.generator.completion([request]): + token_result = token_results[0] if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn text = "" diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 9ffcf99fe..8c0ffc632 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -69,7 +69,10 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] + task: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ] class TaskResponse(BaseModel): @@ -234,7 +237,7 @@ def worker_process_entrypoint( if isinstance(task, EndSentinel): break - assert isinstance(task, TaskRequest) + assert isinstance(task, TaskRequest), task result = model(task.task) except StopIteration: break @@ -331,7 +334,10 @@ class ModelParallelProcessGroup: def run_inference( self, - req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], + req: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ], ) -> Generator: assert not self.running, "inference already running" From c8797f1125cfded745f0688944d783355b4cfc07 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Fri, 25 Apr 2025 00:59:10 +0100 Subject: [PATCH 17/36] fix: Including tool call in chat (#1931) Include the tool call details with the chat when doing Rag with Remote vllm Fixes: #1929 With this PR the tool call is included in the chat returned to vllm, the model (meta-llama/Llama-3.1-8B-Instruct) the returns the answer as expected. Signed-off-by: Derek Higgins --- .../utils/inference/openai_compat.py | 17 ++++++- .../providers/inference/test_remote_vllm.py | 48 ++++++++++++++++++- .../utils/inference/test_openai_compat.py | 43 +++++++++++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 tests/unit/providers/utils/inference/test_openai_compat.py diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f91e7d7dc..4d690287b 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals else: content = [await _convert_content(message.content)] - return { + result = { "role": message.role, "content": content, } + if hasattr(message, "tool_calls") and message.tool_calls: + result["tool_calls"] = [] + for tc in message.tool_calls: + result["tool_calls"].append( + { + "id": tc.call_id, + "type": "function", + "function": { + "name": tc.tool_name, + "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), + }, + } + ) + return result + class UnparseableToolCall(BaseModel): """ diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 88399198d..b3172cad4 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( ChatCompletionRequest, + CompletionMessage, + SystemMessage, ToolChoice, ToolConfig, + ToolResponseMessage, UserMessage, ) from llama_stack.apis.models import Model -from llama_stack.models.llama.datatypes import StopReason +from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.vllm import ( VLLMInferenceAdapter, @@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter): assert request.tool_config.tool_choice == ToolChoice.none +@pytest.mark.asyncio +async def test_tool_call_response(vllm_inference_adapter): + """Verify that tool call arguments from a CompletionMessage are correctly converted + into the expected JSON format.""" + + # Patch the call to vllm so we can inspect the arguments sent were correct + with patch.object( + vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock + ) as mock_nonstream_completion: + messages = [ + SystemMessage(content="You are a helpful assistant"), + UserMessage(content="How many?"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="foo", + tool_name="knowledge_search", + arguments={"query": "How many?"}, + arguments_json='{"query": "How many?"}', + ) + ], + ), + ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."), + ] + await vllm_inference_adapter.chat_completion( + "mock-model", + messages, + stream=False, + tools=[], + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [ + { + "id": "foo", + "type": "function", + "function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'}, + } + ] + + @pytest.mark.asyncio async def test_tool_call_delta_empty_tool_call_buf(): """ diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py new file mode 100644 index 000000000..eb02f8203 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.common.content_types import TextContentItem +from llama_stack.apis.inference.inference import CompletionMessage, UserMessage +from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict + + +@pytest.mark.asyncio +async def test_convert_message_to_openai_dict(): + message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") + assert await convert_message_to_openai_dict(message) == { + "role": "user", + "content": [{"type": "text", "text": "Hello, world!"}], + } + + +# Test convert_message_to_openai_dict with a tool call +@pytest.mark.asyncio +async def test_convert_message_to_openai_dict_with_tool_call(): + message = CompletionMessage( + content="", + tool_calls=[ + ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"}) + ], + stop_reason=StopReason.end_of_turn, + ) + + openai_dict = await convert_message_to_openai_dict(message) + + assert openai_dict == { + "role": "assistant", + "content": [{"type": "text", "text": ""}], + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} + ], + } From 0b6cd45950c37bdd210a62a6bd67479c035eddca Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 24 Apr 2025 20:01:45 -0400 Subject: [PATCH 18/36] fix: Additional streaming error handling (#2007) # What does this PR do? This expands the `test_sse` test suite and fixes some edge cases with bugs in our SSE error handling to ensure streaming clients always get a proper error response. First, we handle the case where a client disconnects before we actually start streaming the response back. Previously we only handled the case where a client disconnected as we were streaming the response, but there was an edge case where a client disconnecting before we streamed any response back did not trigger our logic to cleanly handle that disconnect. Second, we handle the case where an error is thrown from the server before the actual async generator gets created from the provider. This happens in scenarios like the newly merged OpenAI API input validation, where we eagerly raise validation errors before returning the async generator object that streams the responses back. ## Test Plan Tested via: ``` python -m pytest -s -v tests/unit/server/test_sse.py ``` Both test cases failed before, and passed afterwards. The test cases were written based on me experimenting with actual clients that would do bad things like randomly disconnect or send invalid input in streaming mode and I hit these two cases, where things were misbehaving in our error handling. Signed-off-by: Ben Browning --- llama_stack/distribution/server/server.py | 6 ++-- tests/unit/server/test_sse.py | 38 ++++++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2942920d4..02f82498b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -166,14 +166,16 @@ async def maybe_await(value): async def sse_generator(event_gen_coroutine): - event_gen = await event_gen_coroutine + event_gen = None try: + event_gen = await event_gen_coroutine async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") - await event_gen.aclose() + if event_gen: + await event_gen.aclose() except Exception as e: logger.exception("Error in sse_generator") yield create_sse_event( diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py index 4a76bdc9b..c78122294 100644 --- a/tests/unit/server/test_sse.py +++ b/tests/unit/server/test_sse.py @@ -47,9 +47,45 @@ async def test_sse_generator_client_disconnected(): sse_gen = sse_generator(async_event_gen()) assert sse_gen is not None - # Start reading the events, ensuring this doesn't raise an exception seen_events = [] async for event in sse_gen: seen_events.append(event) + + # We should see 1 event before the client disconnected assert len(seen_events) == 1 assert seen_events[0] == create_sse_event("Test event 1") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected_before_response_starts(): + # Disconnect before the response starts + async def async_event_gen(): + raise asyncio.CancelledError() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # No events should be seen since the client disconnected immediately + assert len(seen_events) == 0 + + +@pytest.mark.asyncio +async def test_sse_generator_error_before_response_starts(): + # Raise an error before the response starts + async def async_event_gen(): + raise Exception("Test error") + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # We should have 1 error event + assert len(seen_events) == 1 + assert 'data: {"error":' in seen_events[0] From cc77f79f552ed9d787cccfef491951d1ab102536 Mon Sep 17 00:00:00 2001 From: Jash Gulabrai <37194352+JashG@users.noreply.github.com> Date: Thu, 24 Apr 2025 20:12:42 -0400 Subject: [PATCH 19/36] feat: Add NVIDIA Eval integration (#1890) # What does this PR do? This PR adds support for NVIDIA's NeMo Evaluator API to the Llama Stack eval module. The integration enables users to evaluate models via the Llama Stack interface. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] 1. Added unit tests and successfully ran from root of project: `./scripts/unit-tests.sh tests/unit/providers/nvidia/test_eval.py` ``` tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_cancel PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_result PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_job_status PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_register_benchmark PASSED tests/unit/providers/nvidia/test_eval.py::TestNVIDIAEvalImpl::test_run_eval PASSED ``` 2. Verified I could build the Llama Stack image: `LLAMA_STACK_DIR=$(pwd) llama stack build --template nvidia --image-type venv` Documentation added to `llama_stack/providers/remote/eval/nvidia/README.md` --------- Co-authored-by: Jash Gulabrai --- .../self_hosted_distro/nvidia.md | 3 +- llama_stack/providers/registry/eval.py | 20 +- llama_stack/providers/remote/eval/__init__.py | 5 + .../providers/remote/eval/nvidia/README.md | 134 ++++++++++++ .../providers/remote/eval/nvidia/__init__.py | 31 +++ .../providers/remote/eval/nvidia/config.py | 29 +++ .../providers/remote/eval/nvidia/eval.py | 154 ++++++++++++++ llama_stack/templates/dependencies.json | 4 - llama_stack/templates/nvidia/build.yaml | 4 +- llama_stack/templates/nvidia/nvidia.py | 18 +- .../templates/nvidia/run-with-safety.yaml | 9 +- llama_stack/templates/nvidia/run.yaml | 9 +- tests/unit/providers/nvidia/test_eval.py | 201 ++++++++++++++++++ 13 files changed, 598 insertions(+), 23 deletions(-) create mode 100644 llama_stack/providers/remote/eval/__init__.py create mode 100644 llama_stack/providers/remote/eval/nvidia/README.md create mode 100644 llama_stack/providers/remote/eval/nvidia/__init__.py create mode 100644 llama_stack/providers/remote/eval/nvidia/config.py create mode 100644 llama_stack/providers/remote/eval/nvidia/eval.py create mode 100644 tests/unit/providers/nvidia/test_eval.py diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 0922cb512..147c5b2ae 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov |-----|-------------| | agents | `inline::meta-reference` | | datasetio | `inline::localfs` | -| eval | `inline::meta-reference` | +| eval | `remote::nvidia` | | inference | `remote::nvidia` | | post_training | `remote::nvidia` | | safety | `remote::nvidia` | @@ -29,6 +29,7 @@ The following environment variables can be configured: - `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`) - `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`) - `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`) +- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`) - `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`) - `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`) diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index f3e42c531..9604d5da4 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec def available_providers() -> List[ProviderSpec]: @@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]: Api.agents, ], ), + remote_provider_spec( + api=Api.eval, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[ + "requests", + ], + module="llama_stack.providers.remote.eval.nvidia", + config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", + ), + api_dependencies=[ + Api.datasetio, + Api.datasets, + Api.scoring, + Api.inference, + Api.agents, + ], + ), ] diff --git a/llama_stack/providers/remote/eval/__init__.py b/llama_stack/providers/remote/eval/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/remote/eval/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/remote/eval/nvidia/README.md b/llama_stack/providers/remote/eval/nvidia/README.md new file mode 100644 index 000000000..cebc77920 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/README.md @@ -0,0 +1,134 @@ +# NVIDIA NeMo Evaluator Eval Provider + + +## Overview + +For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used. + +Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation. + +### Example for register an academic benchmark + +``` +POST /eval/benchmarks +``` +```json +{ + "benchmark_id": "mmlu", + "dataset_id": "", + "scoring_functions": [], + "metadata": { + "type": "mmlu" + } +} +``` + +### Example for register a custom evaluation + +``` +POST /eval/benchmarks +``` +```json +{ + "benchmark_id": "my-custom-benchmark", + "dataset_id": "", + "scoring_functions": [], + "metadata": { + "type": "custom", + "params": { + "parallelism": 8 + }, + "tasks": { + "qa": { + "type": "completion", + "params": { + "template": { + "prompt": "{{prompt}}", + "max_tokens": 200 + } + }, + "dataset": { + "files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl" + }, + "metrics": { + "bleu": { + "type": "bleu", + "params": { + "references": [ + "{{ideal_response}}" + ] + } + } + } + } + } + } +} +``` + +### Example for triggering a benchmark/custom evaluation + +``` +POST /eval/benchmarks/{benchmark_id}/jobs +``` +```json +{ + "benchmark_id": "my-custom-benchmark", + "benchmark_config": { + "eval_candidate": { + "type": "model", + "model": "meta-llama/Llama3.1-8B-Instruct", + "sampling_params": { + "max_tokens": 100, + "temperature": 0.7 + } + }, + "scoring_params": {} + } +} +``` + +Response example: +```json +{ + "job_id": "eval-1234", + "status": "in_progress" +} +``` + +### Example for getting the status of a job +``` +GET /eval/benchmarks/{benchmark_id}/jobs/{job_id} +``` + +Response example: +```json +{ + "job_id": "eval-1234", + "status": "in_progress" +} +``` + +### Example for cancelling a job +``` +POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel +``` + +### Example for getting the results +``` +GET /eval/benchmarks/{benchmark_id}/results +``` +```json +{ + "generations": [], + "scores": { + "{benchmark_id}": { + "score_rows": [], + "aggregated_results": { + "tasks": {}, + "groups": {} + } + } + } +} +``` diff --git a/llama_stack/providers/remote/eval/nvidia/__init__.py b/llama_stack/providers/remote/eval/nvidia/__init__.py new file mode 100644 index 000000000..8abbec9b2 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict + +from llama_stack.distribution.datatypes import Api + +from .config import NVIDIAEvalConfig + + +async def get_adapter_impl( + config: NVIDIAEvalConfig, + deps: Dict[Api, Any], +): + from .eval import NVIDIAEvalImpl + + impl = NVIDIAEvalImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + deps[Api.scoring], + deps[Api.inference], + deps[Api.agents], + ) + await impl.initialize() + return impl + + +__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"] diff --git a/llama_stack/providers/remote/eval/nvidia/config.py b/llama_stack/providers/remote/eval/nvidia/config.py new file mode 100644 index 000000000..b660fcd68 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import os +from typing import Any, Dict + +from pydantic import BaseModel, Field + + +class NVIDIAEvalConfig(BaseModel): + """ + Configuration for the NVIDIA NeMo Evaluator microservice endpoint. + + Attributes: + evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000. + """ + + evaluator_url: str = Field( + default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"), + description="The url for accessing the evaluator service", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}", + } diff --git a/llama_stack/providers/remote/eval/nvidia/eval.py b/llama_stack/providers/remote/eval/nvidia/eval.py new file mode 100644 index 000000000..e1a3b5355 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/eval.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.agents import Agents +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.inference import Inference +from llama_stack.apis.scoring import Scoring, ScoringResult +from llama_stack.providers.datatypes import BenchmarksProtocolPrivate +from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper + +from .....apis.common.job_types import Job, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse +from .config import NVIDIAEvalConfig + +DEFAULT_NAMESPACE = "nvidia" + + +class NVIDIAEvalImpl( + Eval, + BenchmarksProtocolPrivate, + ModelRegistryHelper, +): + def __init__( + self, + config: NVIDIAEvalConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + scoring_api: Scoring, + inference_api: Inference, + agents_api: Agents, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.scoring_api = scoring_api + self.inference_api = inference_api + self.agents_api = agents_api + + ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def _evaluator_get(self, path): + """Helper for making GET requests to the evaluator service.""" + response = requests.get(url=f"{self.config.evaluator_url}{path}") + response.raise_for_status() + return response.json() + + async def _evaluator_post(self, path, data): + """Helper for making POST requests to the evaluator service.""" + response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data) + response.raise_for_status() + return response.json() + + async def register_benchmark(self, task_def: Benchmark) -> None: + """Register a benchmark as an evaluation configuration.""" + await self._evaluator_post( + "/v1/evaluation/configs", + { + "namespace": DEFAULT_NAMESPACE, + "name": task_def.benchmark_id, + # metadata is copied to request body as-is + **task_def.metadata, + }, + ) + + async def run_eval( + self, + benchmark_id: str, + benchmark_config: BenchmarkConfig, + ) -> Job: + """Run an evaluation job for a benchmark.""" + model = ( + benchmark_config.eval_candidate.model + if benchmark_config.eval_candidate.type == "model" + else benchmark_config.eval_candidate.config.model + ) + nvidia_model = self.get_provider_model_id(model) or model + + result = await self._evaluator_post( + "/v1/evaluation/jobs", + { + "config": f"{DEFAULT_NAMESPACE}/{benchmark_id}", + "target": {"type": "model", "model": nvidia_model}, + }, + ) + + return Job(job_id=result["id"], status=JobStatus.in_progress) + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + raise NotImplementedError() + + async def job_status(self, benchmark_id: str, job_id: str) -> Job: + """Get the status of an evaluation job. + + EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed". + JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed" + """ + result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}") + result_status = result["status"] + + job_status = JobStatus.failed + if result_status in ["created", "pending"]: + job_status = JobStatus.scheduled + elif result_status in ["running"]: + job_status = JobStatus.in_progress + elif result_status in ["completed"]: + job_status = JobStatus.completed + elif result_status in ["cancelled"]: + job_status = JobStatus.cancelled + + return Job(job_id=job_id, status=job_status) + + async def job_cancel(self, benchmark_id: str, job_id: str) -> None: + """Cancel the evaluation job.""" + await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {}) + + async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: + """Returns the results of the evaluation job.""" + + job = await self.job_status(benchmark_id, job_id) + status = job.status + if not status or status != JobStatus.completed: + raise ValueError(f"Job {job_id} not completed. Status: {status.value}") + + result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results") + + return EvaluateResponse( + # TODO: these are stored in detailed results on NeMo Evaluator side; can be added + generations=[], + scores={ + benchmark_id: ScoringResult( + score_rows=[], + aggregated_results=result, + ) + }, + ) diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index b96191752..63c4ecfa5 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -394,12 +394,10 @@ "aiosqlite", "blobfile", "chardet", - "emoji", "faiss-cpu", "fastapi", "fire", "httpx", - "langdetect", "matplotlib", "nltk", "numpy", @@ -411,7 +409,6 @@ "psycopg2-binary", "pymongo", "pypdf", - "pythainlp", "redis", "requests", "scikit-learn", @@ -419,7 +416,6 @@ "sentencepiece", "tqdm", "transformers", - "tree_sitter", "uvicorn" ], "ollama": [ diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index f99ff6c81..a33fa3737 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use NVIDIA NIM for running LLM inference and safety + description: Use NVIDIA NIM for running LLM inference, evaluation and safety providers: inference: - remote::nvidia @@ -13,7 +13,7 @@ distribution_spec: telemetry: - inline::meta-reference eval: - - inline::meta-reference + - remote::nvidia post_training: - remote::nvidia datasetio: diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index a0cefba52..32ddf78e3 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -7,6 +7,7 @@ from pathlib import Path from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig @@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate: "safety": ["remote::nvidia"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], + "eval": ["remote::nvidia"], "post_training": ["remote::nvidia"], "datasetio": ["inline::localfs"], "scoring": ["inline::basic"], @@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="remote::nvidia", config=NVIDIASafetyConfig.sample_run_config(), ) + eval_provider = Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAEvalConfig.sample_run_config(), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="nvidia", @@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name="nvidia", distro_type="self_hosted", - description="Use NVIDIA NIM for running LLM inference and safety", + description="Use NVIDIA NIM for running LLM inference, evaluation and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider], + "eval": [eval_provider], }, default_models=default_models, default_tool_groups=default_tool_groups, @@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate: "inference": [ inference_provider, safety_provider, - ] + ], + "eval": [eval_provider], }, default_models=[inference_model, safety_model], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")], @@ -119,6 +127,10 @@ def get_distribution_template() -> DistributionTemplate: "http://0.0.0.0:7331", "URL for the NeMo Guardrails Service", ), + "NVIDIA_EVALUATOR_URL": ( + "http://0.0.0.0:7331", + "URL for the NeMo Evaluator Service", + ), "INFERENCE_MODEL": ( "Llama3.1-8B-Instruct", "Inference model", diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 658d9377e..8483fb9bf 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -53,13 +53,10 @@ providers: sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db + evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331} post_training: - provider_id: nvidia provider_type: remote::nvidia diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index ff548d82e..d7e2753ba 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -48,13 +48,10 @@ providers: sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db + evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331} post_training: - provider_id: nvidia provider_type: remote::nvidia diff --git a/tests/unit/providers/nvidia/test_eval.py b/tests/unit/providers/nvidia/test_eval.py new file mode 100644 index 000000000..584ca2101 --- /dev/null +++ b/tests/unit/providers/nvidia/test_eval.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.common.job_types import Job, JobStatus +from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams +from llama_stack.models.llama.sku_types import CoreModelId +from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig +from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl + +MOCK_DATASET_ID = "default/test-dataset" +MOCK_BENCHMARK_ID = "test-benchmark" + + +class TestNVIDIAEvalImpl(unittest.TestCase): + def setUp(self): + os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" + + # Create mock APIs + self.datasetio_api = MagicMock() + self.datasets_api = MagicMock() + self.scoring_api = MagicMock() + self.inference_api = MagicMock() + self.agents_api = MagicMock() + + self.config = NVIDIAEvalConfig( + evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], + ) + + self.eval_impl = NVIDIAEvalImpl( + config=self.config, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + scoring_api=self.scoring_api, + inference_api=self.inference_api, + agents_api=self.agents_api, + ) + + # Mock the HTTP request methods + self.evaluator_get_patcher = patch( + "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get" + ) + self.evaluator_post_patcher = patch( + "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post" + ) + + self.mock_evaluator_get = self.evaluator_get_patcher.start() + self.mock_evaluator_post = self.evaluator_post_patcher.start() + + def tearDown(self): + """Clean up after each test.""" + self.evaluator_get_patcher.stop() + self.evaluator_post_patcher.stop() + + def _assert_request_body(self, expected_json): + """Helper method to verify request body in Evaluator POST request is correct""" + call_args = self.mock_evaluator_post.call_args + actual_json = call_args[0][1] + + # Check that all expected keys contain the expected values in the actual JSON + for key, value in expected_json.items(): + assert key in actual_json, f"Key '{key}' missing in actual JSON" + + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']" + assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'" + else: + assert actual_json[key] == value, f"Value mismatch for '{key}'" + + @pytest.fixture(autouse=True) + def inject_fixtures(self, run_async): + self.run_async = run_async + + def test_register_benchmark(self): + eval_config = { + "type": "custom", + "params": {"parallelism": 8}, + "tasks": { + "qa": { + "type": "completion", + "params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}}, + "dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"}, + "metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}}, + } + }, + } + + benchmark = Benchmark( + provider_id="nvidia", + type="benchmark", + identifier=MOCK_BENCHMARK_ID, + dataset_id=MOCK_DATASET_ID, + scoring_functions=["basic::equality"], + metadata=eval_config, + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Register the benchmark + self.run_async(self.eval_impl.register_benchmark(benchmark)) + + # Verify the Evaluator API was called correctly + self.mock_evaluator_post.assert_called_once() + self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}) + + def test_run_eval(self): + benchmark_config = BenchmarkConfig( + eval_candidate=ModelCandidate( + type="model", + model=CoreModelId.llama3_1_8b_instruct.value, + sampling_params=SamplingParams(max_tokens=100, temperature=0.7), + ) + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "created"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Run the Evaluation job + result = self.run_async( + self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config) + ) + + # Verify the Evaluator API was called correctly + self.mock_evaluator_post.assert_called_once() + self._assert_request_body( + { + "config": f"nvidia/{MOCK_BENCHMARK_ID}", + "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, + } + ) + + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.in_progress + + def test_job_status(self): + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "completed"} + self.mock_evaluator_get.return_value = mock_evaluator_response + + # Get the Evaluation job + result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.completed + + # Verify the API was called correctly + self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}") + + def test_job_cancel(self): + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "cancelled"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Cancel the Evaluation job + self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the API was called correctly + self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) + + def test_job_result(self): + # Mock Evaluator API responses + mock_job_status_response = {"id": "job-123", "status": "completed"} + mock_job_results_response = { + "id": "job-123", + "status": "completed", + "results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}}, + } + self.mock_evaluator_get.side_effect = [ + mock_job_status_response, # First call to retrieve job + mock_job_results_response, # Second call to retrieve job results + ] + + # Get the Evaluation job results + result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the result + assert isinstance(result, EvaluateResponse) + assert MOCK_BENCHMARK_ID in result.scores + assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85 + + # Verify the API was called correctly + assert self.mock_evaluator_get.call_count == 2 + self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123") + self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results") From ace82836c14b4bd5380a14149047013332672bc3 Mon Sep 17 00:00:00 2001 From: Rashmi Pawar <168514198+raspawar@users.noreply.github.com> Date: Fri, 25 Apr 2025 05:43:33 +0530 Subject: [PATCH 20/36] feat: NVIDIA allow non-llama model registration (#1859) # What does this PR do? Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on: - post-training model - non-llama models in API Catalogue(behind https://integrate.api.nvidia.com and endpoints compatible with AyncOpenAI) ## Example Usage: ```python from llama_stack.apis.models import Model, ModelType from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") _ = client.initialize() client.models.register( model_id=model_name, model_type=ModelType.llm, provider_id="nvidia" ) response = client.inference.chat_completion( model_id=model_name, messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}], ) ``` ## Test Plan ```bash pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0 collected 6 items tests/unit/providers/nvidia/test_supervised_fine_tuning.py ...... [100%] ============================================================ warnings summary ============================================================ ../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076 /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/ warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================== 6 passed, 1 warning in 1.51s ====================================================== ``` [//]: # (## Documentation) Updated Readme.md cc: @dglogo, @sumitb, @mattf --- .../self_hosted_distro/nvidia.md | 3 +- .../remote/inference/nvidia/config.py | 5 ++ .../remote/inference/nvidia/nvidia.py | 52 +++++++++++++++++-- .../remote/post_training/nvidia/README.md | 16 +++++- llama_stack/templates/nvidia/nvidia.py | 12 ++--- .../templates/nvidia/run-with-safety.yaml | 1 + llama_stack/templates/nvidia/run.yaml | 1 + .../nvidia/test_supervised_fine_tuning.py | 41 +++++++++++++++ 8 files changed, 116 insertions(+), 15 deletions(-) diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 147c5b2ae..4407de779 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -22,9 +22,8 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov The following environment variables can be configured: - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) -- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`) +- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`) - `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`) -- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`) - `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`) - `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`) - `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`) diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index abd34b498..8f80408d4 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel): default=60, description="Timeout for the HTTP requests", ) + append_api_version: bool = Field( + default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false", + description="When set to false, the API version will not be appended to the base_url. By default, it is true.", + ) @classmethod def sample_run_config(cls, **kwargs) -> Dict[str, Any]: return { "url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}", "api_key": "${env.NVIDIA_API_KEY:}", + "append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}", } diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index c91b4d768..4a62ad6cb 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -33,7 +33,6 @@ from llama_stack.apis.inference import ( TextTruncation, ToolChoice, ToolConfig, - ToolDefinition, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, ) -from llama_stack.models.llama.datatypes import ToolPromptFormat +from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat +from llama_stack.providers.utils.inference import ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", } - base_url = f"{self._config.url}/v1" + base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url + if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: base_url = special_model_urls[provider_model_id] - return _get_client_for_base_url(base_url) async def _get_provider_model_id(self, model_id: str) -> str: @@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return await self._get_client(provider_model_id).chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e + + async def register_model(self, model: Model) -> Model: + """ + Allow non-llama model registration. + + Non-llama model registration: API Catalogue models, post-training models, etc. + client = LlamaStackAsLibraryClient("nvidia") + client.models.register( + model_id="mistralai/mixtral-8x7b-instruct-v0.1", + model_type=ModelType.llm, + provider_id="nvidia", + provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1" + ) + + NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format. + """ + if model.model_type == ModelType.embedding: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + + if provider_resource_id: + model.provider_resource_id = provider_resource_id + else: + llama_model = model.metadata.get("llama_model") + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != llama_model: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + # not llama model + if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] + ) + else: + self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id + return model diff --git a/llama_stack/providers/remote/post_training/nvidia/README.md b/llama_stack/providers/remote/post_training/nvidia/README.md index 230587d66..3ef538d29 100644 --- a/llama_stack/providers/remote/post_training/nvidia/README.md +++ b/llama_stack/providers/remote/post_training/nvidia/README.md @@ -36,7 +36,6 @@ import os os.environ["NVIDIA_API_KEY"] = "your-api-key" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" -os.environ["NVIDIA_USER_ID"] = "llama-stack-user" os.environ["NVIDIA_DATASET_NAMESPACE"] = "default" os.environ["NVIDIA_PROJECT_ID"] = "test-project" os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1" @@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id") ### Inference with the fine-tuned model +#### 1. Register the model + +```python +from llama_stack.apis.models import Model, ModelType + +client.models.register( + model_id="test-example-model@v1", + provider_id="nvidia", + provider_model_id="test-example-model@v1", + model_type=ModelType.llm, +) +``` + +#### 2. Inference with the fine-tuned model + ```python response = client.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 32ddf78e3..463c13879 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -98,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate: "", "NVIDIA API Key", ), - ## Nemo Customizer related variables - "NVIDIA_USER_ID": ( - "llama-stack-user", - "NVIDIA User ID", + "NVIDIA_APPEND_API_VERSION": ( + "True", + "Whether to append the API version to the base_url", ), + ## Nemo Customizer related variables "NVIDIA_DATASET_NAMESPACE": ( "default", "NVIDIA Dataset Namespace", ), - "NVIDIA_ACCESS_POLICIES": ( - "{}", - "NVIDIA Access Policies", - ), "NVIDIA_PROJECT_ID": ( "test-project", "NVIDIA Project ID", diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 8483fb9bf..a3e5fefa4 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -18,6 +18,7 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True} - provider_id: nvidia provider_type: remote::nvidia config: diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index d7e2753ba..271ce1a16 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -18,6 +18,7 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True} vector_io: - provider_id: faiss provider_type: inline::faiss diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 43e0ac11c..09f67e4e6 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -17,6 +17,8 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import ( TrainingConfigOptimizerConfig, ) +from llama_stack.apis.models import Model, ModelType +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter from llama_stack.providers.remote.post_training.nvidia.post_training import ( ListNvidiaPostTrainingJobs, NvidiaPostTrainingAdapter, @@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase): ) self.mock_make_request = self.make_request_patcher.start() + # Mock the inference client + inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) + self.inference_adapter = NVIDIAInferenceAdapter(inference_config) + + self.mock_client = unittest.mock.MagicMock() + self.mock_client.chat.completions.create = unittest.mock.AsyncMock() + self.inference_mock_make_request = self.mock_client.chat.completions.create + self.inference_make_request_patcher = patch( + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + return_value=self.mock_client, + ) + self.inference_make_request_patcher.start() + def tearDown(self): self.make_request_patcher.stop() + self.inference_make_request_patcher.stop() @pytest.fixture(autouse=True) def inject_fixtures(self, run_async): @@ -303,6 +319,31 @@ class TestNvidiaPostTraining(unittest.TestCase): expected_params={"job_id": job_id}, ) + def test_inference_register_model(self): + model_id = "default/job-1234" + model_type = ModelType.llm + model = Model( + identifier=model_id, + provider_id="nvidia", + provider_model_id=model_id, + provider_resource_id=model_id, + model_type=model_type, + ) + result = self.run_async(self.inference_adapter.register_model(model)) + assert result == model + assert len(self.inference_adapter.alias_to_provider_id_map) > 1 + assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id + + with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: + self.run_async( + self.inference_adapter.chat_completion( + model_id=model_id, + messages=[{"role": "user", "content": "Hello, model"}], + ) + ) + + mock_chat_completion.assert_called() + if __name__ == "__main__": unittest.main() From d9e00fca66ac3278464ebf2d733fc51c3bab851e Mon Sep 17 00:00:00 2001 From: Kevin Postlethwait Date: Fri, 25 Apr 2025 04:10:37 -0400 Subject: [PATCH 21/36] fix: specify nbformat version in nb (#2023) # What does this PR do? Adding nbformat version fixes this issue. Not sure exactly why this needs to be done, but this version was rewritten to the bottom of a nb file when I changed its name trying to get to the bottom of this. When I opened it on GH the issue was no longer present Closes #1837 ## Test Plan N/A --- docs/zero_to_hero_guide/00_Inference101.ipynb | 4 +++- docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb | 4 +++- docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb | 4 +++- docs/zero_to_hero_guide/03_Image_Chat101.ipynb | 4 +++- docs/zero_to_hero_guide/04_Tool_Calling101.ipynb | 4 +++- docs/zero_to_hero_guide/05_Memory101.ipynb | 4 +++- docs/zero_to_hero_guide/06_Safety101.ipynb | 4 +++- docs/zero_to_hero_guide/07_Agents101.ipynb | 4 +++- 8 files changed, 24 insertions(+), 8 deletions(-) diff --git a/docs/zero_to_hero_guide/00_Inference101.ipynb b/docs/zero_to_hero_guide/00_Inference101.ipynb index b3b781375..4f71f9f89 100644 --- a/docs/zero_to_hero_guide/00_Inference101.ipynb +++ b/docs/zero_to_hero_guide/00_Inference101.ipynb @@ -389,5 +389,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb b/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb index d66e1b4f5..19a7fe3be 100644 --- a/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb +++ b/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb @@ -256,5 +256,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb b/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb index 7fccf8c51..f3566eeb3 100644 --- a/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb +++ b/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb @@ -301,5 +301,7 @@ "pygments_lexer": "ipython3", "version": "3.12.2" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/03_Image_Chat101.ipynb b/docs/zero_to_hero_guide/03_Image_Chat101.ipynb index 58353e813..ae10d8808 100644 --- a/docs/zero_to_hero_guide/03_Image_Chat101.ipynb +++ b/docs/zero_to_hero_guide/03_Image_Chat101.ipynb @@ -200,5 +200,7 @@ "pygments_lexer": "ipython3", "version": "3.12.2" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb b/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb index c3a383e8c..de3754b21 100644 --- a/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb +++ b/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb @@ -355,5 +355,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/05_Memory101.ipynb b/docs/zero_to_hero_guide/05_Memory101.ipynb index bfeb40adc..66956259f 100644 --- a/docs/zero_to_hero_guide/05_Memory101.ipynb +++ b/docs/zero_to_hero_guide/05_Memory101.ipynb @@ -398,5 +398,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index c8c1fe9c7..5d7763924 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -132,5 +132,7 @@ "pygments_lexer": "ipython3", "version": "3.11.10" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/07_Agents101.ipynb b/docs/zero_to_hero_guide/07_Agents101.ipynb index 8c988e1e3..b6df2a4c8 100644 --- a/docs/zero_to_hero_guide/07_Agents101.ipynb +++ b/docs/zero_to_hero_guide/07_Agents101.ipynb @@ -188,5 +188,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } From 59b759360937bd8592fec30e2e0a46acd8cfa27f Mon Sep 17 00:00:00 2001 From: Surya Prakash Pathak Date: Fri, 25 Apr 2025 01:22:22 -0700 Subject: [PATCH 22/36] feat: Enhance tool display in Tools sidebar by simplifying tool identifiers (#2024) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This PR improves the Tools page in the LlamaStack Playground UI by enhancing the readability of the active tool list shown in the sidebar. - Previously, active tools were displayed in a flat JSON array with verbose identifiers (e.g., builtin::code_interpreter:code_interpreter). - This PR updates the logic to group tools by their toolgroup (e.g., builtin::websearch) and renders each tool name in a simplified, human-readable format (e.g., web_search). - This change improves usability when working with multiple toolgroups, especially in configurations involving MCP tools or complex tool identifiers. Before and After Comparison: **Before** ![Screenshot 2025-04-24 at 1 05 47β€―PM](https://github.com/user-attachments/assets/44843a79-49dc-4b4d-ab28-c6187f9bb5ba) **After** ![Screenshot 2025-04-24 at 1 24 08β€―PM](https://github.com/user-attachments/assets/ebb01006-e0a9-4664-a95a-e6f72eea6f94) [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan - Followed the [LlamaStack UI Developer Setup instructions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/distribution/ui) - Ran the Streamlit UI via: `uv run --with "[.ui]" streamlit run llama_stack/distribution/ui/app.py` - Selected multiple built-in toolgroups (e.g., code_interpreter, websearch, wolfram_alpha) from the sidebar. [//]: # (## Documentation) --- .../distribution/ui/page/playground/tools.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index 96c6a1783..5e19c1e4f 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -66,17 +66,20 @@ def tool_chat_page(): toolgroup_selection.extend(mcp_selection) - active_tool_list = [] - for toolgroup_id in toolgroup_selection: - active_tool_list.extend( - [ - f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}" - for t in client.tools.list(toolgroup_id=toolgroup_id) - ] - ) + grouped_tools = {} + total_tools = 0 - st.markdown(f"Active Tools: πŸ›  {len(active_tool_list)}", help="List of currently active tools.") - st.json(active_tool_list) + for toolgroup_id in toolgroup_selection: + tools = client.tools.list(toolgroup_id=toolgroup_id) + grouped_tools[toolgroup_id] = [tool.identifier for tool in tools] + total_tools += len(tools) + + st.markdown(f"Active Tools: πŸ›  {total_tools}") + + for group_id, tools in grouped_tools.items(): + with st.expander(f"πŸ”§ Tools from `{group_id}`"): + for idx, tool in enumerate(tools, start=1): + st.markdown(f"{idx}. `{tool.split(':')[-1]}`") st.subheader("Agent Configurations") max_tokens = st.slider( From 121c73c2f52a42016da065f6af84f12a67107922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Hu=C3=9F?= Date: Fri, 25 Apr 2025 16:57:42 +0200 Subject: [PATCH 23/36] feat(cli): add interactive tab completion for image type selection (#2027) # What does this PR do? Enhances the user experience in the `llama stack build` command by adding interactive TAB completion for image type selection. This ensures the UX consistency with other parts of the CLI that already support tab completion, such as provider selection, providing a more intuitive and discoverable interface for users. image --- llama_stack/cli/stack/_build.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 80ab0631b..2787a93d5 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) image_type = prompt( - f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ", + "> Enter the image type you want your Llama Stack to be built as (use to see options): ", + completer=WordCompleter([e.value for e in ImageType]), + complete_while_typing=True, validator=Validator.from_callable( lambda x: x in [e.value for e in ImageType], - error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}", + error_message="Invalid image type. Use to see options", ), - default=ImageType.CONDA.value, ) if image_type == ImageType.CONDA.value: From f5dae0517c9e70f30fc59689eb0a6162b1356a97 Mon Sep 17 00:00:00 2001 From: Andy Xie Date: Fri, 25 Apr 2025 11:01:51 -0400 Subject: [PATCH 24/36] feat: Support ReAct Agent on Tools Playground (#2012) # What does this PR do? ReAct prompting attempts to use the Thinking, Action, Observation loop to improve the model's reasoning ability via prompt engineering. With this PR, it now supports the various features in Streamlit's playground: 1. Adding the selection box for choosing between Agent Type: normal, ReAct. 2. Adding the Thinking, Action, Observation loop streamlit logic for ReAct agent, as seen in many LLM clients. 3. Improving tool calling accuracies via ReAct prompting, e.g. using web_search. **Folded** ![react_output_folded png](https://github.com/user-attachments/assets/bf1bdce7-e6ef-455d-b6b0-c22a64e9d5c1) **Collapsed** ![react_output_collapsed](https://github.com/user-attachments/assets/cda2fc17-df0b-400d-971c-988de821f2a4) [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] Run the playground and uses reasoning prompts to see for yourself. Steps to test the ReAct agent mode: 1. Setup a llama-stack server as [getting_started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) describes. 2. Setup your Web Search API keys under `llama_stack/distribution/ui/modules/api.py`. 3. Run the streamlit playground and try ReAct agent, possibly with `websearch`, with the command: `streamlit run llama_stack/distribution/ui/app.py`. ## Test Process Current results are demonstrated with `llama-3.2-3b-instruct`. Results will vary with different models. You should be seeing clear distinction with normal agent and ReAct agent. Example prompts listed below: 1. Aside from the Apple Remote, what other devices can control the program Apple Remote was originally designed to interact with? 2. What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? ## Example Test Results **Web search on AppleTV** normal_output_appletv react_output_appletv **Web search on Colorado** normal_output_colorado react_output_colorado **Web search tool + MCP Slack server** normal_output_search_slack png react_output_search_slack ![slack_screenshot](https://github.com/user-attachments/assets/bb70e669-6067-462a-bdf6-7aaac6ccbcef) --- .../distribution/ui/page/playground/tools.py | 204 +++++++++++++++++- 1 file changed, 194 insertions(+), 10 deletions(-) diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index 5e19c1e4f..6c6a9fcfd 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -4,14 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import enum +import json import uuid import streamlit as st from llama_stack_client import Agent +from llama_stack_client.lib.agents.react.agent import ReActAgent +from llama_stack_client.lib.agents.react.tool_parser import ReActOutput from llama_stack.distribution.ui.modules.api import llama_stack_api +class AgentType(enum.Enum): + REGULAR = "Regular" + REACT = "ReAct" + + def tool_chat_page(): st.title("πŸ›  Tools") @@ -23,6 +32,7 @@ def tool_chat_page(): tool_groups_list = [tool_group.identifier for tool_group in tool_groups] mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")] builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] + selected_vector_dbs = [] def reset_agent(): st.session_state.clear() @@ -82,12 +92,20 @@ def tool_chat_page(): st.markdown(f"{idx}. `{tool.split(':')[-1]}`") st.subheader("Agent Configurations") + st.subheader("Agent Type") + agent_type = st.radio( + "Select Agent Type", + [AgentType.REGULAR, AgentType.REACT], + format_func=lambda x: x.value, + on_change=reset_agent, + ) + max_tokens = st.slider( "Max Tokens", min_value=0, max_value=4096, value=512, - step=1, + step=64, help="The maximum number of tokens to generate", on_change=reset_agent, ) @@ -104,13 +122,27 @@ def tool_chat_page(): @st.cache_resource def create_agent(): - return Agent( - client, - model=model, - instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.", - tools=toolgroup_selection, - sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, - ) + if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT: + return ReActAgent( + client=client, + model=model, + tools=toolgroup_selection, + response_format={ + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + }, + sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, + ) + else: + return Agent( + client, + model=model, + instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.", + tools=toolgroup_selection, + sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, + ) + + st.session_state.agent_type = agent_type agent = create_agent() @@ -139,6 +171,158 @@ def tool_chat_page(): ) def response_generator(turn_response): + if st.session_state.get("agent_type") == AgentType.REACT: + return _handle_react_response(turn_response) + else: + return _handle_regular_response(turn_response) + + def _handle_react_response(turn_response): + current_step_content = "" + final_answer = None + tool_results = [] + + for response in turn_response: + if not hasattr(response.event, "payload"): + yield ( + "\n\n🚨 :red[_Llama Stack server Error:_]\n" + "The response received is missing an expected `payload` attribute.\n" + "This could indicate a malformed response or an internal issue within the server.\n\n" + f"Error details: {response}" + ) + return + + payload = response.event.payload + + if payload.event_type == "step_progress" and hasattr(payload.delta, "text"): + current_step_content += payload.delta.text + continue + + if payload.event_type == "step_complete": + step_details = payload.step_details + + if step_details.step_type == "inference": + yield from _process_inference_step(current_step_content, tool_results, final_answer) + current_step_content = "" + elif step_details.step_type == "tool_execution": + tool_results = _process_tool_execution(step_details, tool_results) + current_step_content = "" + else: + current_step_content = "" + + if not final_answer and tool_results: + yield from _format_tool_results_summary(tool_results) + + def _process_inference_step(current_step_content, tool_results, final_answer): + try: + react_output_data = json.loads(current_step_content) + thought = react_output_data.get("thought") + action = react_output_data.get("action") + answer = react_output_data.get("answer") + + if answer and answer != "null" and answer is not None: + final_answer = answer + + if thought: + with st.expander("πŸ€” Thinking...", expanded=False): + st.markdown(f":grey[__{thought}__]") + + if action and isinstance(action, dict): + tool_name = action.get("tool_name") + tool_params = action.get("tool_params") + with st.expander(f'πŸ›  Action: Using tool "{tool_name}"', expanded=False): + st.json(tool_params) + + if answer and answer != "null" and answer is not None: + yield f"\n\nβœ… **Final Answer:**\n{answer}" + + except json.JSONDecodeError: + yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```" + except Exception as e: + yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```" + + return final_answer + + def _process_tool_execution(step_details, tool_results): + try: + if hasattr(step_details, "tool_responses") and step_details.tool_responses: + for tool_response in step_details.tool_responses: + tool_name = tool_response.tool_name + content = tool_response.content + tool_results.append((tool_name, content)) + with st.expander(f'βš™οΈ Observation (Result from "{tool_name}")', expanded=False): + try: + parsed_content = json.loads(content) + st.json(parsed_content) + except json.JSONDecodeError: + st.code(content, language=None) + else: + with st.expander("βš™οΈ Observation", expanded=False): + st.markdown(":grey[_Tool execution step completed, but no response data found._]") + except Exception as e: + with st.expander("βš™οΈ Error in Tool Execution", expanded=False): + st.markdown(f":red[_Error processing tool execution: {str(e)}_]") + + return tool_results + + def _format_tool_results_summary(tool_results): + yield "\n\n**Here's what I found:**\n" + for tool_name, content in tool_results: + try: + parsed_content = json.loads(content) + + if tool_name == "web_search" and "top_k" in parsed_content: + yield from _format_web_search_results(parsed_content) + elif "results" in parsed_content and isinstance(parsed_content["results"], list): + yield from _format_results_list(parsed_content["results"]) + elif isinstance(parsed_content, dict) and len(parsed_content) > 0: + yield from _format_dict_results(parsed_content) + elif isinstance(parsed_content, list) and len(parsed_content) > 0: + yield from _format_list_results(parsed_content) + except json.JSONDecodeError: + yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n" + except (TypeError, AttributeError, KeyError, IndexError) as e: + print(f"Error processing {tool_name} result: {type(e).__name__}: {e}") + + def _format_web_search_results(parsed_content): + for i, result in enumerate(parsed_content["top_k"], 1): + if i <= 3: + title = result.get("title", "Untitled") + url = result.get("url", "") + content_text = result.get("content", "").strip() + yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n" + + def _format_results_list(results): + for i, result in enumerate(results, 1): + if i <= 3: + if isinstance(result, dict): + name = result.get("name", result.get("title", "Result " + str(i))) + description = result.get("description", result.get("content", result.get("summary", ""))) + yield f"\n- **{name}**\n {description}\n" + else: + yield f"\n- {result}\n" + + def _format_dict_results(parsed_content): + yield "\n```\n" + for key, value in list(parsed_content.items())[:5]: + if isinstance(value, str) and len(value) < 100: + yield f"{key}: {value}\n" + else: + yield f"{key}: [Complex data]\n" + yield "```\n" + + def _format_list_results(parsed_content): + yield "\n" + for _, item in enumerate(parsed_content[:3], 1): + if isinstance(item, str): + yield f"- {item}\n" + elif isinstance(item, dict) and "text" in item: + yield f"- {item['text']}\n" + elif isinstance(item, dict) and len(item) > 0: + first_value = next(iter(item.values())) + if isinstance(first_value, str) and len(first_value) < 100: + yield f"- {first_value}\n" + + def _handle_regular_response(turn_response): for response in turn_response: if hasattr(response.event, "payload"): print(response.event.payload) @@ -156,9 +340,9 @@ def tool_chat_page(): yield f"Error occurred in the Llama Stack Cluster: {response}" with st.chat_message("assistant"): - response = st.write_stream(response_generator(turn_response)) + response_content = st.write_stream(response_generator(turn_response)) - st.session_state.messages.append({"role": "assistant", "content": response}) + st.session_state.messages.append({"role": "assistant", "content": response_content}) tool_chat_page() From 4bbd0c06939728676a3ade0d28e24fbd8617ce96 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 25 Apr 2025 10:39:30 -0700 Subject: [PATCH 25/36] fix: add endpoint route debugs --- llama_stack/distribution/server/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 02f82498b..6e9941d1c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -461,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None): raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") impl_method = getattr(impl, endpoint.name) + logger.debug(f"{endpoint.method.upper()} {endpoint.route}") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") From 29072f40ab8bf8d47cb6867192e1b2f232f89321 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 25 Apr 2025 11:29:08 -0700 Subject: [PATCH 26/36] feat: new system prompt for llama4 (#2031) Tests: LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct Co-authored-by: Eric Huang --- .../llama4/prompt_templates/system_prompts.py | 144 ++++++++++++++++++ .../utils/inference/prompt_adapter.py | 15 +- 2 files changed, 154 insertions(+), 5 deletions(-) create mode 100644 llama_stack/models/llama/llama4/prompt_templates/system_prompts.py diff --git a/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py new file mode 100644 index 000000000..139e204ad --- /dev/null +++ b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List, Optional + +from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition +from llama_stack.models.llama.llama3.prompt_templates.base import ( + PromptTemplate, + PromptTemplateGeneratorBase, +) + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines: + + 1. FUNCTION CALLS: + - ONLY use functions that are EXPLICITLY listed in the function list below + - If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s) + - Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)] + Examples: + CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list + INCORRECT: get_weather(location="New York") + INCORRECT: Let me check the weather: [get_weather(location="New York")] + INCORRECT: [get_events(location="Singapore")] <- If function not in list + + 2. RESPONSE RULES: + - For pure function requests matching a listed function: ONLY output the function call(s) + - For knowledge questions: ONLY output text + - For missing parameters: ONLY request the specific missing parameters + - For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call. + - If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations + - NEVER combine text and function calls in the same response + - NEVER suggest alternative functions when the requested service is unavailable + - NEVER create or invent new functions not listed below + + 3. STRICT BOUNDARIES: + - ONLY use functions from the list below - no exceptions + - NEVER use a function as an alternative to unavailable information + - NEVER call functions not present in the function list + - NEVER add explanatory text to function calls + - NEVER respond with empty brackets + - Use proper Python/JSON syntax for function calls + - Check the function list carefully before responding + + 4. TOOL RESPONSE HANDLING: + - When receiving tool responses: provide concise, natural language responses + - Don't repeat tool response verbatim + - Don't add supplementary information + + + {{ function_description }} + """.strip("\n") + ) + + def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + + You can answer general questions or invoke tools when necessary. + In addition to tool calls, you should also augment your responses by using the tool outputs. + + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 4f9c4927a..657dc4b86 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import ( SystemDefaultGenerator, ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models @@ -306,10 +309,11 @@ def chat_completion_request_to_messages( elif model.model_family in ( ModelFamily.llama3_2, ModelFamily.llama3_3, - ModelFamily.llama4, ): - # llama3.2, llama3.3 and llama4 models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_2(request) + # llama3.2, llama3.3 follow the same tool prompt format + messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) + elif model.model_family == ModelFamily.llama4: + messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) else: messages = request.messages @@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1( return messages -def augment_messages_for_tools_llama_3_2( +def augment_messages_for_tools_llama( request: ChatCompletionRequest, + custom_tool_prompt_generator, ) -> List[Message]: existing_messages = request.messages existing_system_message = None @@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2( if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: system_prompt = existing_system_message.content - tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) + tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) sys_content += tool_template.render() sys_content += "\n" From 1bb1d9b2bad56671a821d5c42f766060f40951b9 Mon Sep 17 00:00:00 2001 From: Sajikumar JS <35679404+Sajikumarjs@users.noreply.github.com> Date: Fri, 25 Apr 2025 23:59:21 +0530 Subject: [PATCH 27/36] feat: Add watsonx inference adapter (#1895) # What does this PR do? IBM watsonx ai added as the inference [#1741 ](https://github.com/meta-llama/llama-stack/issues/1741) [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) --------- Co-authored-by: Sajikumar JS --- README.md | 1 + .../remote_hosted_distro/watsonx.md | 88 ++++++ llama_stack/providers/registry/inference.py | 10 + .../remote/inference/watsonx/__init__.py | 22 ++ .../remote/inference/watsonx/config.py | 46 ++++ .../remote/inference/watsonx/models.py | 47 ++++ .../remote/inference/watsonx/watsonx.py | 260 ++++++++++++++++++ llama_stack/templates/dependencies.json | 36 +++ llama_stack/templates/watsonx/__init__.py | 7 + llama_stack/templates/watsonx/build.yaml | 30 ++ llama_stack/templates/watsonx/doc_template.md | 74 +++++ llama_stack/templates/watsonx/run.yaml | 210 ++++++++++++++ llama_stack/templates/watsonx/watsonx.py | 90 ++++++ pyproject.toml | 1 + 14 files changed, 922 insertions(+) create mode 100644 docs/source/distributions/remote_hosted_distro/watsonx.md create mode 100644 llama_stack/providers/remote/inference/watsonx/__init__.py create mode 100644 llama_stack/providers/remote/inference/watsonx/config.py create mode 100644 llama_stack/providers/remote/inference/watsonx/models.py create mode 100644 llama_stack/providers/remote/inference/watsonx/watsonx.py create mode 100644 llama_stack/templates/watsonx/__init__.py create mode 100644 llama_stack/templates/watsonx/build.yaml create mode 100644 llama_stack/templates/watsonx/doc_template.md create mode 100644 llama_stack/templates/watsonx/run.yaml create mode 100644 llama_stack/templates/watsonx/watsonx.py diff --git a/README.md b/README.md index 8c201e43d..c2e688763 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ Here is a list of the various API providers and available distributions that can | OpenAI | Hosted | | βœ… | | | | | Anthropic | Hosted | | βœ… | | | | | Gemini | Hosted | | βœ… | | | | +| watsonx | Hosted | | βœ… | | | | ### Distributions diff --git a/docs/source/distributions/remote_hosted_distro/watsonx.md b/docs/source/distributions/remote_hosted_distro/watsonx.md new file mode 100644 index 000000000..018dc2a3c --- /dev/null +++ b/docs/source/distributions/remote_hosted_distro/watsonx.md @@ -0,0 +1,88 @@ +--- +orphan: true +--- + +# watsonx Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-watsonx` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| datasetio | `remote::huggingface`, `inline::localfs` | +| eval | `inline::meta-reference` | +| inference | `remote::watsonx` | +| safety | `inline::llama-guard` | +| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | +| telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | +| vector_io | `inline::faiss` | + + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `WATSONX_API_KEY`: watsonx API Key (default: ``) +- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``) + +### Models + +The following models are available by default: + +- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` +- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)` +- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` +- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` +- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` +- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` +- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` +- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` +- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)` + + +### Prerequisite: API Keys + +Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key). + + +## Running Llama Stack with watsonx + +You can do this via Conda (build code), venv or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-watsonx \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ + --env WATSONX_BASE_URL=$WATSONX_BASE_URL +``` + +### Via Conda + +```bash +llama stack build --template watsonx --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID +``` diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 3c54cabcf..4040f0d80 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="watsonx", + pip_packages=["ibm_watson_machine_learning"], + module="llama_stack.providers.remote.inference.watsonx", + config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py new file mode 100644 index 000000000..e59e873b6 --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import Inference + +from .config import WatsonXConfig + + +async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: + # import dynamically so `llama stack build` does not fail due to missing dependencies + from .watsonx import WatsonXInferenceAdapter + + if not isinstance(config, WatsonXConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + adapter = WatsonXInferenceAdapter(config) + return adapter + + +__all__ = ["get_adapter_impl", "WatsonXConfig"] diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py new file mode 100644 index 000000000..7ee99b7e0 --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.schema_utils import json_schema_type + + +class WatsonXProviderDataValidator(BaseModel): + url: str + api_key: str + project_id: str + + +@json_schema_type +class WatsonXConfig(BaseModel): + url: str = Field( + default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), + description="A base url for accessing the watsonx.ai", + ) + api_key: Optional[SecretStr] = Field( + default_factory=lambda: os.getenv("WATSONX_API_KEY"), + description="The watsonx API key, only needed of using the hosted service", + ) + project_id: Optional[str] = Field( + default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), + description="The Project ID key, only needed of using the hosted service", + ) + timeout: int = Field( + default=60, + description="Timeout for the HTTP requests", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}", + "api_key": "${env.WATSONX_API_KEY:}", + "project_id": "${env.WATSONX_PROJECT_ID:}", + } diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py new file mode 100644 index 000000000..d98f0510a --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/models.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.models.llama.sku_types import CoreModelId +from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry + +MODEL_ENTRIES = [ + build_hf_repo_model_entry( + "meta-llama/llama-3-3-70b-instruct", + CoreModelId.llama3_3_70b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-2-13b-chat", + CoreModelId.llama2_13b.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-1-70b-instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-1-8b-instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-11b-vision-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-1b-instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-3b-instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-90b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-guard-3-11b-vision", + CoreModelId.llama_guard_3_11b_vision.value, + ), +] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py new file mode 100644 index 000000000..d5d87ec01 --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, List, Optional, Union + +from ibm_watson_machine_learning.foundation_models import Model +from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams + +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + EmbeddingTaskType, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + TextTruncation, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, + request_has_media, +) + +from . import WatsonXConfig +from .models import MODEL_ENTRIES + + +class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): + def __init__(self, config: WatsonXConfig) -> None: + ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + + print(f"Initializing watsonx InferenceAdapter({config.url})...") + + self._config = config + + self._project_id = self._config.project_id + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_client(self, model_id) -> Model: + config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None + config_url = self._config.url + project_id = self._config.project_id + credentials = {"url": config_url, "apikey": config_api_key} + + return Model(model_id=model_id, credentials=credentials, project_id=project_id) + + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: + params = await self._get_params(request) + r = self._get_client(request.model).generate(**params) + choices = [] + if "results" in r: + for result in r["results"]: + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"] if result["stop_reason"] else None, + text=result["generated_text"], + ) + choices.append(choice) + response = OpenAICompatCompletionResponse( + choices=choices, + ) + return process_completion_response(response) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + + async def _generate_and_convert_to_openai_compat(): + s = self._get_client(request.model).generate_text_stream(**params) + for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=None, + text=chunk, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream): + yield chunk + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + response_format=response_format, + stream=stream, + logprobs=logprobs, + tool_config=tool_config, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + params = await self._get_params(request) + r = self._get_client(request.model).generate(**params) + choices = [] + if "results" in r: + for result in r["results"]: + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"] if result["stop_reason"] else None, + text=result["generated_text"], + ) + choices.append(choice) + response = OpenAICompatCompletionResponse( + choices=choices, + ) + return process_chat_completion_response(response, request) + + async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + model_id = request.model + + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_client(model_id).generate_text_stream(**params) + for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=None, + text=chunk, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response(stream, request): + yield chunk + + async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + input_dict = {"params": {}} + media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) + if isinstance(request, ChatCompletionRequest): + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + else: + assert not media_present, "Together does not support media for Completion requests" + input_dict["prompt"] = await completion_request_to_prompt(request) + if request.sampling_params: + if request.sampling_params.strategy: + input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type + if request.sampling_params.max_tokens: + input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens + if request.sampling_params.repetition_penalty: + input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty + if request.sampling_params.additional_params.get("top_p"): + input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] + if request.sampling_params.additional_params.get("top_k"): + input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] + if request.sampling_params.additional_params.get("temperature"): + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] + if request.sampling_params.additional_params.get("length_penalty"): + input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[ + "length_penalty" + ] + if request.sampling_params.additional_params.get("random_seed"): + input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"] + if request.sampling_params.additional_params.get("min_new_tokens"): + input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[ + "min_new_tokens" + ] + if request.sampling_params.additional_params.get("stop_sequences"): + input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[ + "stop_sequences" + ] + if request.sampling_params.additional_params.get("time_limit"): + input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"] + if request.sampling_params.additional_params.get("truncate_input_tokens"): + input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[ + "truncate_input_tokens" + ] + if request.sampling_params.additional_params.get("return_options"): + input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[ + "return_options" + ] + + params = { + **input_dict, + } + return params + + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: + pass diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 63c4ecfa5..4c16411f0 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -755,5 +755,41 @@ "vllm", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + ], + "watsonx": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "datasets", + "emoji", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "ibm_watson_machine_learning", + "langdetect", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "pythainlp", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "tree_sitter", + "uvicorn" ] } diff --git a/llama_stack/templates/watsonx/__init__.py b/llama_stack/templates/watsonx/__init__.py new file mode 100644 index 000000000..078d86144 --- /dev/null +++ b/llama_stack/templates/watsonx/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .watsonx import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml new file mode 100644 index 000000000..badd643ad --- /dev/null +++ b/llama_stack/templates/watsonx/build.yaml @@ -0,0 +1,30 @@ +version: '2' +distribution_spec: + description: Use watsonx for running LLM inference + providers: + inference: + - remote::watsonx + vector_io: + - inline::faiss + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda diff --git a/llama_stack/templates/watsonx/doc_template.md b/llama_stack/templates/watsonx/doc_template.md new file mode 100644 index 000000000..af0ae15a8 --- /dev/null +++ b/llama_stack/templates/watsonx/doc_template.md @@ -0,0 +1,74 @@ +--- +orphan: true +--- +# watsonx Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} + +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} {{ model.doc_string }}` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key). + + +## Running Llama Stack with watsonx + +You can do this via Conda (build code), venv or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ + --env WATSONX_BASE_URL=$WATSONX_BASE_URL +``` + +### Via Conda + +```bash +llama stack build --template watsonx --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID +``` diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml new file mode 100644 index 000000000..1048f7192 --- /dev/null +++ b/llama_stack/templates/watsonx/run.yaml @@ -0,0 +1,210 @@ +version: '2' +image_name: watsonx +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: watsonx + provider_type: remote::watsonx + config: + url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com} + api_key: ${env.WATSONX_API_KEY:} + project_id: ${env.WATSONX_PROJECT_ID:} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:\u200B}" + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db +models: +- metadata: {} + model_id: meta-llama/llama-3-3-70b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-2-13b-chat + provider_id: watsonx + provider_model_id: meta-llama/llama-2-13b-chat + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-2-13b + provider_id: watsonx + provider_model_id: meta-llama/llama-2-13b-chat + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-1-70b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-1-8b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-11b-vision-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-1b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-3b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-90b-vision-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-guard-3-11b-vision + provider_id: watsonx + provider_model_id: meta-llama/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-11B-Vision + provider_id: watsonx + provider_model_id: meta-llama/llama-guard-3-11b-vision + model_type: llm +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321 diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py new file mode 100644 index 000000000..d59bb6f20 --- /dev/null +++ b/llama_stack/templates/watsonx/watsonx.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.distribution.datatypes import Provider, ToolGroupInput +from llama_stack.providers.remote.inference.watsonx import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::watsonx"], + "vector_io": ["inline::faiss"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + + inference_provider = Provider( + provider_id="watsonx", + provider_type="remote::watsonx", + config=WatsonXConfig.sample_run_config(), + ) + + available_models = { + "watsonx": MODEL_ENTRIES, + } + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] + + default_models = get_model_registry(available_models) + return DistributionTemplate( + name="watsonx", + distro_type="remote_hosted", + description="Use watsonx for running LLM inference", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + available_models_by_provider=available_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "WATSONX_API_KEY": ( + "", + "watsonx API Key", + ), + "WATSONX_PROJECT_ID": ( + "", + "watsonx Project ID", + ), + }, + ) diff --git a/pyproject.toml b/pyproject.toml index 209367c4b..d661f45fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -274,6 +274,7 @@ exclude = [ "^llama_stack/providers/remote/inference/sample/", "^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/together/", + "^llama_stack/providers/remote/inference/watsonx/", "^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/nvidia/", "^llama_stack/providers/remote/safety/sample/", From 1deab94ea00109c887a455993bb0746e004a1fb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 25 Apr 2025 21:16:57 +0200 Subject: [PATCH 28/36] chore: exclude test, provider, and template directories from coverage (#2028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Introduce a `.coveragerc` file to omit: - test files (*/tests/*) - provider code (*/llama_stack/providers/*) - template files (*/llama_stack/templates/*) - virtual environment (.venv/*) This ensures coverage reports focus on core application logic (API and CLI). Note: I'm opening this for discussing as well - we might decide to ignore more and or re-add some directories! Signed-off-by: SΓ©bastien Han --- .coveragerc | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..e16c2e461 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit = + */tests/* + */llama_stack/providers/* + */llama_stack/templates/* + .venv/* From 0e4307de0f4fa531ac382654a082b4bc5ba3b7b1 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Fri, 25 Apr 2025 20:17:31 +0100 Subject: [PATCH 29/36] docs: Fix missing --gpu all flag in Docker run commands (#2026) adding the --gpu all flag to Docker run commands for meta-reference-gpu distributions ensures models are loaded into GPU instead of CPU. Remove docs for meta-reference-quantized-gpu The distribution was removed in #1887 but these files were left behind. Fixes: #1798 # What does this PR do? Fixes doc to add --gpu all command to docker run [//]: # (If resolving an issue, uncomment and update the line below) Closes #1798 ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] verified in docker documentation but untested --------- Signed-off-by: Derek Higgins --- README.md | 1 - docs/source/distributions/building_distro.md | 2 - .../self_hosted_distro/meta-reference-gpu.md | 2 + .../meta-reference-quantized-gpu.md | 123 ------------------ .../meta-reference-gpu/doc_template.md | 2 + 5 files changed, 4 insertions(+), 126 deletions(-) delete mode 100644 docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md diff --git a/README.md b/README.md index c2e688763..9a4f1a849 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider | **Distribution** | **Llama Stack Docker** | Start This Distribution | |:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:| | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | -| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | | SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) | | Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index 4c342b14b..56b8d30a8 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -109,8 +109,6 @@ llama stack build --list-templates +------------------------------+-----------------------------------------------------------------------------+ | nvidia | Use NVIDIA NIM for running LLM inference | +------------------------------+-----------------------------------------------------------------------------+ -| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference | -+------------------------------+-----------------------------------------------------------------------------+ | cerebras | Use Cerebras for running LLM inference | +------------------------------+-----------------------------------------------------------------------------+ | ollama | Use (an external) Ollama server for running LLM inference | diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md index b90f75347..f58d7bbee 100644 --- a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md @@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321 docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-meta-reference-gpu \ @@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use: docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-meta-reference-gpu \ diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md deleted file mode 100644 index c3e2b4f2c..000000000 --- a/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md +++ /dev/null @@ -1,123 +0,0 @@ ---- -orphan: true ---- - -# Meta Reference Quantized Distribution - -```{toctree} -:maxdepth: 2 -:hidden: - -self -``` - -The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations: - -| API | Provider(s) | -|-----|-------------| -| agents | `inline::meta-reference` | -| datasetio | `remote::huggingface`, `inline::localfs` | -| eval | `inline::meta-reference` | -| inference | `inline::meta-reference-quantized` | -| safety | `inline::llama-guard` | -| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | -| telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | -| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | - - -The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc. - -Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs. - -### Environment Variables - -The following environment variables can be configured: - -- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`) -- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`) -- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`) - - -## Prerequisite: Downloading Models - -Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. - -``` -$ llama model list --downloaded -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ -┃ Model ┃ Size ┃ Modified Time ┃ -┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ -β”‚ Llama3.2-1B-Instruct:int4-qlora-eo8 β”‚ 1.53 GB β”‚ 2025-02-26 11:22:28 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama3.2-1B β”‚ 2.31 GB β”‚ 2025-02-18 21:48:52 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Prompt-Guard-86M β”‚ 0.02 GB β”‚ 2025-02-26 11:29:28 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama3.2-3B-Instruct:int4-spinquant-eo8 β”‚ 3.69 GB β”‚ 2025-02-26 11:37:41 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama3.2-3B β”‚ 5.99 GB β”‚ 2025-02-18 21:51:26 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama3.1-8B β”‚ 14.97 GB β”‚ 2025-02-16 10:36:37 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama3.2-1B-Instruct:int4-spinquant-eo8 β”‚ 1.51 GB β”‚ 2025-02-26 11:35:02 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama-Guard-3-1B β”‚ 2.80 GB β”‚ 2025-02-26 11:20:46 β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Llama-Guard-3-1B:int4 β”‚ 0.43 GB β”‚ 2025-02-26 11:33:33 β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -## Running the Distribution - -You can do this via Conda (build code) or Docker which has a pre-built image. - -### Via Docker - -This method allows you to get started quickly without having to build the distribution code. - -```bash -LLAMA_STACK_PORT=8321 -docker run \ - -it \ - --pull always \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - llamastack/distribution-meta-reference-quantized-gpu \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct -``` - -If you are using Llama Stack Safety / Shield APIs, use: - -```bash -docker run \ - -it \ - --pull always \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - llamastack/distribution-meta-reference-quantized-gpu \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -``` - -### Via Conda - -Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. - -```bash -llama stack build --template meta-reference-quantized-gpu --image-type conda -llama stack run distributions/meta-reference-quantized-gpu/run.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct -``` - -If you are using Llama Stack Safety / Shield APIs, use: - -```bash -llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -``` diff --git a/llama_stack/templates/meta-reference-gpu/doc_template.md b/llama_stack/templates/meta-reference-gpu/doc_template.md index a174331b4..2ca6793d7 100644 --- a/llama_stack/templates/meta-reference-gpu/doc_template.md +++ b/llama_stack/templates/meta-reference-gpu/doc_template.md @@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321 docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-{{ name }} \ @@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use: docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-{{ name }} \ From 4fb583b4076e245cbd6c9c76546d485652f78563 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 25 Apr 2025 12:23:33 -0700 Subject: [PATCH 30/36] fix: check that llama stack client plain can be used as a subst for OpenAI client (#2032) With https://github.com/meta-llama/llama-stack-client-python/pull/226, now we have llama-stack-client be able to used as a substitute for OpenAI client (duck-typed) so you don't need to change downstream library code. image --- .../inference/test_openai_completion.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 75b53100c..46ec03d2e 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -75,19 +75,24 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="bar") +@pytest.fixture(params=["openai_client", "llama_stack_client"]) +def compat_client(request): + return request.getfixturevalue(request.param) + + @pytest.mark.parametrize( "test_case", [ "inference:completion:sanity", ], ) -def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... prompt = "Respond to this question and explain your answer. " + tc["content"] - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=False, @@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text "inference:completion:sanity", ], ) -def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... prompt = "Respond to this question and explain your answer. " + tc["content"] - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=True, @@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod 0, ], ) -def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs): +def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs): skip_if_provider_isnt_vllm(client_with_models, text_model_id) prompt = "Hello, world!" - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=False, @@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te assert len(choice.prompt_logprobs) > 0 -def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id): +def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id): skip_if_provider_isnt_vllm(client_with_models, text_model_id) prompt = "I am feeling really sad today." - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=False, @@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text assert choice.text in ["joy", "sadness"] +# Run the chat-completion tests with both the OpenAI client and the LlamaStack client + + @pytest.mark.parametrize( "test_case", [ @@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text "inference:chat_completion:non_streaming_02", ], ) -def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] - response = openai_client.chat.completions.create( + response = compat_client.chat.completions.create( model=text_model_id, messages=[ { @@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models, "inference:chat_completion:streaming_02", ], ) -def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] - response = openai_client.chat.completions.create( + response = compat_client.chat.completions.create( model=text_model_id, messages=[{"role": "user", "content": question}], stream=True, From 1b2e116a2ad8f6e8661b951fdbd7d9bf9ec19994 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 25 Apr 2025 13:16:16 -0700 Subject: [PATCH 31/36] fix: tool call encoded twice (#2034) # What does this PR do? ## Test Plan LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct --- llama_stack/models/llama/llama4/chat_format.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index 1debadcc5..1574eeb5e 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -303,6 +303,7 @@ class ChatFormat: arguments_json=json.dumps(tool_arguments), ) ) + content = "" return RawMessage( role="assistant", From b5d8e44e81b2ff09c276308ea4fca8ed5dc94fb5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 25 Apr 2025 13:15:52 -0700 Subject: [PATCH 32/36] fix: only sleep for tests when they pass or fail --- tests/integration/conftest.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 22290b519..131219e52 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -10,6 +10,7 @@ import platform import textwrap import time +import pytest from dotenv import load_dotenv from llama_stack.log import get_logger @@ -19,10 +20,29 @@ from .report import Report logger = get_logger(__name__, category="tests") +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + report = outcome.get_result() + if report.when == "call": + item.execution_outcome = report.outcome + item.was_xfail = getattr(report, "wasxfail", False) + + def pytest_runtest_teardown(item): - interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS") - if interval_seconds: - time.sleep(float(interval_seconds)) + # Check if the test actually ran and passed or failed, but was not skipped or an expected failure (xfail) + outcome = getattr(item, "execution_outcome", None) + was_xfail = getattr(item, "was_xfail", False) + + name = item.nodeid + if not any(x in name for x in ("inference/", "safety/", "agents/")): + return + + logger.debug(f"Test '{item.nodeid}' outcome was '{outcome}' (xfail={was_xfail})") + if outcome in ("passed", "failed") and not was_xfail: + interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS") + if interval_seconds: + time.sleep(float(interval_seconds)) def pytest_configure(config): From 8713d67ce3cf383bd615934dedd1da99ff2c905c Mon Sep 17 00:00:00 2001 From: Jash Gulabrai <37194352+JashG@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:21:50 -0400 Subject: [PATCH 33/36] fix: Correctly parse algorithm_config when launching NVIDIA customization job; fix internal request handler (#2025) # What does this PR do? This addresses 2 bugs I ran into when launching a fine-tuning job with the NVIDIA Adapter: 1. Session handling in `_make_request` helper function returns an error. ``` INFO: 127.0.0.1:55831 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 500 Internal Server Error 16:11:45.643 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (270.44ms) 16:11:45.643 [ERROR] Error executing endpoint route='/v1/post-training/supervised-fine-tune' method='post' Traceback (most recent call last): File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/distribution/server/server.py", line 201, in endpoint return await maybe_await(value) File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/distribution/server/server.py", line 161, in maybe_await return await value File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/providers/remote/post_training/nvidia/post_training.py", line 408, in supervised_fine_tune response = await self._make_request( File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/providers/remote/post_training/nvidia/post_training.py", line 98, in _make_request async with self.session.request(method, url, params=params, json=json, **kwargs) as response: File "/Users/jgulabrai/Projects/forks/llama-stack/.venv/lib/python3.10/site-packages/aiohttp/client.py", line 1425, in __aenter__ self._resp: _RetType = await self._coro File "/Users/jgulabrai/Projects/forks/llama-stack/.venv/lib/python3.10/site-packages/aiohttp/client.py", line 579, in _request handle = tm.start() File "/Users/jgulabrai/Projects/forks/llama-stack/.venv/lib/python3.10/site-packages/aiohttp/helpers.py", line 587, in start return self._loop.call_at(when, self.__call__) File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 724, in call_at self._check_closed() File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 510, in _check_closed raise RuntimeError('Event loop is closed') RuntimeError: Event loop is closed ``` Note: This only occurred when initializing the client like so: ``` client = LlamaStackClient( base_url="http://0.0.0.0:8321" ) response = client.post_training.supervised_fine_tune(...) # Returns error ``` I didn't run into this issue when using the library client: ``` client = LlamaStackAsLibraryClient("nvidia") client.initialize() response = client.post_training.supervised_fine_tune(...) # Works fine ``` 2. The `algorithm_config` param in `supervised_fine_tune` is parsed as a `dict` when run from unit tests, but a Pydantic model when invoked using the Llama Stack client. So, the call fails outside of unit tests: ``` INFO: 127.0.0.1:54024 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 500 Internal Server Error 21:14:02.315 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (71.18ms) 21:14:02.314 [ERROR] Error executing endpoint route='/v1/post-training/supervised-fine-tune' method='post' Traceback (most recent call last): File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/distribution/server/server.py", line 205, in endpoint return await maybe_await(value) File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/distribution/server/server.py", line 164, in maybe_await return await value File "/Users/jgulabrai/Projects/forks/llama-stack/llama_stack/providers/remote/post_training/nvidia/post_training.py", line 407, in supervised_fine_tune "adapter_dim": algorithm_config.get("adapter_dim"), File "/Users/jgulabrai/Projects/forks/llama-stack/.venv/lib/python3.10/site-packages/pydantic/main.py", line 891, in __getattr__ raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') AttributeError: 'LoraFinetuningConfig' object has no attribute 'get' ``` The code assumes `algorithm_config` should be `dict`, so I just handle both cases. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan 1. I ran a local Llama Stack server with the necessary env vars: ``` lama stack run llama_stack/templates/nvidia/run.yaml --port 8321 --env ... ``` And invoked `supervised_fine_tune` to confirm neither of the errors above occur. ``` client = LlamaStackClient( base_url="http://0.0.0.0:8321" ) response = client.post_training.supervised_fine_tune(...) ``` 2. I confirmed the unit tests still pass: `./scripts/unit-tests.sh tests/unit/providers/nvidia/test_supervised_fine_tuning.py` [//]: # (## Documentation) --------- Co-authored-by: Jash Gulabrai --- .../post_training/nvidia/post_training.py | 37 +++++------ .../unit/providers/nvidia/test_parameters.py | 65 ++++++++++--------- .../nvidia/test_supervised_fine_tuning.py | 47 +++++++++----- 3 files changed, 83 insertions(+), 66 deletions(-) diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index d3de930f7..c74fb2a24 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): self.timeout = aiohttp.ClientTimeout(total=config.timeout) # TODO: filter by available models based on /config endpoint ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) - self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) - self.customizer_url = config.customizer_url + self.session = None + self.customizer_url = config.customizer_url if not self.customizer_url: warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2) self.customizer_url = "http://nemo.test" + async def _get_session(self) -> aiohttp.ClientSession: + if self.session is None or self.session.closed: + self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) + return self.session + async def _make_request( self, method: str, @@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): if json and "Content-Type" not in request_headers: request_headers["Content-Type"] = "application/json" + session = await self._get_session() for _ in range(self.config.max_retries): - async with self.session.request(method, url, params=params, json=json, **kwargs) as response: + async with session.request(method, url, params=params, json=json, **kwargs) as response: if response.status >= 400: error_data = await response.json() raise Exception(f"API request failed: {error_data}") @@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): jobs = [] for job in response.get("data", []): job_id = job.pop("id") - job_status = job.pop("status", "unknown").lower() - mapped_status = STATUS_MAPPING.get(job_status, "unknown") + job_status = job.pop("status", "scheduled").lower() + mapped_status = STATUS_MAPPING.get(job_status, "scheduled") # Convert string timestamps to datetime objects created_at = ( @@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): ) api_status = response.pop("status").lower() - mapped_status = STATUS_MAPPING.get(api_status, "unknown") + mapped_status = STATUS_MAPPING.get(api_status, "scheduled") return NvidiaPostTrainingJobStatusResponse( status=JobStatus(mapped_status), @@ -239,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): Supported models: - meta/llama-3.1-8b-instruct + - meta/llama-3.2-1b-instruct Supported algorithm configs: - LoRA, SFT @@ -284,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): - LoRA config: ## NeMo customizer specific LoRA parameters - - adapter_dim: int - Adapter dimension - Default: 8 (supports powers of 2) - - adapter_dropout: float - Adapter dropout - Default: None (0.0-1.0) - alpha: int - Scaling factor for the LoRA update Default: 16 Note: @@ -297,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): User is informed about unsupported parameters via warnings. """ # Map model to nvidia model name - # ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models + # See `_MODEL_ENTRIES` for supported models nvidia_model = self.get_provider_model_id(model) # Check for unsupported method parameters @@ -330,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): }, "data_config": {"dataset_id", "batch_size"}, "optimizer_config": {"lr", "weight_decay"}, - "lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"}, + "lora_config": {"type", "alpha"}, } # Validate all parameters at once @@ -389,16 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Handle LoRA-specific configuration if algorithm_config: - if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA": + if algorithm_config.type == "LoRA": warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config") job_config["hyperparameters"]["lora"] = { - k: v - for k, v in { - "adapter_dim": algorithm_config.get("adapter_dim"), - "alpha": algorithm_config.get("alpha"), - "adapter_dropout": algorithm_config.get("adapter_dropout"), - }.items() - if v is not None + k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None } else: raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}") diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py index cb1b92fba..ea12122a0 100644 --- a/tests/unit/providers/nvidia/test_parameters.py +++ b/tests/unit/providers/nvidia/test_parameters.py @@ -10,14 +10,17 @@ import warnings from unittest.mock import patch import pytest -from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig -from llama_stack_client.types.post_training_supervised_fine_tune_params import ( - TrainingConfig, - TrainingConfigDataConfig, - TrainingConfigEfficiencyConfig, - TrainingConfigOptimizerConfig, -) +from llama_stack.apis.post_training.post_training import ( + DataConfig, + DatasetFormat, + EfficiencyConfig, + LoraFinetuningConfig, + OptimizerConfig, + OptimizerType, + TrainingConfig, +) +from llama_stack.distribution.library_client import convert_pydantic_to_json_value from llama_stack.providers.remote.post_training.nvidia.post_training import ( NvidiaPostTrainingAdapter, NvidiaPostTrainingConfig, @@ -66,11 +69,8 @@ class TestNvidiaParameters(unittest.TestCase): def test_customizer_parameters_passed(self): """Test scenario 1: When an optional parameter is passed and value is correctly set.""" - custom_adapter_dim = 32 # Different from default of 8 algorithm_config = LoraFinetuningConfig( type="LoRA", - adapter_dim=custom_adapter_dim, - adapter_dropout=0.2, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, @@ -78,8 +78,15 @@ class TestNvidiaParameters(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) - data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16) - optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002) + data_config = DataConfig( + dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0002, + weight_decay=0.01, + num_warmup_steps=100, + ) training_config = TrainingConfig( n_epochs=3, data_config=data_config, @@ -95,7 +102,7 @@ class TestNvidiaParameters(unittest.TestCase): model="meta-llama/Llama-3.1-8B-Instruct", checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -114,7 +121,7 @@ class TestNvidiaParameters(unittest.TestCase): self._assert_request_params( { "hyperparameters": { - "lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16}, + "lora": {"alpha": 16}, "epochs": 3, "learning_rate": 0.0002, "batch_size": 16, @@ -130,8 +137,6 @@ class TestNvidiaParameters(unittest.TestCase): algorithm_config = LoraFinetuningConfig( type="LoRA", - adapter_dim=16, - adapter_dropout=0.1, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, @@ -139,12 +144,16 @@ class TestNvidiaParameters(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) - data_config = TrainingConfigDataConfig( - dataset_id=required_dataset_id, # Required parameter - batch_size=8, + data_config = DataConfig( + dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct ) - optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) training_config = TrainingConfig( n_epochs=1, @@ -161,7 +170,7 @@ class TestNvidiaParameters(unittest.TestCase): model=required_model, # Required parameter checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -186,24 +195,24 @@ class TestNvidiaParameters(unittest.TestCase): def test_unsupported_parameters_warning(self): """Test that warnings are raised for unsupported parameters.""" - data_config = TrainingConfigDataConfig( + data_config = DataConfig( dataset_id="test-dataset", batch_size=8, # Unsupported parameters shuffle=True, - data_format="instruct", + data_format=DatasetFormat.instruct, validation_dataset_id="val-dataset", ) - optimizer_config = TrainingConfigOptimizerConfig( + optimizer_config = OptimizerConfig( lr=0.0001, weight_decay=0.01, # Unsupported parameters - optimizer_type="adam", + optimizer_type=OptimizerType.adam, num_warmup_steps=100, ) - efficiency_config = TrainingConfigEfficiencyConfig( + efficiency_config = EfficiencyConfig( enable_activation_checkpointing=True # Unsupported parameter ) @@ -230,15 +239,13 @@ class TestNvidiaParameters(unittest.TestCase): checkpoint_dir="test-dir", # Unsupported parameter algorithm_config=LoraFinetuningConfig( type="LoRA", - adapter_dim=16, - adapter_dropout=0.1, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, rank=16, lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ), - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={"test": "value"}, # Unsupported parameter hyperparam_search_config={"test": "value"}, # Unsupported parameter ) diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 09f67e4e6..319011be3 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -10,14 +10,18 @@ import warnings from unittest.mock import patch import pytest -from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig -from llama_stack_client.types.post_training_supervised_fine_tune_params import ( - TrainingConfig, - TrainingConfigDataConfig, - TrainingConfigOptimizerConfig, -) from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.post_training.post_training import ( + DataConfig, + DatasetFormat, + LoraFinetuningConfig, + OptimizerConfig, + OptimizerType, + QATFinetuningConfig, + TrainingConfig, +) +from llama_stack.distribution.library_client import convert_pydantic_to_json_value from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter from llama_stack.providers.remote.post_training.nvidia.post_training import ( ListNvidiaPostTrainingJobs, @@ -121,7 +125,7 @@ class TestNvidiaPostTraining(unittest.TestCase): "batch_size": 16, "epochs": 2, "learning_rate": 0.0001, - "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, + "lora": {"alpha": 16}, }, "output_model": "default/job-1234", "status": "created", @@ -132,8 +136,6 @@ class TestNvidiaPostTraining(unittest.TestCase): algorithm_config = LoraFinetuningConfig( type="LoRA", - adapter_dim=16, - adapter_dropout=0.1, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, @@ -141,10 +143,15 @@ class TestNvidiaPostTraining(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) - data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16) + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) - optimizer_config = TrainingConfigOptimizerConfig( + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, ) training_config = TrainingConfig( @@ -161,7 +168,7 @@ class TestNvidiaPostTraining(unittest.TestCase): model="meta-llama/Llama-3.1-8B-Instruct", checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -185,16 +192,22 @@ class TestNvidiaPostTraining(unittest.TestCase): "epochs": 2, "batch_size": 16, "learning_rate": 0.0001, - "lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1}, + "weight_decay": 0.01, + "lora": {"alpha": 16}, }, }, ) def test_supervised_fine_tune_with_qat(self): - algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) - data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16) - optimizer_config = TrainingConfigOptimizerConfig( + algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, ) training_config = TrainingConfig( n_epochs=2, @@ -209,7 +222,7 @@ class TestNvidiaPostTraining(unittest.TestCase): model="meta-llama/Llama-3.1-8B-Instruct", checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) From bb1a85c9a0b34f5ecafacf4465dcad62737d0cb2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 25 Apr 2025 15:23:53 -0700 Subject: [PATCH 34/36] fix: make sure test works equally well against llama stack as a server --- tests/integration/tool_runtime/test_registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index e4241d813..b36237d05 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -114,7 +114,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server): llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) # Verify it is unregistered - with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"): + with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) # Verify tools are also unregistered From 0266b20535c6d3a7e2918161c8e0a7804cc08d44 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 25 Apr 2025 15:52:15 -0700 Subject: [PATCH 35/36] docs: update prompt_format.md for llama4 (#2035) torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-Scout-17B-16E-Instruct ~/local/checkpoints// llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md Co-authored-by: Eric Huang --- .../models/llama/llama4/prompt_format.md | 62 ++++++++++++++----- llama_stack/models/llama/llama4/prompts.py | 40 +++--------- 2 files changed, 54 insertions(+), 48 deletions(-) diff --git a/llama_stack/models/llama/llama4/prompt_format.md b/llama_stack/models/llama/llama4/prompt_format.md index 698571093..350a5517a 100644 --- a/llama_stack/models/llama/llama4/prompt_format.md +++ b/llama_stack/models/llama/llama4/prompt_format.md @@ -64,7 +64,7 @@ This example passes an image that is smaller than the tile size, to show the til ##### Model Response Format ``` -The image depicts a dog standing on a skateboard, with its front paws positioned on the board and its back paws hanging off the back. The dog has a distinctive coat pattern, featuring a white face, brown and black fur, and white paws, and is standing on a skateboard with red wheels, set against a blurred background of a street or alleyway with a teal door and beige wall.<|eot|> +The image depicts a dog standing on a skateboard, positioned centrally and facing the camera directly. The dog has a distinctive coat pattern featuring white, black, and brown fur, with floppy ears and a black nose, and is standing on a skateboard with red wheels.<|eot|> ``` @@ -91,7 +91,7 @@ Here is an example of how to pass an image to the model ##### Model Response Format ``` -This image shows a dog standing on a skateboard, with its front paws positioned near the front of the board and its back paws near the back. The dog has a white, black, and orange coat, and is standing on a gray skateboard with red wheels, in front of a blurred background that appears to be a street or alleyway.<|eot|> +The image depicts a dog standing on a skateboard, with the dog positioned centrally and facing forward. The dog has a distinctive coat featuring a mix of white, brown, and black fur, and is wearing a collar as it stands on the skateboard, which has red wheels.<|eot|> ``` @@ -117,7 +117,7 @@ Here is an example of how to pass an image to the model ##### Model Response Format ``` -The first image shows a dog standing on a skateboard, while the second image shows a plate of spaghetti with tomato sauce, parmesan cheese, and parsley. The two images are unrelated, with the first image featuring a dog and the second image featuring a food dish, and they do not share any common elements or themes.<|eot|> +The first image features a dog standing on a skateboard, while the second image showcases a plate of spaghetti with tomato sauce and cheese. The two images appear to be unrelated, with one depicting a playful scene of a dog on a skateboard and the other presenting a classic Italian dish.<|eom|> ``` @@ -135,13 +135,44 @@ We are continuing the format for zero shot function calling used in previous ver ``` <|begin_of_text|><|header_start|>system<|header_end|> -You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the function can be used, point it out. If the given question lacks the parameters required by the function, -also point it out. You should only return the function call in tools call sections. +You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines: + +1. FUNCTION CALLS: +- ONLY use functions that are EXPLICITLY listed in the function list below +- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" +- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" +- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s) +- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)] +Examples: +CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list +INCORRECT: get_weather(location="New York") +INCORRECT: Let me check the weather: [get_weather(location="New York")] +INCORRECT: [get_events(location="Singapore")] <- If function not in list + +2. RESPONSE RULES: +- For pure function requests matching a listed function: ONLY output the function call(s) +- For knowledge questions: ONLY output text +- For missing parameters: ONLY request the specific missing parameters +- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call. +- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations +- NEVER combine text and function calls in the same response +- NEVER suggest alternative functions when the requested service is unavailable +- NEVER create or invent new functions not listed below + +3. STRICT BOUNDARIES: +- ONLY use functions from the list below - no exceptions +- NEVER use a function as an alternative to unavailable information +- NEVER call functions not present in the function list +- NEVER add explanatory text to function calls +- NEVER respond with empty brackets +- Use proper Python/JSON syntax for function calls +- Check the function list carefully before responding + +4. TOOL RESPONSE HANDLING: +- When receiving tool responses: provide concise, natural language responses +- Don't repeat tool response verbatim +- Don't add supplementary information -If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] -You SHOULD NOT include any other text in the response. Here is a list of functions in JSON format that you can invoke. @@ -151,9 +182,7 @@ Here is a list of functions in JSON format that you can invoke. "description": "Get weather info for places", "parameters": { "type": "dict", - "required": [ - "city" - ], + "required": ["city"], "properties": { "city": { "type": "string", @@ -167,7 +196,10 @@ Here is a list of functions in JSON format that you can invoke. } } } -<|eot|><|header_start|>user<|header_end|> +] + +You can answer general questions or invoke tools when necessary. +In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|> What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|> @@ -176,7 +208,7 @@ What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_e ##### Model Response Format ``` -[get_weather(city='SF'), get_weather(city='Seattle')]<|eot|> +[get_weather(city="San Francisco"), get_weather(city="Seattle")]<|eot|> ``` @@ -273,5 +305,5 @@ Use tools to get latest trending songs<|eot|><|header_start|>assistant<|header_e ##### Model Response Format ``` -{"n": "10"}<|eot|> +{"n": 10}<|eot|> ``` diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py index 13b96359a..fe9a59130 100644 --- a/llama_stack/models/llama/llama4/prompts.py +++ b/llama_stack/models/llama/llama4/prompts.py @@ -9,6 +9,10 @@ from io import BytesIO from pathlib import Path from typing import List +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator, +) + from ..datatypes import RawMediaItem, RawMessage, RawTextItem from ..prompt_format import ( Llama4UseCase, @@ -177,39 +181,9 @@ def usecases(base_model: bool = False) -> List[UseCase | str]: [ RawMessage( role="system", - content="""You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the function can be used, point it out. If the given question lacks the parameters required by the function, -also point it out. You should only return the function call in tools call sections. - -If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] -You SHOULD NOT include any other text in the response. - -Here is a list of functions in JSON format that you can invoke. - -[ - { - "name": "get_weather", - "description": "Get weather info for places", - "parameters": { - "type": "dict", - "required": [ - "city" - ], - "properties": { - "city": { - "type": "string", - "description": "The name of the city to get the weather for" - }, - "metric": { - "type": "string", - "description": "The metric for weather. Options are: celsius, fahrenheit", - "default": "celsius" - } - } - } - } -""", + content=PythonListCustomToolGenerator() + .gen(PythonListCustomToolGenerator().data_examples()[0]) + .render(), ), RawMessage( role="user", From 6cf6791de1772fc44bc2192da0a3241babc8e60c Mon Sep 17 00:00:00 2001 From: Sajikumar JS <35679404+Sajikumarjs@users.noreply.github.com> Date: Sat, 26 Apr 2025 22:47:52 +0530 Subject: [PATCH 36/36] fix: updated watsonx inference chat apis with new repo changes (#2033) # What does this PR do? There are new changes in repo which needs to add some additional functions to the inference which is fixed. Also need one additional params to pass some extra arguments to watsonx.ai [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --------- Co-authored-by: Sajikumar JS --- .../remote/inference/watsonx/watsonx.py | 182 +++++++++++++++--- 1 file changed, 150 insertions(+), 32 deletions(-) diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index d5d87ec01..fa9cc4391 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from ibm_watson_machine_learning.foundation_models import Model from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( @@ -27,10 +28,21 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + GreedySamplingStrategy, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -95,6 +107,14 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): return Model(model_id=model_id, credentials=credentials, project_id=project_id) + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + self._openai_client = AsyncOpenAI( + base_url=f"{self._config.url}/openai/v1", + api_key=self._config.api_key, + ) + return self._openai_client + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client(request.model).generate(**params) @@ -213,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - if request.sampling_params.additional_params.get("top_p"): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] - if request.sampling_params.additional_params.get("top_k"): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] - if request.sampling_params.additional_params.get("temperature"): - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] - if request.sampling_params.additional_params.get("length_penalty"): - input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[ - "length_penalty" - ] - if request.sampling_params.additional_params.get("random_seed"): - input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"] - if request.sampling_params.additional_params.get("min_new_tokens"): - input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[ - "min_new_tokens" - ] - if request.sampling_params.additional_params.get("stop_sequences"): - input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[ - "stop_sequences" - ] - if request.sampling_params.additional_params.get("time_limit"): - input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"] - if request.sampling_params.additional_params.get("truncate_input_tokens"): - input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[ - "truncate_input_tokens" - ] - if request.sampling_params.additional_params.get("return_options"): - input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[ - "return_options" - ] + + if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): + input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature + if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k + if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): + input_dict["params"][GenParams.TEMPERATURE] = 0.0 + + input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] params = { **input_dict, @@ -257,4 +257,122 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - pass + raise NotImplementedError("embedding is not supported for watsonx") + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + if params.get("stream", False): + return self._stream_openai_chat_completion(params) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore + + async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: + # watsonx.ai sometimes adds usage data to the stream + include_usage = False + if params.get("stream_options", None): + include_usage = params["stream_options"].get("include_usage", False) + stream = await self._get_openai_client().chat.completions.create(**params) + + seen_finish_reason = False + async for chunk in stream: + # Final usage chunk with no choices that the user didn't request, so discard + if not include_usage and seen_finish_reason and len(chunk.choices) == 0: + break + yield chunk + for choice in chunk.choices: + if choice.finish_reason: + seen_finish_reason = True + break